Backpropagation(BP) 倒傳遞法 #1 工作原理與說明
本篇會介紹在機器學習(machine learning)與深度學習(deep learning)領域裡很流行的倒傳遞法(Back Propagation/ Backpropagation, BP)的精髓:梯度下降法(Gradient Descent)、連鎖率(Chain Rule)
你想要知道該如何以Python實作BP並應用於優化層類神經網路可以讀這篇:Backpropagation(BP) 倒傳遞法 #2 貓貓分類器-2層類神經網路;你想要知道該如何優化多層類神經網路可以讀這篇:Backpropagation(BP) 倒傳遞法 #3 貓貓分類器-N層類神經網路
倒傳遞法(Backpropagation),這是一個很多學者在同一個年代都有發表過的最佳化演算法,其中包括鼎鼎大名的 Rumelhart 與 Hinton 在1986發表的『Learning representations by back-propagation errors』與更早幾年歸納這個方法的Webors於1974所發表的博士學位論文也有提到。BP是一種可大致分為正向傳遞(Forward pass)與反向傳遞(Backward pass),其中又結合梯度下降法(Gradient Descent)和微積分中的連鎖率(Chain Rule)而成的最佳化演算法。
那什麼是梯度下降呢?
就直接從梯度下降法切入Backpropagation吧!
梯度下降法的基礎概念
假設最佳化目標:對成本函數最小化
而梯度下降法就從參數初始位置朝向最陡的下坡方向前進並更新參數位置,獲得更新後的參數最終可以帶來降低成本的效果。那獲得坡度資訊的方法就是使用導函數(精確地來說應該是偏微分)。微積分應該有學到,對函數求得一階導函數可以獲得斜率函數,梯度下降法就是運用這個特性來優化成本函數。
梯度下降法的數學基礎
假設最佳化目標:對成本函數(J(w))最小化
先只用一個參數來看梯度下降法。可以注意到,這邊設定成本函數 J 擁有一個輸入值 w,就是要解釋微分帶來的效果。
若成本函數J是一個拋物線,如下圖(1):
圖(1):成本函數J
若要找到能產出最小成本的w,就是必須要不斷的改變w帶入J來嘗試。
但是電腦沒有上帝視角,所以我們透過微分讓電腦知道應該要朝哪個方向來方法更新 w,因此產生了下面的更新公式(1)。從公式(1)可以發現有一個未曾看過的α,這是學習速率(learning rate),用來控制學習步伐的參數,數值通常是介於0到1之間。
w=w–αdJ(w)dw | (1) |
公式(1)為我們帶來使用梯度更新w的概念,所以我們可以用下圖(2)來理解。可以發現紫色的三角形就是我們每次計算出來微分值dJ(w)dw,而這張圖是建立在當微分值dJ(w)dw大於0的情況。
圖(2):梯度下降示意圖(當斜率為大於0)
如果w的初始值在比較接近原點的地方呢?可以用左半邊的線段斜率是負數來幫助理解,所以微分值就會小於0,但是稍微計算一下就知道就如果是dJ(w)dw<0帶入更新公式(1)計算的結果會讓w數值變大,也就是會讓w趨近於J(w)較小的方向!(如圖(3)左半部所示)
圖(3):梯度下降示意圖(當斜率小於0)
梯度下降法
假設優化目標:對成本函數 J(w,b) 最小化
上面已經介紹過一個參數的成本函數,但是在真實應用上成本函數中不會只有一個參數。
因為要做微分的目標不只一個,此時就必須要使用到偏微分的方法,才能知道參數w、b分別對於成本函數J的影響。
偏微分記號我們以∂表示。更新公式則修改成如下公式(2)和公式(3):
w=w−α∂J(w,b)∂w | (2) |
b=b−α∂J(w,b)∂b | (3) |
梯度下降法,就是利用公式(2)、(3)的概念對成本函數進行優化,經過若干迭代之後就可以得到優化後的參數w和b。
連鎖率是Backward pass的精髓,一定要懂!
下方圖(3)為成本函數J(a,b,c)的計算圖(computation graph),我們可以藉由這樣的圖來理解成本函數的計算過程,以及連鎖率。
首先,順著計算圖的流程分別先計算u、v最後算出成本函數J可以獲得成本J。(其實整個計算成本的過程就是Forward pass)
圖(4):簡易計算圖
倘若欲優化成本函數J,就勢必要優化參數a、b、c。
所以要透過上述的梯度下降法優化J,就要倒著計算圖的順序來找出這幾個參數對J的偏微分值,如此一來便可知道這三個參數改變一點點的話,對於成本函數的影響是多少。
因此我們得從J回推到參數a、b、c,倒著計算圖的順序可以先看到J=3v,用偏微分可計算出若給v一點點變動,能夠影響多少J的變動量為:∂J∂v。
接續著來看v=a+u,分別用偏微分計算更改a一點點,會v對產生多少影響:∂v∂a,想當然u對v的影響就是:∂v∂u。
再來就是u=bc,依照上述偏微分的計算方式可以得知b對u的影響就是∂u∂b,而c對u的影響則是∂u∂c。
現在順著計算圖的方向來看誰會影響誰:
a會影響v,而v又會影響J,我們可以這樣表示:a→v→J
b→u→v→J
c→u→v→J
既然優化成本函數時就必須要計算各個參數對成本函數的影響量(∂J∂a、∂J∂b、∂J∂c),那我們就可以透過這個上述『誰影響誰』的路徑來計算:
∂J∂a=∂v∂a∂J∂v | (4) |
∂J∂b=∂u∂b∂v∂u∂J∂v | (5) |
∂J∂c=∂u∂c∂v∂u∂J∂v | (6) |
像這樣子的影響鏈就是微積分這門學問裡的連鎖率(Chain Rule)。
依照圖(4)計算圖的流程,我們是可以算出各參數分別對成本函數偏微分值分別是多少:∂J∂a=3、∂J∂b=6、∂J∂c=9
最終獲得∂J∂a、∂J∂b、∂J∂c,就可以用來更新參數了!
(其實,計算∂J∂a、∂J∂b、∂J∂c的過程就是Backward pass)
- Andrew Ng – Neural Networks & Deep Learning in Coursera
- (paper) Learning representations by back-propagation errors
Pingback:Backpropagation(BP) 倒傳遞法 #2 貓貓分類器-2層類神經網路 - BrilliantCode.net