CNN主要處理圖像信息,主要應(yīng)用于計(jì)算機(jī)視覺(jué)領(lǐng)域。
RNN(recurrent neural network)主要就是處理序列數(shù)據(jù)(自然語(yǔ)言處理、語(yǔ)音識(shí)別、視頻分類(lèi)、文本情感分析、翻譯),核心就是它能保持過(guò)去的記憶。但RNN有著梯度消失問(wèn)題,專(zhuān)家之后接著改進(jìn)為L(zhǎng)STM和GRU結(jié)構(gòu)。下面將用通俗的語(yǔ)言分別詳細(xì)介紹。
對(duì)機(jī)器學(xué)習(xí)或深度學(xué)習(xí)不太熟的童鞋可以先康康這幾篇哦:
《無(wú)廢話(huà)的機(jī)器學(xué)習(xí)筆記》、《一文極速理解深度學(xué)習(xí)》、《一文總結(jié)經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)CNN模型》
RNN(Recurrent Neural Network)
RNN中的處理單元,中間綠色就是過(guò)去處理的結(jié)果,左邊第一幅圖就是正常的DNN,不會(huì)保存過(guò)去的結(jié)果,右邊的圖都有一個(gè)特點(diǎn),輸出的結(jié)果(藍(lán)色)不僅取決于當(dāng)前的輸入,還取決于過(guò)去的輸入!不同的單元能賦予RNN不同的能力,如 多對(duì)一就能對(duì)一串文本進(jìn)行分類(lèi),輸出離散值,比如根據(jù)你的言語(yǔ)判斷你今天高不高興。
RNN中保存著過(guò)去的信息,輸出取決于現(xiàn)在與過(guò)去。如果大伙學(xué)過(guò)數(shù)電,這就是狀態(tài)機(jī)!這玩意跟觸發(fā)器很像。
有個(gè)很重要的點(diǎn):
這個(gè)權(quán)重fw沿時(shí)間維度是一致的,權(quán)值共享。就像CNN中一個(gè)卷積核在卷積過(guò)程中參數(shù)一致。所以CNN是沿著空間維度權(quán)值共享;RNN是沿著時(shí)間維度權(quán)值共享。
具體來(lái)說(shuō)有三個(gè)權(quán)重,過(guò)去與現(xiàn)在各一個(gè)權(quán)重,加起來(lái)再來(lái)一個(gè)權(quán)重。 它們都沿著時(shí)間維度權(quán)值共享。不然每個(gè)時(shí)間都不一樣權(quán)重,參數(shù)量會(huì)很恐怖。
整體的計(jì)算圖(多對(duì)多):
每次的輸出y可以與標(biāo)簽值構(gòu)建損失函數(shù),這樣就跟之前DNN訓(xùn)練模型思想一樣,訓(xùn)練3套權(quán)重使損失函數(shù)不斷下降至滿(mǎn)意。
反向傳播要沿時(shí)間反向傳回去(backpropagation through time,BPTT)
Forward through entire sequence to compute loss, then backward through entire sequence to compute gradient.
這樣會(huì)有問(wèn)題,就是一下子把全部序列弄進(jìn)來(lái)求梯度,運(yùn)算量非常大。實(shí)際我們會(huì)將大序列分成等長(zhǎng)的小序列,分別處理:
不同隱含層中不同的值負(fù)責(zé)的是語(yǔ)料庫(kù)中不同的特征,所以隱含狀態(tài)的個(gè)數(shù)越多,模型就越能捕獲文本的底層特征。
下面來(lái)看一個(gè)例子:字符級(jí)語(yǔ)言模型(由上文預(yù)測(cè)下文):
我想輸入hell,然后模型預(yù)測(cè)我會(huì)輸出o;或者我輸入h,模型輸出e,我再輸入e,模型輸出l…
首先對(duì)h,e,l,o進(jìn)行獨(dú)熱編碼,然后構(gòu)建模型進(jìn)行訓(xùn)練。
輸入莎士比亞的劇本,讓模型自己生成劇本,訓(xùn)練過(guò)程:
輸入latex文本,讓模型自己生成內(nèi)容,公式寫(xiě)得有模有樣的,就不知道對(duì)不對(duì):
當(dāng)然輸入代碼,模型也會(huì)輸出代碼。所以現(xiàn)在火熱的Chatgpt的本質(zhì)就是RNN。
對(duì)于圖像描述,專(zhuān)家會(huì)先用CNN對(duì)圖像進(jìn)行特征抽?。?a class="article-link" target="_blank" href="/baike/522651.html">編碼器),然后將特征再輸入RNN進(jìn)行圖像描述(解碼器)。
還可以結(jié)合注意力機(jī)制(Image captioning with attention):
普通堆疊的RNN一旦隱含層變多變深,反向傳播時(shí)就很容易出現(xiàn)梯度消失/爆炸。
子豪兄總結(jié)得非常好,以最簡(jiǎn)單的三層網(wǎng)絡(luò)來(lái)看,對(duì)于輸出的O3可以列出損失函數(shù)L3,對(duì)L3進(jìn)行求偏導(dǎo),分別對(duì)輸出權(quán)重w0,輸入權(quán)重wx,過(guò)去權(quán)重ws進(jìn)行求導(dǎo)。我們發(fā)現(xiàn)對(duì)w0求偏導(dǎo)會(huì)很輕松。
但是,由于鏈?zhǔn)椒▌t(chain rule),對(duì)輸入權(quán)重wx和過(guò)去權(quán)重ws求偏導(dǎo)就會(huì)很痛苦。在表達(dá)式里,對(duì)于越是前面層的鏈?zhǔn)角髮?dǎo),乘積項(xiàng)越多,所以很容易梯度消失/爆炸,梯度消失占大多數(shù)。
LSTM(Long Short-Term Memory)
長(zhǎng)短時(shí)記憶神經(jīng)網(wǎng)絡(luò)(LSTM) 應(yīng)運(yùn)而生!
LSTM既有長(zhǎng)期記憶也有短期記憶,包括遺忘門(mén)、輸入門(mén)、輸出門(mén)、長(zhǎng)期記憶單元。右圖紅色函數(shù)是sigmoid,藍(lán)色函數(shù)是tanh。
C是長(zhǎng)期記憶,h是短期記憶。
所以當(dāng)前輸出ht是由短期記憶產(chǎn)生的。
我們看到長(zhǎng)期記憶那條線(xiàn)是貫通的,且只有乘加操作。
LSTM算法詳解:
下面幾個(gè)圖完美解釋了:
所以總共有四個(gè)權(quán)重:Wf、Wi、Wc、Wo,當(dāng)然還有它們對(duì)應(yīng)的偏置項(xiàng)。
整體過(guò)程可以概括為:遺忘、更新、輸出。(更新包括先選擇保留信息,再更新最新記憶。)
原論文中的圖也非常形象:
現(xiàn)在反向傳播求偏導(dǎo)就舒服了
GRU(Gated Recurrent Unit)
GRU也能很好解決梯度消失問(wèn)題,結(jié)構(gòu)簡(jiǎn)單一點(diǎn),主要就是重置門(mén)和更新門(mén)。
GRU與LSTM對(duì)比:
- 參數(shù)數(shù)量:GRU的參數(shù)數(shù)量相對(duì)LSTM來(lái)說(shuō)更少,因?yàn)樗鼘STM中的輸入門(mén)、遺忘門(mén)和輸出門(mén)合并為了一個(gè)門(mén)控單元,從而減少了模型參數(shù)的數(shù)量。
LSTM中有三個(gè)門(mén)控單元:輸入門(mén)、遺忘門(mén)和輸出門(mén)。每個(gè)門(mén)控單元都有自己的權(quán)重矩陣和偏置向量。這些門(mén)控單元負(fù)責(zé)控制歷史信息的流入和流出。
GRU中只有兩個(gè)門(mén)控單元:更新門(mén)和重置門(mén)。它們共享一個(gè)權(quán)重矩陣和一個(gè)偏置向量。更新門(mén)控制當(dāng)前輸入和上一時(shí)刻的輸出對(duì)當(dāng)前時(shí)刻的輸出的影響,而重置門(mén)則控制上一時(shí)刻的輸出對(duì)當(dāng)前時(shí)刻的影響。 - 計(jì)算速度:由于參數(shù)數(shù)量更少,GRU的計(jì)算速度相對(duì)LSTM更快。
- 長(zhǎng)序列建模:在處理長(zhǎng)序列數(shù)據(jù)時(shí),LSTM更加優(yōu)秀。由于LSTM中引入了一個(gè)長(zhǎng)期記憶單元(Cell State),使得它可以更好地處理長(zhǎng)序列中的梯度消失和梯度爆炸問(wèn)題。
GRU適用于:
處理簡(jiǎn)單序列數(shù)據(jù),如語(yǔ)言模型和文本生成等任務(wù)。
處理序列數(shù)據(jù)時(shí)需要快速訓(xùn)練和推斷的任務(wù),如實(shí)時(shí)語(yǔ)音識(shí)別、語(yǔ)音合成等。
對(duì)計(jì)算資源有限的場(chǎng)景,如嵌入式設(shè)備、移動(dòng)設(shè)備等。
LSTM適用于:
處理復(fù)雜序列數(shù)據(jù),如長(zhǎng)文本分類(lèi)、機(jī)器翻譯、語(yǔ)音識(shí)別等任務(wù)。
處理需要長(zhǎng)時(shí)依賴(lài)關(guān)系的序列數(shù)據(jù),如長(zhǎng)文本、長(zhǎng)語(yǔ)音等。
對(duì)準(zhǔn)確度要求較高的場(chǎng)景,如股票預(yù)測(cè)、醫(yī)學(xué)診斷等。
公式總結(jié):