External Memory

プログラミング周辺知識の備忘録メイン

TensorFlow exampleプログラムのmnist_deep.pyにあるドロップアウトについて

ドロップアウトの効果を確認したく、過学習状態を無理やり作っていろいろ試してみたが
なかなかうまくいかなかった。

そこでTensorFlow exampleプログラムの
tensorflow/mnist_deep.py at r1.2 · tensorflow/tensorflow · GitHub

ドロップアウト処理を行っているので、
これを利用して(こちらで使いやすいように若干プログラムを変更、内容はほぼ同じ)
ドロップアウト有無で比較することにより、ドロップアウトの効果を確認した。


プログラムを眺めている時に肝心な点に気が付いた。
学習時はドロップアウトを使用するkeep率(keep_prob=0.5)が、
テスト時はkeep率は1.0である。
見落としていた点であるが、ドロップアウトは学習時のための処理だから、
テスト時はドロップアウトを行わなない方がしっくりくる。
一応確認のためテスト時のドロップアウト処理についてお互いの比較を行った。


mnist_deep.pyは畳み込みニューラルネットワーク(CNN)を使用している。
CNNの性質から画像や動画認識特化かと思ったら、recommend systemや自然言語処理にも使われているらしい。


以下が作成したプログラムである。

import tensorflow as tf
import math
from tensorflow.examples.tutorials.mnist import input_data 

mnist = input_data.read_data_sets("mnist/", one_hot=True)

PIXEL = 784 
R = 10 

class cnn():
    
    def __init__(self,l_rate,step,keep_p=1.0):
        self.l_rate = l_rate
        self.step = step
        self.keep_p = keep_p
        self.x  = tf.placeholder(tf.float32, [None, PIXEL])
        self.y_ = tf.placeholder(tf.float32, [None, R])
        self.y = tf.reshape(self.x, [-1, 28, 28, 1])
        self.keep_prob = tf.placeholder(tf.float32)
        
    def conv2d(self,p_height,p_width,ch,next_ch):
        
        W = tf.Variable(tf.truncated_normal([p_height, p_width, ch, next_ch],
                            stddev=1.0 / math.sqrt(float(p_height * p_width))))
        b = tf.Variable(tf.zeros([next_ch]))

        self.y = tf.nn.relu(tf.nn.conv2d(self.y, W, strides=[1,1,1,1], padding='SAME')+ b)
    
    def max_pool(self):
        self.y = tf.nn.max_pool(self.y, ksize=[1,2,2,1],
            strides=[1,2,2,1], padding='SAME')
        
    def ful_connect(self,unit,next_unit):
        W = tf.Variable(tf.truncated_normal([unit, next_unit],
                        stddev=1.0 / math.sqrt(float(unit))))
        b = tf.Variable(tf.zeros([next_unit]))
        self.y = tf.reshape(self.y, [-1, unit])
        self.y = tf.nn.relu(tf.matmul(self.y, W)+ b)
        
    def drop_out(self):
        self.y = tf.nn.dropout(self.y, self.keep_prob)
        
    def train(self,unit):
        W = tf.Variable(tf.truncated_normal([unit, R],
                    stddev=1.0 / math.sqrt(float(unit))))
        b = tf.Variable(tf.zeros([R]))

        self.y = tf.matmul(self.y, W)+ b
        
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = self.y_,logits = self.y))
        train_step = tf.train.AdamOptimizer(self.l_rate).minimize(cross_entropy)
        
        correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
        sess = tf.InteractiveSession()
        tf.global_variables_initializer().run()
        
        for i in range(self.step):
            batch_xs, batch_ys = mnist.train.next_batch(50)
            sess.run(train_step, feed_dict={self.x: batch_xs, self.y_: batch_ys,self.keep_prob:self.keep_p})
            
            if i % 2000 == 0:
                print("{0:>5}{1[0]:>15.4f}{1[1]:>15.4f}".format(i,sess.run([cross_entropy,accuracy], feed_dict={self.x: mnist.test.images,
                                         self.y_: mnist.test.labels,self.keep_prob:1.0})))
        return sess,cross_entropy
    
    def print_accurency(self,unit):
        sess,cross_entropy = self.train(unit)
        correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        print("{0:>5}{1[0]:>15.4f}{1[1]:>15.4f}".format(self.step,sess.run([cross_entropy,accuracy], feed_dict={self.x: mnist.test.images,
                                         self.y_: mnist.test.labels,self.keep_prob:1.0})))
    
if __name__ == '__main__':
    print("{0:>5}{1:>15}{2:>15}".format("step","loss","accurancy"))
    print("--------------------------------------")
    print("*** keep_prob = 0.5 ***")
    cnn1 = cnn(0.0001,20000,0.5)
    cnn1.conv2d(5,5,1,32)
    cnn1.max_pool()
    cnn1.conv2d(5,5,32,64)
    cnn1.max_pool()
    cnn1.ful_connect(7*7*64,1024)
    cnn1.drop_out()
    cnn1.print_accurency(1024)
    print("*** keep_prob = 0.8 ***")
    cnn2 = cnn(0.0001,20000,0.8)
    cnn2.conv2d(5,5,1,32)
    cnn2.max_pool()
    cnn2.conv2d(5,5,32,64)
    cnn2.max_pool()
    cnn2.ful_connect(7*7*64,1024)
    cnn2.drop_out()
    cnn2.print_accurency(1024)
    print("*** keep_prob = 1.0 ***")
    cnn3 = cnn(0.0001,20000)
    cnn3.conv2d(5,5,1,32)
    cnn3.max_pool()
    cnn3.conv2d(5,5,32,64)
    cnn3.max_pool()
    cnn3.ful_connect(7*7*64,1024)
    cnn3.drop_out()
    cnn3.print_accurency(1024)


出力は以下となる。
こちらはテスト時はドロップアウト処理を行っていない。
学習時のみで行っている。

>python mnist-cnn.py
 step           loss      accurancy
--------------------------------------
*** keep_prob = 0.5 ***
    0         2.2298         0.2129
 2000         0.0457         0.9850
 4000         0.0293         0.9895
 6000         0.0308         0.9890
 8000         0.0259         0.9913
10000         0.0223         0.9924
12000         0.0233         0.9926
14000         0.0226         0.9932
16000         0.0277         0.9913
18000         0.0229         0.9929
20000         0.0252         0.9921
*** keep_prob = 0.8 ***
    0         2.2767         0.1582
 2000         0.0441         0.9846
 4000         0.0378         0.9877
 6000         0.0302         0.9904
 8000         0.0291         0.9895
10000         0.0276         0.9919
12000         0.0282         0.9915
14000         0.0264         0.9910
16000         0.0347         0.9899
18000         0.0317         0.9916
20000         0.0347         0.9911
*** keep_prob = 1.0 ***
    0         2.3039         0.1228
 2000         0.0452         0.9844
 4000         0.0374         0.9868
 6000         0.0284         0.9896
 8000         0.0335         0.9894
10000         0.0249         0.9910
12000         0.0246         0.9924
14000         0.0379         0.9896
16000         0.0269         0.9924
18000         0.0370         0.9915
20000         0.0332         0.9912

"keep_prob=1.0"はkeep率1.0のため実質ドロップアウト処理は行われていない。

正答率からみると、ドロップアウトの効果はわずかにあるように見える。
またstep数に対するlossの値の増加傾向も、
keep率が0.5の場合は他と比べて抑えられている。
後者はおそらく過学習の特徴が出ているのだろう。

もっと明確に確認するなら学習用データでの正答率とlossも確認すべきではあるが。
他の数字画像データではもっと顕著に有意差が出るかもしれない。


もうひとつ、以下の出力データはテスト時もドロップアウト処理を行ったものである。

出力データを比較して、
テスト時はドロップアウトの処理を行わないほうがいいと思われる。
keep率が小さい場合は顕著に正答率、lossともに悪化するので特にそうである。

>python mnist-cnn.py
 step           loss      accurancy
--------------------------------------
*** keep_prob = 0.5 ***
    0         2.5310         0.1409
 2000         0.0678         0.9779
 4000         0.0465         0.9846
 6000         0.0383         0.9883
 8000         0.0348         0.9893
10000         0.0346         0.9878
12000         0.0380         0.9886
14000         0.0323         0.9893
16000         0.0315         0.9911
18000         0.0314         0.9910
20000         0.0377         0.9894
*** keep_prob = 0.8 ***
    0         2.3156         0.1184
 2000         0.0631         0.9797
 4000         0.0362         0.9877
 6000         0.0346         0.9886
 8000         0.0310         0.9892
10000         0.0346         0.9886
12000         0.0288         0.9912
14000         0.0347         0.9902
16000         0.0318         0.9909
18000         0.0268         0.9917
20000         0.0319         0.9912
*** keep_prob = 1.0 ***
    0         2.2096         0.2297
 2000         0.0486         0.9837
 4000         0.0428         0.9857
 6000         0.0314         0.9888
 8000         0.0349         0.9892
10000         0.0260         0.9918
12000         0.0241         0.9924
14000         0.0275         0.9923
16000         0.0313         0.9914
18000         0.0254         0.9926
20000         0.0248         0.9927


また前回のニューラルネットワークunit数検証のプログラムを用いて
ドロップアウトを使って実行した。
keep率は0.8と0.5で行ったが、以下の出力結果は0.5の時のものである。
前回のドロップアウトなしの結果は以下のURLにある。
taka74k4.hatenablog.com


こちらはドロップアウト処理により、
正答率に変化が見られないもしくは、逆に低下した。

ドロップアウトは使いどころを選ぶようある。

今回はドロップアウトは最後の出力の場面で使用している。
CNNの場合は畳み込み層やpooling層でうまく特徴を抽出できるような構造だから、
最後の場面だけのドロップアウト処理でもうまくいくかもしれないが、
すべての層が全結合層の場合は、途中の層間でもドロップアウト処理をしたほうがいいかもしれないし、別の工夫が必要かもしれない。

>python nnet.py

units                     step           loss      accurancy
-------------------------------------------------------------
***same size***
784,100,100,10               0         2.2678         0.1839
784,100,100,10            1000         0.1601         0.9505
784,100,100,10            2000         0.1172         0.9637
784,100,100,10            3000         0.1010         0.9686
784,100,100,10            4000         0.0900         0.9730
784,100,100,10            5000         0.0874         0.9752
784,100,100,10            6000         0.0944         0.9703
784,100,100,10            7000         0.0969         0.9726
784,100,100,10            8000         0.0941         0.9731
784,100,100,10            9000         0.0830         0.9784
784,100,100,10           10000         0.0882         0.9769

784,392,392,10               0         2.2548         0.3129
784,392,392,10            1000         0.1228         0.9607
784,392,392,10            2000         0.0840         0.9752
784,392,392,10            3000         0.0836         0.9736
784,392,392,10            4000         0.0717         0.9787
784,392,392,10            5000         0.0723         0.9795
784,392,392,10            6000         0.0752         0.9801
784,392,392,10            7000         0.0732         0.9825
784,392,392,10            8000         0.0925         0.9770
784,392,392,10            9000         0.0832         0.9793
784,392,392,10           10000         0.0769         0.9805

784,784,784,10               0         2.2183         0.1888
784,784,784,10            1000         0.1021         0.9692
784,784,784,10            2000         0.0844         0.9732
784,784,784,10            3000         0.0931         0.9722
784,784,784,10            4000         0.0836         0.9767
784,784,784,10            5000         0.0846         0.9767
784,784,784,10            6000         0.0885         0.9786
784,784,784,10            7000         0.0868         0.9796
784,784,784,10            8000         0.0972         0.9786
784,784,784,10            9000         0.1160         0.9752
784,784,784,10           10000         0.1139         0.9795

784,1024,1024,10             0         2.1906         0.2948
784,1024,1024,10          1000         0.1075         0.9681
784,1024,1024,10          2000         0.0863         0.9756
784,1024,1024,10          3000         0.0790         0.9784
784,1024,1024,10          4000         0.0940         0.9741
784,1024,1024,10          5000         0.0856         0.9778
784,1024,1024,10          6000         0.0976         0.9756
784,1024,1024,10          7000         0.1002         0.9772
784,1024,1024,10          8000         0.1054         0.9763
784,1024,1024,10          9000         0.1033         0.9805
784,1024,1024,10         10000         0.0885         0.9803

***decreasing size***
784,392,100,10               0         2.2625         0.2460
784,392,100,10            1000         0.1406         0.9545
784,392,100,10            2000         0.0907         0.9706
784,392,100,10            3000         0.0835         0.9747
784,392,100,10            4000         0.0783         0.9776
784,392,100,10            5000         0.0788         0.9776
784,392,100,10            6000         0.0851         0.9765
784,392,100,10            7000         0.0777         0.9790
784,392,100,10            8000         0.0856         0.9803
784,392,100,10            9000         0.0888         0.9793
784,392,100,10           10000         0.0932         0.9804

784,784,392,10               0         2.2167         0.1935
784,784,392,10            1000         0.1138         0.9639
784,784,392,10            2000         0.0966         0.9704
784,784,392,10            3000         0.0691         0.9793
784,784,392,10            4000         0.0810         0.9769
784,784,392,10            5000         0.0879         0.9751
784,784,392,10            6000         0.0793         0.9790
784,784,392,10            7000         0.0760         0.9804
784,784,392,10            8000         0.0856         0.9796
784,784,392,10            9000         0.0780         0.9820
784,784,392,10           10000         0.0886         0.9813

784,784,100,10               0         2.2621         0.1553
784,784,100,10            1000         0.1164         0.9654
784,784,100,10            2000         0.0857         0.9740
784,784,100,10            3000         0.0789         0.9774
784,784,100,10            4000         0.0748         0.9788
784,784,100,10            5000         0.0746         0.9799
784,784,100,10            6000         0.0774         0.9791
784,784,100,10            7000         0.0803         0.9779
784,784,100,10            8000         0.0869         0.9793
784,784,100,10            9000         0.0885         0.9803
784,784,100,10           10000         0.0999         0.9778

***increasing size***
784,784,1024,10              0         2.1892         0.2807
784,784,1024,10           1000         0.1363         0.9565
784,784,1024,10           2000         0.0859         0.9731
784,784,1024,10           3000         0.0786         0.9783
784,784,1024,10           4000         0.0799         0.9792
784,784,1024,10           5000         0.0853         0.9788
784,784,1024,10           6000         0.0766         0.9798
784,784,1024,10           7000         0.0787         0.9800
784,784,1024,10           8000         0.0971         0.9789
784,784,1024,10           9000         0.1066         0.9762
784,784,1024,10          10000         0.0909         0.9808

784,1024,1568,10             0         2.1720         0.2661
784,1024,1568,10          1000         0.1125         0.9681
784,1024,1568,10          2000         0.0875         0.9748
784,1024,1568,10          3000         0.0862         0.9747
784,1024,1568,10          4000         0.0763         0.9800
784,1024,1568,10          5000         0.0811         0.9796
784,1024,1568,10          6000         0.0864         0.9797
784,1024,1568,10          7000         0.1049         0.9789
784,1024,1568,10          8000         0.1037         0.9771
784,1024,1568,10          9000         0.0884         0.9814
784,1024,1568,10         10000         0.1216         0.9775

***3 hidden layer***
784,784,784,784,10           0         2.2269         0.3175
784,784,784,784,10        1000         0.1222         0.9643
784,784,784,784,10        2000         0.0875         0.9747
784,784,784,784,10        3000         0.0785         0.9783
784,784,784,784,10        4000         0.0838         0.9793
784,784,784,784,10        5000         0.0882         0.9770
784,784,784,784,10        6000         0.1022         0.9766
784,784,784,784,10        7000         0.0863         0.9811
784,784,784,784,10        8000         0.0903         0.9811
784,784,784,784,10        9000         0.0985         0.9809
784,784,784,784,10       10000         0.1202         0.9774

784,392,256,100,10           0         2.2843         0.1934
784,392,256,100,10        1000         0.1164         0.9646
784,392,256,100,10        2000         0.0979         0.9727
784,392,256,100,10        3000         0.0794         0.9780
784,392,256,100,10        4000         0.0914         0.9751
784,392,256,100,10        5000         0.0776         0.9792
784,392,256,100,10        6000         0.0770         0.9799
784,392,256,100,10        7000         0.0783         0.9801
784,392,256,100,10        8000         0.1013         0.9798
784,392,256,100,10        9000         0.0833         0.9817
784,392,256,100,10       10000         0.1020         0.9786