從最優化的角度看待 Softmax 損失函式
加入極市 專業CV交流群,與 6000+來自騰訊,華為,百度,北大,清華,中科院 等名企名校視覺開發者互動交流!更有機會與 李開復老師 等大牛群內互動!
同時提供每月大咖直播分享、真實專案需求對接、乾貨資訊彙總,行業技術交流 。 點選文末“ 閱讀原文 ”立刻申請入群~
作 者 | 王峰
來源 | https://zhuanlan.zhihu.com/p/45014864
本文經作者授權轉載,二次轉載請聯絡原作者。
Softmax交叉熵損失函式應該是目前最常用的分類損失函數了,在大部分文章中,Softmax交叉熵損失函式都是從概率角度來解釋的,本週二極市就推送了一篇 Softmax相關文章: 一文道盡softmax loss及其變種 。
本文將嘗試從最優化的角度來推匯出Softmax交叉熵損失函式,希望能夠啟發出更多的研究思路。
一般而言,最優化的問題通常需要構造一個目標函式,然後尋找能夠使目標函式取得最大/最小值的方法。目標函式往往難以優化,所以有了各種relax、smooth的方法,例如使用L1範數取代L0範數、使用sigmoid取代階躍函式等等。
那麼我們就要思考一個問題:使用神經網路進行多分類(假設為 C 類)時的目標函式是什麼?神經網路的作用是學習一個非線性函式 f(x) ,將輸入轉換成我們希望的輸出。這裡我們不考慮網路結構,只考慮分類器(也就是損失函式)的話,最簡單的方法莫過於直接輸出一維的類別序號 。而這個方法的缺點顯而易見:我們事先並不知道這些類別之間的關係,而這樣做默認了相近的整數的類是相似的,為什麼第2類的左右分別是第1類和第3類,也許第2類跟第5類更為接近呢?
為了解決這個問題,可以將各個類別的輸出獨立開來,不再只輸出1個數而是輸出 C 個分數(某些文章中叫作logit[1],但我感覺這個詞用得沒什麼道理,參見評論),每個類別佔據一個維度,這樣就沒有誰與誰更近的問題了。那麼如果讓一個樣本的真值標籤(ground-truth label)所對應的分數比其他分數更大,就可以通過比較 C 個分數的大小來判斷樣本的類別了。這裡沿用我的論文[2]使用的名詞,稱真值標籤對應的類別分數為目標分數(target score),其他的叫非目標分數(non-target score)。
這樣我們就得到了一個優化目標:
輸出C個分數,使目標分數比非目標分數更大。
換成數學描述,設 為真值標籤的序號,那優化目標即為:
。
得到了目標函式之後,就要考慮優化問題了。我們可以給 一個負的梯度,給其他所有 一個正的梯度,經過梯度下降法,即可使 升高而 下降。為了控制整個神經網路的幅度,不可以讓 無限地上升或下降,所以我們利用max函式,讓在 剛剛超過 時就停止上升:
。
然而這樣做往往會使模型的泛化效能比較差,我們在訓練集上才剛剛讓 超過 ,那測試集很可能就不會超過。借鑑svm裡間隔的概念,我們新增一個引數,讓 比 大過一定的數值才停止:
。
這樣我們就推匯出了hinge loss...唔,好像跑題了,我們本來不是要說Softmax的麼...不過既然跑題了就多說點,為什麼hinge loss在SVM時代大放異彩,但在神經網路時代就不好用了呢?主要就是因為svm時代我們用的是二分類,通過使用一些小技巧比如1 vs 1、1 vs n等方式來做多分類問題。而如論文[3]這樣直接把hinge loss應用在多分類上的話,當類別數 特別大時,會有大量的非目標分數得到優化,這樣每次優化時的梯度幅度不等且非常巨大,極易梯度爆炸。
其實要解決這個梯度爆炸的問題也不難,我們把優化目標換一種說法:
輸出C個分數,使目標分數比最大的非目標分數更大。
跟之前相比,多了一個限制詞“最大的”,但其實我們的目標並沒有改變,“目標分數比最大的非目標分數更大”實際上等價於“目標分數比所有非目標分數更大”。這樣我們的損失函式就變成了:
。
在優化這個損失函式時,每次最多隻會有一個+1的梯度和一個-1的梯度進入網路,梯度幅度得到了限制。但這樣修改每次優化的分數過少,會使得網路收斂極其緩慢,這時就又要祭出smooth大法了。那麼max函式的smooth版是什麼?有同學會脫口而出:softmax!恭喜你答錯了...
這裡出現了一個經典的歧義,softmax實際上並不是max函式的smooth版,而是one-hot向量(最大值為1,其他為0)的smooth版。其實從輸出上來看也很明顯,softmax的輸出是個向量,而max函式的輸出是一個數值,不可能直接用softmax來取代max。max函式真正的smooth版本是LogSumExp函式( LogSumExp:https://en.wikipedia.org/wiki/LogSumExp ),對此感興趣的讀者還可以看看這個部落格: 尋求一個光滑的最大值函式(https://kexue.fm/archives/3290) 。
使用LogSumExp函式取代max函式:
,
LogSumExp函式的導數恰好為softmax函式:
。
經過這一變換,給予非目標分數的1的梯度將會通過LogSumExp函式傳播給所有的非目標分數,各個非目標分數得到的梯度是通過softmax函式進行分配的,較大的非目標分數會得到更大的梯度使其更快地下降。這些非目標分數的梯度總和為1,目標分數得到的梯度為-1,總和為0,絕對值和為2,這樣我們就有效地限制住了梯度的總幅度。
LogSumExp函式值是大於等於max函式值的,而且等於取到的條件也是非常苛刻的(具體情況還是得看我的博士論文,這裡公式已經很多了,再寫就沒法看了),所以使用LogSumExp函式相當於變相地加了一定的 m 。但這往往還是不夠的,我們可以選擇跟hinge loss一樣新增一個 ,那樣效果應該也會不錯,不過softmax交叉熵損失走的是另一條路:繼續smooth。
注意到ReLU函式 也有一個smooth版,即softplus函式 。使用softplus函式之後,即使 超過了LogSumExp函式,仍會得到一點點梯度讓 繼續上升,這樣其實也是變相地又增加了一點 ,使得泛化效能有了一定的保障。替換之後就可以得到:
這個就是大家所熟知的softmax交叉熵損失函數了。在經過兩步smooth化之後,我們將一個難以收斂的函式逐步改造成了softmax交叉熵損失函式,解決了原始的目標函式難以優化的問題。從這個推導過程中我們可以看出smooth化不僅可以讓優化更暢通,而且還變相地在類間引入了一定的間隔,從而提升了泛化效能。
至於如何利用這個推導來對損失函式進行修改和一些進一步的分析,未完待續...
[1] Pereyra G, Tucker G, Chorowski J, et al. Regularizing neural networks by penalizing confident output distributions[J]. arXiv preprint arXiv:1701.06548, 2017.
[2] Wang F, Cheng J, Liu W, et al. Additive margin softmax for face verification[J]. IEEE Signal Processing Letters, 2018, 25(7): 926-930.
[3] Tang Y. Deep learning using linear support vector machines[J]. arXiv preprint arXiv:1306.0239, 2013.
*延伸閱讀
一文道盡softmax loss及其變種
分享神經網路中設計loss function的一些技巧
CVPR 2018 | Repulsion loss:專注於遮擋情況下的行人檢測
每月大咖直播分享、真實專案需求對接、乾貨資訊彙總,行業技術交流 。 點選左下角“ 閱讀原文 ”立刻申請入群~
覺得有用麻煩給個好看啦~