External Memory

プログラミング周辺知識の備忘録メイン

MNISTデータのニューラルネットワーク学習

機械学習の練習によく使用されるMNISTの手書き数字データを利用して、
実際にプログラムを作成して学習を行った。

モデルは簡単な2つのニューラルネットワーク
入力と出力のみの2層、及びhidden layerを1層挿入した3層ネットワークです。

import tensorflow as tf
import math
import time
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist/", one_hot=True)

def n_net2(units,l_rate):
    
    T = time.time()
    x = tf.placeholder(tf.float32, [None, units[0]])
    
    W_1 = tf.Variable(
        tf.truncated_normal([units[0], units[1]],
                            stddev=1.0 / math.sqrt(float(units[0]))))
    
    b_1 = tf.Variable(tf.zeros([units[1]]))
    y = tf.matmul(x, W_1) + b_1
    
    y_ = tf.placeholder(tf.float32, [None, units[-1]])

    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

    train_step = tf.train.AdamOptimizer(l_rate).minimize(cross_entropy)

    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    
    for _ in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                         y_: mnist.test.labels}))  
    print("time: {:.4f} s".format(time.time()-T))

    
def n_net3(units,l_rate,keep_p=1.0,lmd=0.0):
    
    T = time.time()
    x = tf.placeholder(tf.float32, [None, units[0]])
    
    W_1 = tf.Variable(
        tf.truncated_normal([units[0], units[1]],
                            stddev=1.0 / math.sqrt(float(units[0]))))
    b_1 = tf.Variable(tf.zeros([units[1]]))
    
    hidden1 = tf.nn.relu(tf.matmul(x, W_1) + b_1)
    
    
    
    keep_prob = tf.placeholder(tf.float32)
    drop_fc = tf.nn.dropout(hidden1,keep_prob)
    
    
    W_h1 = tf.Variable(
        tf.truncated_normal([units[1], units[-1]],
                            stddev=1.0 / math.sqrt(float(units[1]))))
    b_h1 = tf.Variable(tf.zeros([units[-1]]))
    
    y = tf.matmul(drop_fc, W_h1) + b_h1
    
    y_ = tf.placeholder(tf.float32,[None,units[-1]])

    
    L2_n = tf.constant(0.0, dtype=tf.float32)    
    if lmd > 0.0:
        L2_sqr = tf.nn.l2_loss(W_1) + tf.nn.l2_loss(W_h1)
        L2_n = lmd * L2_sqr
        
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y_,logits = y))
    train_step = tf.train.AdamOptimizer(l_rate).minimize(cross_entropy + L2_n)
    
    
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    
    
    for _ in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: keep_p})
        
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                         y_: mnist.test.labels, keep_prob: keep_p}))
    print("time: {:.4f} s".format(time.time()-T))
    
if __name__ == '__main__':
    units1 = [784,10]
    n_net2(units1,0.001)
    
    units2 = [784,392,10]
    n_net3(units2,0.001)


input_data.read_data_sets関数を使って、
画像ファイルがない場合のみweb上からダウンロードを行い、データを読み込む。
read_data_sets関数は引数"validation_data=5000"を使って、
validation_dataの5000個データをmnist.validation.imagesで抽出することもできるので、
ハイパーパラメータの調整に使用する場合などに便利です。

今回はOptimizerにAdamを使用しています。
tensorflowのoptimizerを見比べて一番無難そうだったからです。
Adamについては以下を参照。
https://arxiv.org/pdf/1412.6980.pdf

AdagradやRMSPropのいいとこ取り改善ver.です。
傾きの影響を受けにくい性質など、ステップサイズをうまく制御しています。

mini_batch_next関数はランダムシャッフルされてから出てくる。

ドロップアウト、L2正則化の使用は逆に分類精度が低下した。
データ数や汎用化にとっての質が十分だっで過適合になりにくかったからか。
これらを適用すると、学習速度が落ちることがある。
これらの効果を確認するなら、
データ数を少なくして検証してみるのが良いかもしれない。

出力
>python mnist.py
Extracting mnist/train-images-idx3-ubyte.gz
Extracting mnist/train-labels-idx1-ubyte.gz
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz
0.9133
time: 0.8999 s
0.969
time: 5.1246 s

99%以上まで正答率を上げるのがどれくらい難しいかわからないが、
すぐに達成できても、検証サンプルとしては手軽に使えそうである。