2019年5月17日 星期五

[筆記][tensorflow]Batch Normalization Layer用法(Key Variable not found in checkpoint)

最近在改寫3D Faster RCNN時遇到了一些使用上的問題 ,
因此稍微筆記一下。

首先介紹一下Batch Normalization~
Batch Normalization 最早是在2015年被提出,
主要是解決network在訓練過程中遇到的Covariate shift問題
透過正規化minibatch,加速網絡訓練及穩定性。

最常用來解釋Batch nomalization對於增加訓練網絡的穩定性的方式為下圖:


圖片中我們可以看到,
加入Batch normalization的每一層網絡輸出的分布比較常態,
沒有發生有效值被shift到兩個極端,
這對於使用tanh或sigmoid兩種activation function的網絡來說有很大的幫助,
因為這兩個函數最敏感(變化最大)的區域有一定的範圍,
輸入的分佈超過範圍就會有飽和的問題。


因此如果加入Batch normalization,
會讓訓練避免陷入飽和的狀態。

而ReLU雖然不會遇到飽和的問題,
但是使用Batch Normalization還是可以避免梯度爆炸與加速訓練。

目前我謹了解其物理意義,
對於Batch Normalization的數學模型與詳細細節有興趣的朋友也可以去找他的原始paper或是其他資源來看。
或是之後我有空看完也會來努力補上。

接下來是Batch Normalization的tensorflow用法,
我使用的tensorflow版本是1.12。

要將Batch Normalization layer加在convolution layer的activation function之前,
這邊我用的是tf.layers.batch_normalization,
beta值、gamma值、mean與variance等超參數都使用預設值,
這些值可以自己調整,
(另外如果用的是tf.nn.batch_normalization,
就要自己先定義存超參數的tensor。)

要正確的訓練記得將training這個argument設為true。

接下來就是我遇到問題的地方。
因為batch_normalization的參數不屬於trainable variables,
若沒有設定tf.train.Saver儲存global_variable,
儲存的model將不包含訓練好的batch_normalization layer,
inference的時候會出現找不到的錯誤:Key Variable not found in checkpoint
因此在宣告saver時要加入save global variable的設定 var_list=tf.global_variables()



就可以解決了~~


本篇主要是筆記一下
(1) batch normalization的物理意義
(2) tensorflow的寫法
(3) save batch normalization model時的注意事項

參考資料:





https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

https://www.tensorflow.org/api_docs/python/tf/nn/batch_normalization

沒有留言:

張貼留言