交通場景、車道分割演算法-SCNN
論文中提出了一個新穎的網路Spatial CNN,該網路在圖片的行和列上做資訊傳遞,可以有效的識別強先驗結構的目標。同時論文提出了一個大型的車道檢測資料集UCLane,用於進一步推動自動駕駛發展。
offical github: https://github.com/XingangPan/SCNN
paper: Spatial As Deep: Spatial CNN for Traffic Scene Understanding
模型
本文提出的SCNN演算法將傳統的卷積層接層(layer-by-layer)的連線形式的轉為feature map中片連片卷積(slice-by-slice)的形式,使得圖中畫素行和列之間能夠傳遞資訊。這特別適用於檢測長距離連續形狀的目標或大型目標,有著極強的空間關係但是外觀線索較差的目標,例如交通線,電線杆和牆。
傳統的CNN不能有效處理具有長距離連續的形狀(尤其是在遮擋的情況下)。MRF/CRF+CNN的結構使用一個大卷積核來進行資訊傳遞,但是會導致計算效率低下,並且大卷積核很難訓練,如下圖(a)所示;而SCNN分別在列方向與行方向使用寬卷積做了迴圈的資訊傳遞,這樣就增強了空間資訊進而對於識別結構化物件特別有效,如下圖(b)所示。
Model
D、U、R、L是四個資訊傳遞模組。D、U沿著H方向做了從上到下和從下到上的資訊傳遞;R、L沿著W方向做了從左到右和從右到左的資訊傳遞。資訊傳遞的公式如下所示,f是relu函式,每一個模組的卷積函式都共享同一個卷積核。
簡單的舉一個例子,假設 x0
為 h
方向上的第一片特徵, x1
為第二片,那麼 x0
到 x1
的資訊傳遞過程就是 x1=x1+relu(conv2D(x0))
,後面的操作就可以這樣迴圈下去。這個操作類似於迴圈的殘差操作,既能夠加快計算效率又能傳遞長資訊。
Message pass equation
在資訊傳遞(Message Pass)過程中,MRF/CRF中每個畫素點會直接接收其他所有畫素點的資訊(大卷積核實現),這其中有許多冗餘計算;而SCNN在資訊傳遞的時候並不是獲取全域性元素,而是順序傳遞,由此簡化了資訊傳遞的結構加快了模型的運算效率,如下圖所示:
Message pass
在進行車道檢測時,在上述模型的基礎上,在輸出結果上添加了一個分支網路。這個分支網路能夠直接區分不同車道標記,這樣魯棒性更好。共有4中型別的車道線。輸出的概率圖經過這個分支網路預測車道標記是否存在。
對於存在值大於0.5的車道標記,在對應的概率圖每20行搜尋以獲得最高的響應位置,然後通過三次樣條函式連線這些點(cubic splines),就得到了最終的預測。
road lane
這是該演算法在車道分割上達成的效果,使用了 UCLane資料庫 :
road lane res
這是該演算法在交通場景分割上達成的效果,使用了 cityscapes資料庫 :
traffic sense res
模型實現
這個實現與官方實現並不是完全一致,僅用來理解SCNN的網路結構,同時由於資料集較大也沒有進行訓練測試。想要訓練使用這個模型可以下載官方的torch版本或者 tf版本 。
首先我們將資訊傳遞的過程封裝成一個keras層,每個MessagePass層沿一個軸做兩個方向的資訊傳遞,如下所示:
class MessagePass(Layer): def __init__(self, output_dim, axis, **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) super(MessagePass, self).__init__(**kwargs) self.output_dim = output_dim self.axis = axis def build(self, input_shape): assert self.axis in [1, 2] assert input_shape[-1] == self.output_dim if self.axis == 1: kernel_shape = [1, 9, input_shape[-1], self.output_dim] if self.axis == 2: kernel_shape = [9, 1, input_shape[-1], self.output_dim] self.w1 = self.add_weight(name='one', shape=kernel_shape, initializer='glorot_uniform', trainable=True) self.w2 = self.add_weight(name='two', shape=kernel_shape, initializer='glorot_uniform', trainable=True) super(MessagePass, self).build(input_shape) def call(self, inputs, **kwargs): h, w = int(inputs.shape[1]), int(inputs.shape[2]) if self.axis == 1: n = h if self.axis == 2: n = w feature_slice_old = [] feature_slice_new = [] for i in range(n): if self.axis == 1: cur_slice = K.expand_dims(inputs[:, i, :, :], axis=1) else: cur_slice = K.expand_dims(inputs[:, :, i, :], axis=2) feature_slice_old.append(cur_slice) if i == 0: feature_slice_new.append(cur_slice) else: tmp = K.relu(K.conv2d(feature_slice_old[i - 1], self.w1, padding='same')) tmp = tmp + feature_slice_old[i] feature_slice_new.append(tmp) feature_slice_old = feature_slice_new feature_slice_new = [] for i in reversed(range(n)): if self.axis == 1: cur_slice = K.expand_dims(inputs[:, i, :, :], axis=1) else: cur_slice = K.expand_dims(inputs[:, :, i, :], axis=2) feature_slice_old.append(cur_slice) if i == (n - 1): feature_slice_new.append(cur_slice) else: tmp = K.relu(K.conv2d(feature_slice_old[i - 1], self.w2, padding='same')) tmp = tmp + feature_slice_old[i] feature_slice_new.append(tmp) output = K.stack(feature_slice_new, axis=self.axis) output = K.squeeze(output, axis=self.axis + 1) return output def compute_output_shape(self, input_shape): return (input_shape[0], input_shape[1], input_shape[2], self.output_dim)
我們選擇DenseNet 121作為backbone網路,選取8倍縮小的那一個特徵圖作為輸出特徵,然後為這個特徵接上資訊傳遞層,如下所示:
class SCNN: def __init__(self, height, width, classes=5): self.classes = classes self.height = height self.width = width def backbone(self): model = DenseNet121( input_shape=(self.height, self.width, 3), weights=None, include_top=False) out_conv = model.get_layer('pool3_conv').output return model.input, out_conv def build(self): inputs, conv_out = self.backbone() conv_out = Conv2D(128, (1, 1), padding='same')(conv_out) conv_out = BatchNormalization()(conv_out) conv_out = Activation('relu')(conv_out) conv_out = MessagePass(128, 1)(conv_out) conv_out = MessagePass(128, 2)(conv_out) conv_out = Conv2D(self.classes, (1, 1), activation='softmax', padding='same')(conv_out) prob_output = UpSampling2D((8, 8))(conv_out) # add lane existence prediction branch x = AveragePooling2D(strides=2)(conv_out) x = Flatten()(x) x = Dense(128, activation='relu')(x) existence_output = Dense(4, activation='sigmoid')(x) model = Model(inputs=inputs, outputs=[prob_output, existence_output]) opt = SGD(lr=0.01, momentum=0.9, decay=0.0001) model.compile( optimizer=opt, loss=['categorical_crossentropy', 'binary_crossentropy']) return model if __name__ == '__main__': model = SCNN(288, 800).build() print(model.summary())