pythonでNMSを実装し、複数の矩形をマージする

pythonでNMS (Non-Maximum Suppression)を実装する方法を紹介します。NMSはSSDやYOLOといった物体検出AIの後処理として使用されるアルゴリズムです。

物体検出AIでは、人の顔などの検出目標の周辺にたくさんの矩形が出力されてしまいます。それらの矩形をマージしてすっきりさせるためのアルゴリズムがNMSです。

NMSを適用したイメージは下の方に掲載している画像を参照ください。

この記事ではNMSの実装方法を中心に説明します。NMSの処理の内容は詳しく説明している方がいらっしゃいますので、こちらの外部記事をご参照ください。

この記事でできること

NMS (Non-Maximum Suppression)を使って、たくさんの矩形(長方形)を統合してすっきりさせます。次の画像はNMS適用前の画像イメージです。

NMSをかける前のたくさんの矩形
NMSをかける前のたくさんの矩形

この例では、いらすとやの「人脈・コネがない人のイラスト」の顔の周辺に、たくさんの矩形を描画しました。矩形だけでなく、クラス名やスコア値も表示されてごちゃごちゃになっています。NMSをかける前のSSDやYOLOの出力はこのようになりがちです。

この画像に対してNMSをかけると次のようになります。

NMSをかけた後のすっきりした画像
NMSをかけた後のすっきりした画像

かなりすっきりしました。重複する矩形をマージするのがNMSの処理です。

上の画像のたくさんの矩形はランダムに付与したもので、それらをNMSで統合しました。便宜上、矩形の統合と表現しましたが、NMSの実際の処理では各矩形の信頼度(スコア値)を参照し、最も信頼度の高いものが残るようになっています。

NMS(Non-Maximum Suppresion)とは

NMSをすごく簡単に単純に説明すると、信頼度の高い矩形と大きく重なっている矩形を消去するアルゴリズムです。信頼度の最も高い矩形を選んで、その矩形と大きく重なっていて、かつ、信頼度の低い矩形を消していきます。

この説明で「重なり」とは、2つの矩形の領域の共通部分のことで、重なり具合を表す定量的な指標としてIoU(Intersection over Union)を使います。「大きく重なっている」ことの判定は、IoUがある閾値以上であるかで判定します。この閾値をIoU閾値などと呼び、0.5等の数値が選ばれます。

IoUについては次にリンクする依存の記事で実装方法を説明しました。本エントリーでは、この記事で説明した関数”iou(a, b)” (矩形aとbのIoUを計算する関数)を使ってNMSを実装します。

NMS(Non-Maximum Suppresion)を計算する関数の実装例

まず、NMSを計算する関数のサンプルコードを次に記載します。

def nms(bboxes, scores, classes, iou_threshold=0.5):
    new_bboxes = [] # NMS適用後の矩形リスト
    new_scores = [] # NMS適用後の信頼度(スコア値)リスト
    new_classes = [] # NMS適用後のクラスのリスト

    while len(bboxes) > 0:
        # スコア最大の矩形のインデックスを取得
        argmax = scores.index(max(scores))

        # スコア最大の矩形、スコア値、クラスをそれぞれのリストから消去
        bbox = bboxes.pop(argmax)
        score = scores.pop(argmax)
        clss = classes.pop(argmax)        

        # スコア最大の矩形と、対応するスコア値、クラスをNMS適用後のリストに格納
        new_bboxes.append(bbox)
        new_scores.append(score)
        new_classes.append(clss)

        pop_i = []
        for i, bbox_tmp in enumerate(bboxes):
            # スコア最大の矩形bboxとのIoUがiou_threshold以上のインデックスを取得
            if iou(bbox, bbox_tmp) >= iou_threshold:
                pop_i.append(i)

        # 取得したインデックス(pop_i)の矩形、スコア値、クラスをそれぞれのリストから消去
        for i in pop_i[::-1]:
            bboxes.pop(i)
            scores.pop(i)
            classes.pop(i)

    return new_bboxes, new_scores, new_classes

順番に説明していきます。

def nms(bboxes, scores, classes, iou_threshold=0.5):
    new_bboxes = [] # NMS適用後の矩形リスト
    new_scores = [] # NMS適用後の信頼度(スコア値)リスト
    new_classes = [] # NMS適用後のクラスのリスト

この関数nms()では、矩形のリストbboxesと、対応する信頼度(スコア値)のリストscores、クラスのリストclasses、IoU閾値iou_thresholdを引数にとります。

そして、NMS適用後の、矩形のリストnew_bboxes 、対応する信頼度(スコア値)のリストnew_scores 、クラスのリストnew_classes を返します。

各リストに含まれる矩形、信頼度、クラスのインデックスは、それぞれ対応している必要があります
。つまり、i番目の矩形bboxes[i]に対応する信頼度がscores[i]に格納され、クラスも同様にclasses[i]に格納されていることが前提です。

    while len(bboxes) > 0:
        # スコア最大の矩形のインデックスを取得
        argmax = scores.index(max(scores))

        # スコア最大の矩形、スコア値、クラスをそれぞれのリストから消去
        bbox = bboxes.pop(argmax)
        score = scores.pop(argmax)
        clss = classes.pop(argmax)        

この関数の処理では、入力された矩形のリスト”bboxes”から、まずはスコア値最大の矩形”bbox”をpopメソッドで削除し、whileループでbboxesが空になるまで以下の処理を繰り返します。

        for i, bbox_tmp in enumerate(bboxes):
            # スコア最大の矩形bboxとのIoUがiou_threshold以上のインデックスを取得
            if iou(bbox, bbox_tmp) >= iou_threshold:
                pop_i.append(i)

        # 取得したインデックス(pop_i)の矩形、スコア値、クラスをそれぞれのリストから消去
        for i in pop_i[::-1]:
            bboxes.pop(i)
            scores.pop(i)
            classes.pop(i)

ここでは、矩形が格納された配列”bboxes”から、スコア値最大の矩形”bbox”とのIoUが”iou_threshold”以上の矩形を取得し、削除します。

上記で紹介した関数iou()を用いて”bbox”と配列”bboxes”内の矩形”bbox_tmp”のIoUを各々計算し、IoUの値がiou_threshold以上であるなら、リスト”pop_i”にインデックスを追加します。リスト”pop_i”は、リスト”bboxes”等から削除すべきインデックスのリストで、popメソッドを使って、リストbboxes、scores、classesから要素を削除しています。

ここで注意すべきは、popメソッドでリストから要素を削除すると、削除された要素の分だけインデックスがずれてしまうことです。popを適用する要素を順番に気をつけないと、取得したインデックス”pop_i”と、削除したい矩形のインデックスがずれてしまいます。

この問題を回避するため、popで削除するインデックスは値が大きい順にします。取得したインデックスのリスト”pop_i”のうち値の大きな要素から選び、リスト”bboxes”等から要素をpopしていきます。これを実現しているのが”for i in pop_i[::-1]:”というfor文です。”pop_i”に含まれるインデックスのうち、最大のものから要素を選んでいき、リスト”bboxes”等から要素をpopで削除していきます。

上記の処理をwhile文で繰り返し、最終的に得られたnew_bboxes 、new_scores 、new_classesがNMS適用後の矩形や信頼度、クラスです。

関数nms()の使用例

冒頭で説明したいらすとやの画像の例を使って、上で紹介した関数nms()の使用例を紹介します。サンプルコードは次の通りです。

このソースコードでは、まず人の顔周辺にランダムに矩形を生成し、NMS適用前と適用後の矩形を描画した画像を保存する処理となっています。

from PIL import Image
from draw import draw_rect # 過去記事で紹介した関数

input_img_name="img/jinmyaku_nai.png" 
# 人脈・コネがない人のイラスト https://www.irasutoya.com/2015/09/blog-post_47.html
out_img_rnd_rects_name="img/jinmyaku_nai_rnd_rect.png"
out_img_name="img/jinmyaku_nai_rect_nms.png"

if __name__ == '__main__':
    import random
    import time
    import copy
    random.seed(314)
    img = Image.open(input_img_name)

    iou_threshold = 0.3
    # 顔の矩形リスト
    rects = [[161,50,260,161], 
             [318,135,418,243], 
             [46,267,142,373], 
             [181,244,275,352], 
             [340,351,440,455], 
             [500,257,598,371], 
             [648,33,742,146]]
    classes = ["woman", "woman", "man", "man", "woman", "man", "man"]
    # 信頼度の値は適当
    scores = [0.51, 0.61, 0.71, 0.75, 0.81, 0.85, 0.91]

    # ランダムに矩形を生成
    new_rects = []
    new_classes = []
    new_scores = []
    for i, rect in enumerate(rects):
        new_rects.append(rects[i])
        new_classes.append(classes[i])
        new_scores.append(scores[i])
        for dmy in range(10):
            new_rect = [a + 30*random.random() for a in rect]
            new_rects.append(new_rect)

            if random.random() < 0.5: cl = "man"
            else: cl = "woman"
            new_classes.append(cl)
            scr = random.random()*0.1
            new_scores.append(scr)

    # NMS適用前の画像を保存
    img_r = copy.copy(img)
    img_r = draw_rect(img_r, new_rects, new_scores, new_classes)    
    img_r.save(out_img_rnd_rects_name)

    # NMSを実行
    nms_rects, nms_scores, nms_classes = nms(new_rects, new_scores, new_classes, iou_threshold=iou_threshold)

    # NMS適用後の画像を保存
    img = draw_rect(img, nms_rects, nms_scores, nms_classes)  
    img.save(out_img_name)

このソースで保存される画像ファイル”jinmyaku_nai_rnd_rect.png”と”jinmyaku_nai_rect_nms.png”は、冒頭に掲載した2つの画像にそれぞれ対応します。

なお、このソースコードでは自作の関数”draw_rect()”を使用していますが、関数”draw_rect()”は画像に矩形やテキストを描画する関数で、次の記事で紹介したものです。

おわりに

今回はNMSのpythonでの実装方法を説明しました。numpyやtensorflow等の特殊なライブラリを使用しない、シンプルな実装方法を紹介したつもりです。

次の記事では、NMSの発展手法であるSoft-NMSも紹介していますので、興味ありましたら参照してみて下さい。

コメント

タイトルとURLをコピーしました