External Memory

プログラミング周辺知識の備忘録メイン

hidden layer(隠れ層)のunit数の違いによる学習制御について

ニューラルネットワークによる機械学習において、
それぞれのhidden layerごとのunit数によって
学習の制御がある程度可能であると直観的には思われる。

例えば、以下のようなことを考えた。

  • unit数が少ないとunitごとに相互作用が大きく(周囲の変化に敏感・従属的)なるので、特徴量の抽出および最適化(収束)がうまく働かない。
  • 逆にunit数が多いとアクティブなunitが疎となり、計算効率が悪いとか、相互作用が小さく学習がうまく進まない。また、パラメータ数の多さから余計な特徴まで拾って汎用性を失う。
  • これらのことから、うまくバランスのとれたunit数がある。
  • 入力データの複雑さや分類数によって、最適unit数が変わる。


(一方でテストデータに対して最適なバランスがあるとすれば、
それはテストデータに適応しただけという懸念もあり、
汎用性のためには多様なデータでテストをすべきではあると思う)


今回もMNISTデータを使用した。
主に隠れ層2層(一部3層)とし、それぞれのhidden層のunit数を変えて検証を行った。
また、誤差関数の出力も追加した。
学習の進み具合を確認したり正答率と比較して過学習などを確認することができるはず。


以下は使用したプログラムである。

import tensorflow as tf
import math
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets("mnist/", one_hot=True)


class nnet():
    
    def __init__(self,units,l_rate,step):
        self.units = units
        self.l_rate = l_rate
        self.step = step
        self.x = tf.placeholder(tf.float32, [None, units[0]])
        self.y_ = tf.placeholder(tf.float32,[None, units[-1]])
        
    def build(self):
        y = self.x
        for i in range(len(self.units)-1):
            y = self.layer(y,self.units[i],self.units[i+1],i)
            
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = self.y_,logits = y))        
        return cross_entropy,y
            
    def layer(self,inpt,unit,next_unit,i):
        W = tf.Variable(tf.truncated_normal([unit, next_unit],
                            stddev=1.0 / math.sqrt(float(unit))))
        b = tf.Variable(tf.zeros([next_unit]))
        
        if len(self.units)-2 != i:
            return tf.nn.relu(tf.matmul(inpt, W) + b)
        else:
            return tf.matmul(inpt, W) + b
    
    def train(self):
        cross_entropy,y = self.build()
        train_step = tf.train.AdamOptimizer(self.l_rate).minimize(cross_entropy)
        
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(self.y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
                
        sess = tf.InteractiveSession()
        tf.global_variables_initializer().run()
    
        for i in range(self.step):
            batch_xs, batch_ys = mnist.train.next_batch(50)
            sess.run(train_step, feed_dict={self.x: batch_xs, self.y_: batch_ys})
            
            if i % 1000 == 0:
                print("{0:<25}{1:>5}{2[0]:>15.4f}{2[1]:>15.4f}".format(",".join(map(str,self.units)),i,sess.run([cross_entropy,accuracy], feed_dict={self.x: mnist.test.images,
                                         self.y_: mnist.test.labels})))
            
        return sess,cross_entropy,y
            
    def print_accurency(self):
        sess,cross_entropy,y = self.train()
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(self.y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        print("{0:<25}{1:>5}{2[0]:>15.4f}{2[1]:>15.4f}".format(",".join(map(str,self.units)),self.step,sess.run([cross_entropy,accuracy], feed_dict={self.x: mnist.test.images,
                                         self.y_: mnist.test.labels})))
    
    
if __name__ == '__main__':
    print("{0:<25}{1:>5}{2:>15}{3:>15}".format("units","step","loss","accurancy"))
    print("-------------------------------------------------------------")
    print("***same size***")
    units_s1 = [784,392,392,10]
    units_s2 = [784,784,784,10]
    units_s3 = [784,1024,1024,10]
    nets1 = nnet(units_s1,0.001,10000)
    nets1.print_accurency()
    nets2 = nnet(units_s2,0.001,10000)
    nets2.print_accurency()
    nets3 = nnet(units_s3,0.001,10000)
    nets3.print_accurency()
    print("***decreasing size***")
    units_d1 = [784,392,100,10]
    units_d2 = [784,784,392,10]
    units_d3 = [784,784,100,10]
    netd1 = nnet(units_d1,0.001,10000)
    netd1.print_accurency()
    netd2 = nnet(units_d2,0.001,10000)
    netd2.print_accurency()
    netd3 = nnet(units_d3,0.001,10000)
    netd3.print_accurency()
    print("***increasing size***")
    units_i1 = [784,784,1024,10]
    units_i2 = [784,1024,1568,10]
    neti1 = nnet(units_i1,0.001,10000)
    neti1.print_accurency()
    neti2 = nnet(units_i2,0.001,10000)
    neti2.print_accurency()
    print("***3 hidden layer***")
    units_31 = [784,784,784,784,10]
    units_32 = [784,392,256,100,10]
    net31 = nnet(units_31,0.001,10000)
    net31.print_accurency()
    net32 = nnet(units_32,0.001,10000)
    net32.print_accurency()

末尾に実際の出力を載せた。
見やすさのため、データ挿入と改行を加えた(それでもわかりにくいが)。


結果を羅列すると

  • hidden layerのunit数を入力データのunit数と同じである場合で最良の結果が得られた(正答率0.985)。しかし、unit数が大きい場合や小さい場合と比較してそれほど有意差があるわけではない。
  • decreasing size形状(ピラミッド型)も比較的結果は良好だった。
  • unit数が入力データunit数よりも大きい場合は、誤差関数の値から学習が途中でうまく進まなくなる傾向が見られたので、ノイズを拾っているのかもしれない。
  • hidden layer数を2→3層にしても正答率が上がらなかった。


単純なニューラルネットワークでunit数の最適化では
顕著に正答率が向上するわけでもないようである。

正答率向上を狙うなら、アルゴリズムの変更や新しいハイパーパラメータ導入や変更を考えたほうが早いかもしれない。

入力データの複雑さなどにもよると思うが、
hidden layerのunit数を入力unit数と同じ数とするのが
第一選択として無難であるように思う。
また、unit数をdecreasing形状(ピラミッド型)も、
もう一つの選択としてはいいかもしれない。
unit数を極端に大きくしたり小さくすべきではない。

ピラミッド型で結果が良好であったことに関して、
徐々に特徴を抽象化させていると考えれば聞こえはいいかもしれないが
他の結果と結びつかず、よくわからない部分がある。
いくつかの要素が重なって影響しているのだろうか。

>python nnet.py
units                     step           loss      accurancy
-------------------------------------------------------------
***same size***
(784,100,100,10のデータは後で追加挿入)
784,100,100,10               0         2.2757         0.1661
784,100,100,10            1000         0.1627         0.9501
784,100,100,10            2000         0.1029         0.9681
784,100,100,10            3000         0.0968         0.9691
784,100,100,10            4000         0.0914         0.9716
784,100,100,10            5000         0.0791         0.9767
784,100,100,10            6000         0.0757         0.9765
784,100,100,10            7000         0.0799         0.9760
784,100,100,10            8000         0.0810         0.9767
784,100,100,10            9000         0.0744         0.9790
784,100,100,10           10000         0.0912         0.9775

784,392,392,10               0         2.2246         0.3222
784,392,392,10            1000         0.1323         0.9591
784,392,392,10            2000         0.0740         0.9758
784,392,392,10            3000         0.0866         0.9752
784,392,392,10            4000         0.0755         0.9788
784,392,392,10            5000         0.0822         0.9761
784,392,392,10            6000         0.0838         0.9784
784,392,392,10            7000         0.0832         0.9792
784,392,392,10            8000         0.1185         0.9735
784,392,392,10            9000         0.0985         0.9786
784,392,392,10           10000         0.0909         0.9788

784,784,784,10               0         2.1951         0.2831
784,784,784,10            1000         0.1036         0.9698
784,784,784,10            2000         0.0949         0.9729
784,784,784,10            3000         0.0674         0.9803
784,784,784,10            4000         0.0899         0.9747
784,784,784,10            5000         0.0844         0.9778
784,784,784,10            6000         0.0763         0.9800
784,784,784,10            7000         0.0852         0.9791
784,784,784,10            8000         0.0861         0.9820
784,784,784,10            9000         0.0871         0.9817
784,784,784,10           10000         0.0796         0.9852

784,1024,1024,10             0         2.1686         0.4042
784,1024,1024,10          1000         0.1070         0.9646
784,1024,1024,10          2000         0.0902         0.9716
784,1024,1024,10          3000         0.0844         0.9759
784,1024,1024,10          4000         0.0748         0.9766
784,1024,1024,10          5000         0.1080         0.9717
784,1024,1024,10          6000         0.0906         0.9794
784,1024,1024,10          7000         0.0776         0.9802
784,1024,1024,10          8000         0.0969         0.9789
784,1024,1024,10          9000         0.0816         0.9821
784,1024,1024,10         10000         0.1145         0.9793

***decreasing size***
784,392,100,10               0         2.2494         0.1666
784,392,100,10            1000         0.1273         0.9592
784,392,100,10            2000         0.0829         0.9735
784,392,100,10            3000         0.0808         0.9760
784,392,100,10            4000         0.0703         0.9798
784,392,100,10            5000         0.0672         0.9814
784,392,100,10            6000         0.0759         0.9776
784,392,100,10            7000         0.0780         0.9798
784,392,100,10            8000         0.0906         0.9784
784,392,100,10            9000         0.0765         0.9801
784,392,100,10           10000         0.0814         0.9818

784,784,392,10               0         2.1779         0.3722
784,784,392,10            1000         0.1029         0.9658
784,784,392,10            2000         0.0858         0.9741
784,784,392,10            3000         0.0655         0.9806
784,784,392,10            4000         0.0787         0.9761
784,784,392,10            5000         0.0727         0.9794
784,784,392,10            6000         0.0767         0.9780
784,784,392,10            7000         0.0918         0.9748
784,784,392,10            8000         0.0806         0.9795
784,784,392,10            9000         0.0749         0.9810
784,784,392,10           10000         0.0818         0.9816

784,784,100,10               0         2.2184         0.3194
784,784,100,10            1000         0.1115         0.9654
784,784,100,10            2000         0.0844         0.9742
784,784,100,10            3000         0.0760         0.9765
784,784,100,10            4000         0.0803         0.9765
784,784,100,10            5000         0.0808         0.9756
784,784,100,10            6000         0.0740         0.9808
784,784,100,10            7000         0.0712         0.9793
784,784,100,10            8000         0.0864         0.9782
784,784,100,10            9000         0.0769         0.9805
784,784,100,10           10000         0.0794         0.9813

***increasing size***
784,784,1024,10              0         2.2237         0.1310
784,784,1024,10           1000         0.1015         0.9672
784,784,1024,10           2000         0.1167         0.9648
784,784,1024,10           3000         0.0875         0.9752
784,784,1024,10           4000         0.0641         0.9796
784,784,1024,10           5000         0.1064         0.9719
784,784,1024,10           6000         0.0741         0.9790
784,784,1024,10           7000         0.0785         0.9794
784,784,1024,10           8000         0.0917         0.9779
784,784,1024,10           9000         0.1052         0.9779
784,784,1024,10          10000         0.0983         0.9798

784,1024,1568,10             0         2.1147         0.2857
784,1024,1568,10          1000         0.0919         0.9706
784,1024,1568,10          2000         0.0913         0.9718
784,1024,1568,10          3000         0.1061         0.9748
784,1024,1568,10          4000         0.0804         0.9783
784,1024,1568,10          5000         0.0822         0.9787
784,1024,1568,10          6000         0.1009         0.9761
784,1024,1568,10          7000         0.1074         0.9744
784,1024,1568,10          8000         0.0897         0.9801
784,1024,1568,10          9000         0.1069         0.9786
784,1024,1568,10         10000         0.1149         0.9753

***3 hidden layer***
784,784,784,784,10           0         2.2147         0.2800
784,784,784,784,10        1000         0.1029         0.9696
784,784,784,784,10        2000         0.0868         0.9729
784,784,784,784,10        3000         0.1062         0.9689
784,784,784,784,10        4000         0.0862         0.9771
784,784,784,784,10        5000         0.0898         0.9751
784,784,784,784,10        6000         0.1172         0.9761
784,784,784,784,10        7000         0.0894         0.9796
784,784,784,784,10        8000         0.1042         0.9796
784,784,784,784,10        9000         0.0885         0.9830
784,784,784,784,10       10000         0.0884         0.9815

784,392,256,100,10           0         2.2668         0.1804
784,392,256,100,10        1000         0.1087         0.9655
784,392,256,100,10        2000         0.0930         0.9713
784,392,256,100,10        3000         0.0841         0.9755
784,392,256,100,10        4000         0.0838         0.9762
784,392,256,100,10        5000         0.0781         0.9797
784,392,256,100,10        6000         0.0912         0.9759
784,392,256,100,10        7000         0.0856         0.9770
784,392,256,100,10        8000         0.0963         0.9776
784,392,256,100,10        9000         0.0982         0.9771
784,392,256,100,10       10000         0.0830         0.9812