首頁 > 運動

深度學習教程:神經網路不工作的37種應對法

由 深度學習中文社群 發表于 運動2021-12-14

簡介當你要解決分類問題而且輸入資料的類別非常不平衡時,在輸出層有時會出現這樣的現象”檢查每層引數的更新量,它們也應該滿足高斯分佈(34)使用不同的最佳化器一般情況下,最佳化器不應該是導致網路不訓練的罪魁禍首,除非選擇的超引數尤其糟糕

nan ind;nan ind 怎麼解決

前言

網路的訓練已經持續了12個小時,看上去一切都挺正常:梯度在波動,損失值在減小。但是真正用起來的時候卻令人大跌眼鏡:輸出都是0,或者輸出一些噪聲,總而言之產生的結果完全不能用。“我做錯了什麼?”——我問我的計算機,換來的只是沉默。

如果你的模型產生的都是垃圾(比如只是輸出平均值,或者準確率很差),那你該從哪兒入手檢查呢?

網路沒有得到訓練的原因有很多,經歷過多次除錯之後,我發現我做的檢查專案是類似的。因此,我把這些經驗彙編出來,給出如下這份便捷指南。這份列表提供了很多不錯的點子,希望對正在閱讀的你也能有所幫助。

一開始你可以做的事

很多事情都可能出錯,但是其中有一些出錯的可能性更大。通常我的應急預案包含如下幾項:

先使用一種簡單的,對你要解決的問題久經考驗的模型(例如對影象問題使用VGG)。如果可能的話,使用標準的損失函式

關掉所有的額外功能,例如正則化和資料增強

如果是在微調模型,仔細比對一下兩邊的預處理過程。新模型輸入資料的預處理和訓練原始模型時做的預處理應該一樣

驗證輸入資料是對的

先使用一個非常小的資料集(2-20個樣本)。讓模型過擬合這個資料集,然後逐漸加入更多資料

慢慢把之前關掉的額外功能都開開:資料增強、正則化、自定義損失函式等等。試著使用更復雜的模型

如果上述步驟都不奏效,沿著下面這個列表裡的內容開始一項項檢查吧

一、資料集問題

深度學習教程:神經網路不工作的37種應對法

(1)檢查你的輸入資料

檢查你送入網路的資料是否有意義。比如我曾經不止一次將影象的寬和高弄混。有時候我還會不小心給進全是0的資料。或者我還把同樣一批資料一遍又一遍作為輸入。所以把輸入輸出的若干批資料打印出來看一看,確認它們都是對的

(2)嘗試隨機輸入

試著向網路送入隨機數,而不是真實資料,看看是不是產生同樣的錯誤。如果的確產生了同樣的錯誤,那這基本上證實了你的網路中的某一部分是把資料變成了垃圾。試著逐層或者逐運算子除錯一下,看看哪兒出問題了

(3)檢查讀資料部分的程式碼

可能你的資料沒問題,但是把輸入送進網路的這部分程式碼有些問題。打印出網路第一層的、沒有經過其它操作的輸入,然後檢查一下

(4)確定輸入與輸出之間的關係沒問題

隨機選擇一些樣本,看看它們的標籤對不對。確定在將輸入樣本打亂時沒有打亂輸出標籤

(5)輸入與輸出之間的關係是否太過隨機?

也有可能是輸入和輸出之間的關係裡確定性的部分太小,隨機性的部分太大(可能有人會爭辯說,股票價格就是這樣的),也就是說,輸入與輸出之間不夠相關。不過沒有一種普遍的方法能對此作出判定,因為這取決於資料本身的性質

(6)資料集裡的噪聲是不是太多?

我曾經遇到過一次這樣的情況,那次是我想從一個食品網站上爬取一個圖片資料集,但是錯誤的標籤太多了,網路根本學不到東西。人肉檢查一些輸入看看標籤是不是有問題。至於錯誤標籤佔比超過多少能稱為“噪聲資料太多”,這個問題一直都比較有爭議。這篇論文使用了MNIST資料集,不過有50%的資料都對應了錯誤的標籤,而模型最後也能得到超過50%的準確率

(7)打亂資料集

如果資料集沒有被打亂過,而且內部存在某些特別的順序(比如資料是按標籤排列的),那麼這會對學習造成很差的影響。因此要把資料集打亂以避免這樣的現象——注意,打亂時要把輸入和標籤一起打亂

(8)減小類別的不平衡性

標記為類別A的影象數量是否千倍於標記為類別B的影象數量?那可能需要重新設計損失函式,或者使用其它不平衡資料集處理辦法

(9)訓練樣本夠嗎?

如果你是從頭訓練一個網路(不是在已有模型基礎上微調),那麼可能需要很多資料。對影象分類問題,討論指出對每個類別需要至少1000張影象

(10)確定每批資料不止包含一種標籤

如果資料集是排序過的(例如前10000個樣本的標籤都相同),那麼就會出現這樣的情況。把資料集打亂即可

(11)降低每批資料的數量(batch size)

這篇論文(

On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima

)指出如果一批資料包含的資料量太大,會降低模型的泛化能力

附一條:使用標準資料集(例如MNIST、CIFAR10)

當測試新的網路結構或者編寫新程式碼時,先使用標準資料集,不要使用自己的資料。因為對這些資料集已經有很多可參考的評估結果,而且這些資料集背後的問題已經被證明是‘可以被解決的’。對這些資料集,不會有標籤噪聲的問題,不會有訓練集/測試集分佈不同的問題,也不會有資料不可分的問題,等等

二、資料歸一化/資料增強

(12)標準化特徵

你把輸入標準化,使其均值為0方差為1了嗎

(13)是否做了太多資料增強

資料增強能起到正則化的效果。如果做了太多資料增強並且又使用了其它正則化方法(例如權重的L2範數、dropout等等)可能會導致網路欠擬合

(14)檢查預訓練模型的預處理

如果使用了預訓練模型,確保你在訓練時使用了訓練原始模型時一樣的歸一化和預處理方法。例如,影象畫素點是應該在[0, 1]內,還是應該在[-1, 1]內,還是應該在[0, 255]內?

(15)檢查訓練集/驗證集/測試集的預處理過程

CS231n給出了一個經常被人踩的坑

“所有預處理的統計過程(例如對資料求均值的過程)只能在訓練資料上做,然後再用到驗證集/測試集上。比如,在整個資料集上計算均值然後把每張圖片減掉均值,再把處理後的資料集分成訓練/驗證/測試集是不對的”

另外,要看看是不是對每個樣本/每批資料使用了同樣的預處理方法

三、實現問題

(16)試著解決原始問題的一個簡化版本

這有助於你找出問題出在哪兒。例如,如果原始問題的目標輸出是物體類別和座標,先試著只預測類別

(17)計算隨機情況下模型應該取得的正確損失值

這條仍然來自於CS231n的講義:“使用小引數初始化模型,不加正則化。例如,假如我們有10個類,那麼最開始模型隨機亂猜的話,對每條樣本,有10%的可能猜到正確類別,而Softmax損失函式值是猜到正確類別機率的負對數,因此最開始模型的損失值應該是-ln(0。1) = 2。302”

在之後,再試著加入正則化,這樣在開始模型的損失值會變大

(18)檢查損失函式

如果是自己實現一個損失函式,檢查有沒有bug,編寫單元測試。我自己寫的損失函式可能會有哪兒不太對,因此會暗搓搓地拖垮網路的效能

(19)驗證損失函式的輸入

如果使用了某個框架提供的損失函式,要保證你傳進去的是它想要的值。比如,我在用PyTorch的時候總會弄混NLLLoss和CrossEntropyLoss,這兩個函式的唯一區別是前者需要的引數是被softmax過的,而後者不用

(20)調整損失權重

如果你的損失函式是多個更小損失函式的組合體,要確保每個子損失函式值的相對大小沒問題。這需要你測試各損失函式權重的不同組合

(21)監測其它指標

有時候損失函式值不是衡量網路是否正確訓練的最好指標。如果可以,試著使用一些其它指標,比如準確率

(22)測試自定義的網路層

你是否自己實現了網路中的若干層?仔細檢查一下,以確保它們是按你想的那樣工作的

(23)檢查“被凍結的”層或變數

有些層/變數本來是應該被學習的,檢檢視看你是否不小心把它們的梯度更新停掉了

(24)增大網路大小

可能你的網路表示能力不足,難以擬合目標函式。試著增多層數或增多全連線層的隱藏單元數

(25)檢查隱藏維度錯誤

如果你的輸入類似於 (k, H, W) = (64, 64, 64) ,那麼很容易漏過與維度錯誤有關的問題。對輸入維度使用一些比較怪的數字(比如每個維度設定一個不同的質數),然後檢查輸入是如何在網路中傳播的

(26)檢查梯度

如果你自己手動實現了梯度下降,可以試著做一些梯度檢查來保證反向傳播的正確性。可參考 1 23

四、訓練問題

(27)先解決一個小問題

在資料的一個很小的子集上讓模型過擬合,以保證模型的有效性。比如,只在一個或兩個樣本上訓練模型,看看網路能否把它們區分開來,然後對各個類別加入更多樣本

(28)檢查權重的初始化

如果沒底,使用Xavier初始化法或He初始化法。此外,不好的初始化可能會把模型帶入到一個差的區域性極小值點,因此可以試試不同的初始化結果,看看是否管用

(29)修改超引數

可能你用的超引數比較差,如果可以的話,試試網格搜尋

(30)降低正則程度

太多正則化會導致網路嚴重欠擬合。考慮將諸如dropout、batch norm、權重/偏置L2正則化等手段都放鬆一些。fast。ai有一門很棒的課程叫“軟開人員的深度學習實踐”,在這裡Jeremy Howard建議我們先要避免欠擬合,也就是說要先足夠過擬合訓練資料,然後再考慮解決過擬合問題

(31)保持耐心

可能你的網路需要更多訓練時間才能開始給出有意義的預測。如果損失函式值一直在穩定下降,就讓它再訓練一會兒

(32)從訓練模式切換到測試模式

在某些網路實現中,有些層會使用batch norm或者dropout,它們在訓練和測試過程中的表現是不一樣的。將網路切換到合適的模式可以幫助它給出正確預測

(33)將訓練過程視覺化

監測每層的啟用值、權重和權重更新量,確保這些值的量級都匹配。比如,引數(權重和偏置)更新量的量級應該是1e-3左右考慮使用Tensorboard和Crayon這些視覺化庫。或者可以使用它們的乞丐版——打印出權重/偏置/啟用函式值注意那些啟用函式值平均情況下還遠大於0的層,考慮使用batch norm和ELUDeeplearning4j指出權重和偏置的直方圖應該滿足如下性質:

“對於權重,在一段時間以後,它們的直方圖形狀應該近似於高斯分佈(正態分佈)。對於偏置,它們應該從0開始,最後也呈高斯分佈(LSTM例外)。注意那些最後收斂到正/負無窮的引數,注意那些變得特別大的偏置。當你要解決分類問題而且輸入資料的類別非常不平衡時,在輸出層有時會出現這樣的現象”檢查每層引數的更新量,它們也應該滿足高斯分佈

(34)使用不同的最佳化器

一般情況下,最佳化器不應該是導致網路不訓練的罪魁禍首,除非選擇的超引數尤其糟糕。但是,選擇合適的最佳化器有助於以最短時間基本完成訓練過程。如果你是照著論文實現演算法,論文裡應該寫出用了哪個最佳化器。如果沒有,建議使用Adam或帶動量的SGD

(35)梯度爆炸/梯度消失

檢查引數的更新值,太大意味著出現了梯度爆炸,可以做一些梯度裁剪來避免檢查啟用值。deeplearning4j指出,“啟用值的標準差理想狀況下應該在0。5到2。0之間。否則說明出現了梯度爆炸或梯度消失”

(36)調整學習率

太小的學習率會導致模型收斂速度很慢,太大的學習率可以在初始階段迅速降低損失值,但是難以獲得一個好的解。試著成10倍增大/縮小當前學習率

(37)克服NaN帶來的問題

訓練RNN時更嚴重的問題是遇到了NaN(Non-a-Number)(我聽說是這樣)。可以試著使用如下方法應對

降低學習率,尤其是如果你在前100個迭代就遇到了NaN把0做除數,或者對0/負數求對數也會導致NaN參考Russell Stewart的這篇文章逐層計算,看哪一層出了NaN

更多深度學習相關資源歡迎點選下面瞭解更多按鈕獲取

Tags:資料訓練模型輸入網路