External Memory

プログラミング周辺知識の備忘録メイン

deep Q-Networks(2)

前回のdeep Q-learningの続き
Deep Q-Networks (DQN) - External Memory

deep Q-Networks(DQN)によりenvironmentに依存せず、最適Q値を近似することが可能である。DQNは状態(今回は画像)を入力すると、actionごとのQ値を出力する。
environmentとして、かなり単純なゲームっぽいもの(Simple-seek,Simple-shooting)を実際に作成してネットワーク構造を変更せずに学習効果を確認した。

学習効果の確認においてはDQNによる予測actionとランダムなactionによる比較により、100回ゲームを行いその合計値から検証を行った。
ランダム選択の場合epsilon=1、完全なDQN依存の場合epsilon=0である。

DQN学習はepoch1000回ゲームを行い、epsilonはepsilon=1-0.9*step/epochとして大雑把にepsilonの減衰を適用した。

Simple_seek(max score 100)

4*4のマスの中でランダムに1つ当たりマスを指定し、初期位置(0,0)カーソルを移動させて当たりマスを探す。
actionはx,y座標のどちらかを+1カーソル移動の2種類のみである。
当たりマスを探し当てるとreward +1、4*4マスから外れるとreward -1である。

Simple_seek

epsilon = 1(random)
epoch:0 score:33
epoch:100 score:38
epoch:200 score:28
epoch:300 score:36
epoch:400 score:33
epoch:500 score:33
epoch:600 score:33
epoch:700 score:33
epoch:800 score:26
epoch:900 score:38
epoch:1000 score:21

epsilon = 0.05
epoch:0 score:36
epoch:100 score:63
epoch:200 score:89
epoch:300 score:94
epoch:400 score:95
epoch:500 score:94
epoch:600 score:98
epoch:700 score:97
epoch:800 score:94
epoch:900 score:97
epoch:1000 score:93

epsilon = 0
epoch:0 score:23
epoch:100 score:74
epoch:200 score:77
epoch:300 score:92
epoch:400 score:88
epoch:500 score:94
epoch:600 score:100
epoch:700 score:100
epoch:800 score:100
epoch:900 score:75
epoch:1000 score:90
Simple_shooting(max score 200)

9*5画面上でおおよそx方向に動くtargetを座標(5,9)からbulletで撃ち、当たった数をscoreとする。
actionは何もしないと、bullet発射の2種類であり、撃つタイミングがscoreに影響するものである。
targetに当たるとreward +1、外すとreward -1である。
targetとbulletは1ゲーム当たり2つに設定しているので、max scoreは200である。

Simple_shooting

epsilon = 1(random)
epoch:0 score:46
epoch:100 score:46
epoch:200 score:45
epoch:300 score:34
epoch:400 score:46
epoch:500 score:45
epoch:600 score:39
epoch:700 score:45
epoch:800 score:40
epoch:900 score:38
epoch:1000 score:48

epsilon = 0.05
epoch:0 score:70
epoch:100 score:138
epoch:200 score:182
epoch:300 score:185
epoch:400 score:188
epoch:500 score:189
epoch:600 score:186
epoch:700 score:191
epoch:800 score:185
epoch:900 score:181
epoch:1000 score:192

epsilon = 0
epoch:0 score:38
epoch:100 score:94
epoch:200 score:197
epoch:300 score:200
epoch:400 score:200
epoch:500 score:200
epoch:600 score:200
epoch:700 score:200
epoch:800 score:200
epoch:900 score:200
epoch:1000 score:200

これらの状態遷移は決定論的なのでepisilon=0でほぼmax scoreを出すように学習させることが出来た。
Sinple_seekの方は若干バラつきが大きく収束が遅かった。
報酬確定時点でゲームを打ち切ることで学習精度が向上したので報酬確定するまでの状態数の差あたりが原因だろうか。
なんにせよ、報酬確定後の余計な状態は切り捨てるほうがよさそうに感じた。


以下は検証で使ったコードである。
小さく複雑でない画像を入力するのでCNNでなく、Fully-Connectedのみで構成されたネットワークを使った。
(追記11/1 labelの計算に現在のパラメータを使い、一回前のパラメータを使っていないのでネットワークを複製してパラメータコピーなど修正必要だがそれなりに収束はしてくれている。)

env:Simple_shooting

class Simple_shooting(object):
    def __init__(self,num_actions,num_targets):
        self.actions =tuple(i for i in range(num_actions))
        self.img_height = 5
        self.img_width = 9
        self.num_targets = num_targets
        self.totalscore = 0
        
    def newgame(self):
        
        self.terminal = False
        self.reward = 0
        self.image = np.zeros(self.img_height*self.img_width)
        self.bullets = self.num_targets #equal targets
        self.pos_bullets = []
        self.pos_targets = random.sample(range((self.img_height-2)*self.img_width),self.num_targets)
                
        for i in self.pos_targets:
            self.image[i] = 1
            
        self.totalscore = 0
        
    def observe(self):
        return self.image, self.reward, self.terminal
    
    def preprocess(self,state):
        pass
    
    def excute(self,action):
        self.image = np.zeros(5*9)
        self.reward = 0        
        b_list = self.pos_bullets
        self.pos_bullets =[]
        
        for i in b_list:
            if i > self.img_width-1:
                self.pos_bullets.append(i-9)
                self.image[i-9] = -1
            else:
                self.reward -= 1
        
        if action == 1 and self.bullets > 0:
            self.image[self.img_height*self.img_width-5] = -1
            self.pos_bullets.append(self.img_height*self.img_width-5)
            self.bullets -= 1    
        
        self.pos_targets = [(i+1)%((self.img_height-2)*self.img_width) for i in self.pos_targets]
        for i in self.pos_targets:
            self.image[i] = 1
            
        hit = set(self.pos_targets) & set(self.pos_bullets)
        
        for i in hit:
            self.pos_bullets.remove(i)
            self.pos_targets.remove(i)
            self.image[i] = 0
            self.reward += 1
            self.totalscore += 1
            
        if len(self.pos_bullets) == 0 and self.bullets <= 0:
            self.terminal = True

env:Simple_seek

class Simple_seek(object):
    def __init__(self,num_actions):
        self.actions =tuple(i for i in range(num_actions))
        self.img_height = 4
        self.img_width = 4
        self.totalscore = 0
        
    def newgame(self):
        self.terminal = False
        self.reward = 0
        self.totalscore = 0
        self.image = np.zeros((self.img_height,self.img_width))
        pos = np.random.randint(1,16)
        self.t_x = pos%self.img_height
        self.t_y = pos//self.img_height
        self.image[self.t_y][self.t_x] = 1
        self.pointer = [0,0]
        self.image[0][0] = -1
    
    def observe(self):
        return self.image, self.reward, self.terminal
    
    def excute(self,action):
        self.image[self.pointer[1]][self.pointer[0]] = 0
        self.reward = 0
        
        if action == 0:
            self.pointer[1] += 1
        elif action == 1:
            self.pointer[0] += 1
                                          
        if self.pointer[0]>self.t_x or self.pointer[1]>self.t_y:
            self.reward -= 1
            self.terminal = True
        elif self.image[self.pointer[1]][self.pointer[0]] == 1:
            self.reward += 1
            self.terminal = True
            self.totalscore += 1
        else:
            self.image[self.pointer[1]][self.pointer[0]] = -1

agent

class Agent(object):
    
    def __init__(self,env):
        self.env = env        
        self.max_memory_size =1000
        self.replay_memory =[]
        self.current_pos = 0
        self.actions = env.actions
        self.learning_rate =0.001
        self.max_step = 1000
        self.discount_factor = 0.9
        self.batch_size =32
        self.y = self.nn_inference()
        
    def predict(self,state,epsilon=0.1,is_training=True):
        if not is_training:
            epsilon=0.05
            
        if np.random.rand() <= epsilon:
            action = np.random.choice(self.actions)
        else:
            action = self.actions[np.argmax(self.q_eval(state))]
        
        return action
    
    def store_transition(self,state,action,reward,n_state,terminal):
        if len(self.replay_memory) == self.max_memory_size:
            self.replay_memory[self.current_pos] = (state,action,reward,n_state,terminal)
            self.current_pos = (self.current_pos +1) % self.max_memory_size
        else:
            self.replay_memory.append((state,action,reward,n_state,terminal))

    def nn_inference(self):
        #simple_shooting
        self.x  = tf.placeholder(tf.float32, [None, 45])
        
        #simple_seek
        #self.x  = tf.placeholder(tf.float32, [None, 4,4])
        #inpt = tf.reshape(self.x, [-1,16])
        
        W_1 = tf.Variable(tf.truncated_normal([45, 128],stddev=0.05))
        b_1 = tf.Variable(tf.zeros([128]))
    
        hidden1 = tf.nn.relu(tf.matmul(self.x, W_1) + b_1)
        #hidden1 = tf.nn.relu(tf.matmul(inpt, W_1) + b_1)

        W_3 = tf.Variable(tf.truncated_normal([128, 2],stddev=0.05))
        b_3 = tf.Variable(tf.zeros([2]))
    
        output = tf.matmul(hidden1, W_3) + b_3
        
        return output
    
    def q_eval(self,state):
        return self.sess.run(self.y, feed_dict={self.x:[state]})[0]
    
    def train_update(self):
        state_batch = []
        labels = []
        r_mem_length = len(self.replay_memory)
        
        batchsize = min(r_mem_length,self.batch_size)
        r_mem_indexes = np.random.randint(0,r_mem_length,batchsize)
        
        for mem_index in r_mem_indexes:
            state,action,reward,n_state,terminal = self.replay_memory[mem_index]
            state_batch.append(state)
            
            q_vals= self.q_eval(state)
            
            if terminal:
                q_vals[action] = reward
            else:
                q_vals[action] = reward + self.discount_factor * np.max(self.q_eval(n_state))
            
            labels.append(q_vals)
            
        self.sess.run(self.train_op,feed_dict={self.x:state_batch,self.y_:labels})
        
    def train(self,epoch):
        
        self.y_ = tf.placeholder(tf.float32, [None, len(self.actions)])
        loss = tf.reduce_mean(tf.square(self.y_ - self.y))
        
        self.train_op = tf.train.RMSPropOptimizer(self.learning_rate).minimize(loss)
        #self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(loss)
        self.sess = tf.InteractiveSession()
        tf.global_variables_initializer().run()
        
        for i in range(epoch):
            self.env.newgame()
            state, reward, terminal = self.env.observe()
            step = 0
            
            while not terminal or step == self.max_step:
                  
                action = self.predict(state,epsilon=1-0.9*i/epoch)

                self.env.excute(action)                
                next_state, reward, terminal = self.env.observe()

                self.store_transition(state,action,reward,next_state,terminal)
                
                self.train_update()
                
                state = next_state
                step += 1
            
            if i%(epoch//10) == 0:
                #saver = tf.train.Saver(max_to_keep=11)
                #saver.save(self.sess,"tmp/dqn",global_step = i)
                self.test_play(100,i)   
        #saver.save(self.sess,"tmp/dqn",global_step = epoch)
        
    def test_play(self,times,epoch):
        
        score = 0
        
        for i in range(times):
            self.env.newgame()
            state, reward, terminal = self.env.observe()
            
            while not terminal:
                action = self.predict(state,is_training=False)
                self.env.excute(action)                
                state, _, terminal = self.env.observe()

            score += self.env.totalscore
        
        print("epoch:{} score:{}".format(epoch,score))
    
    
if __name__ == '__main__':

    env = Simple_shooting(2,2)
    #env = Simple_seek(2)
    agent = Agent(env)
    
    agent.train(1000)
    
    agent.test_play(100,1000)

Single Shot MultiBox Detector、Crafting GBD-Net

SSD: Single Shot MultiBox Detector

https://arxiv.org/abs/1512.02325

CNNにおいて普通ネットワークの先端に近づくにつれて、同サイズのフィルターサイズの受容野のスケールは大きくなる。
SSDは一つのニューラルネットワークより異なるスケールの特徴マップからの出力によって、直接クラス弁別とbounding-boxの決定を行って物体位置認識を行う。

これにより高速かつ高精度で物体位置認識の計算を行うことが可能となった。
SSD300におけるPASCAL VOC2007testセットの精度と速度はデータセットが07+12でmAP 74.3%、速度は59 FPSである。
SSD512ではmAP 76.8 %、速度は22 FPSである。(Nvidia Titan X)


特徴マップの作成にはベースCNNを用い、この文献ではVGG16である。
このマップから直接または、conv層を追加することによってクラスごとのスコアとboxのオフセットの出力を行う。
よって一つの層の特徴マップから(c+4)kmn個の出力が得られる。
cはクラス数、4はboxのオフセットパラメータ、kはboxの領域数、m,nはフィルターマップの領域の大きさを表す。
それぞれの特徴マップのセルごと1,2,3,1/2,1/3と1/1のスケール拡大boxの6種類のアスペクト比のbox(default box)から、信頼スコアとboxのオフセットを予測する。
default boxのスケールはサイズが同じ(3*3)conv層のスライドが2以上なら層ごとにスケールは2倍で解像度は下がる。
各特徴マップでのdefault boxのスケールは以下のように表される。

\displaystyle s_k=s_{min}+\frac{s_{max}-s_{min}}{m-1}(k-1)

最下層スケールs_minは0.2、最先端層のスケールはs_max=0.9でmは層数でkはcurrent層である。


学習においては誤差関数と逆伝播がend-to-endで適用される。
誤差関数は信頼スコアと位置誤差の重みαありの和として表される(文献ではα=1)。
それぞれソフトマックス誤差関数とsmoothL1誤差関数が用いられる。
領域にはground-truthに対してjaccard overlap 0.5以上をpositiveとしている。
またpositiveとnegativeのdefault boxの学習サンプル数の比率を調整するために、
信頼スコアでソートしたうえで比率が1/3となるように選び出している。


testセットの物体認識に対しては、恐らく信頼スコア値のtopいくつかを抽出してそのbox位置とクラスを予測とするのだろう。


Crafting GBD-Net for Object Detection

https://arxiv.org/pdf/1610.02579.pdf

GBD-Netは物体領域認識において、物体だけに焦点を当てるだけでなく
その周囲の環境や文脈をの情報を盛り込んで判別を行うニューラルネットワークである。

物体領域の判別において局所的な強い特徴量によって誤りを引き起こしたり、遮蔽物や局所的な類似特徴による判別ミスなどを引き起こす恐れがあった。
文献においてはうさ耳バンドを着用した人の顔や、動物の体表模様などを例として挙げている。


GBD-Net(gated bi-directional CNN)は、ゲート機能を有する双方向ネットワーク構造を持つCNNである。
双方向ネットワークは複数のスケールと解像度間でゲートを通して情報をお互いにやり取りすることで、大域的な特徴と局所的な細かい特徴間での相補的な効果により判別精度を向上させる。
この時ゲートは関連のない情報と関連性のある情報を制御する機能を持つ。
この文献においてGBD-Netはfast R-CNNの枠組みの下で実装される。

GBD-NetはCNN中の後半部分辺りに挿入される(BN-netではinception 4d、Res-Netでは269では234番目)。
しかしどこにでも挿入することは可能である。
GBD-Netの直前にRoi pooling layerを配置して様々な特徴マップサイズを、同じサイズのベクトルにする。
それぞれの解像度とスケールに対して特徴量は共有されないので各解像度に対し複数のCNNに分岐する。

Roi poolingで4つのスケールに分岐させた後、warpingさせて4つの特徴マップ\mathbf{f}^{-0.2},\mathbf{f}^{0.2},\mathbf{f}^{0.8},\mathbf{f}^{1.7}をGBD-netに入力をする。
\mathbf{f}^pのpはスケール倍率でサイズが224*224にwarpingするデフォルトの場合p=0.2(1.2倍)である。
\mathbf{h}_i^0\ i=1,2,3,4を入力、\mathbf{h}_i^3を出力とするGBD-Net構造は以下のように表される。
iはそれぞれのスケールである。

\mathbf{f}_i^1=\sigma(\mathbf{h}_i^0\otimes \mathbf{w}_i^1+\mathbf{b}_i^{0,1})+G_i^1\bullet \sigma(\mathbf{h}_{i-1}^1\otimes \mathbf{w}_{i-1,i}^1+\mathbf{b}_i^1)
\mathbf{f}_i^2=\sigma(\mathbf{h}_i^0\otimes \mathbf{w}_i^2+\mathbf{b}_i^{0,2})+G_i^2\bullet \sigma(\mathbf{h}_{i+1}^2\otimes \mathbf{w}_{i+1,i}^2+\mathbf{b}_i^2)
\mathbf{f}_i^3=\mathbf{h}_i^0+\beta max(\mathbf{h}_i^1,\mathbf{h}_i^2)

G_i^1=sigmoid(\mathbf{h}_{i-1}^0\otimes \mathbf{w}_{i-1,i}^g+\mathbf{b}_{i-1,i}^g)
G_i^2=sigmoid(\mathbf{h}_{i+1}^0\otimes \mathbf{w}_{i+1,i}^g+\mathbf{b}_{i+1,i}^g)

σ()はReLU、sigmoid()はシグモイド関数である。
w,bはそれぞれ重みとバイアス、Gはゲート関数を表す。
これらの表式はconv層と非線形層で実現される。


GBDの学習においてはpre-trainデータを大きく変化しないような初期値の設定が必要である。loss関数はcls lossとloc lossを足し合わせたもので、loc lossにはsmooth L1関数を用いる。

テスト時のregion proposalには拡張RPNを使ったCraftを用いる。

ILSVRC2016 ImageNetのobject detectionではmAP 66%、後にさらに1%弱ほど向上しているようである。

VOC2007 dataset 07+12でmAP 77.2%。baseより+4.1%である。

Deep Q-Networks (DQN)

強化学習は普通の深層学習より学習が能動的という意味でAlっぽいので面白そうではある。以下の有名そうな論文を読んで勉強のためのとっかかりとした。

Playing Atari with Deep Reinforcement Learning
https://arxiv.org/pdf/1312.5602.pdf

強化学習の全体像を知ったわけではないが、
最適化をニューラルネットワークに任せるという意味では、
政策やアルゴリズムをあまり考える手間が少なそうなのでおそらく取っ付きやすいのではないかと思う。


DQNでは最適化された行動価値関数Qを近似するためにCNNを用いる。
CNNはSGDを用いた重み更新を用いたQ-learningアルゴリズムの変形で学習される。
この時、相関データと非定常な分布を緩和するために以前の遷移をランダムにサンプリングするexperience replayを用いる。

最適化されたQは報酬の期待値が最大化され、time t、状態s、action a、報酬Rを用いて以下のように表される。

Q^*(s,a)=\max_{\pi}\mathbb{E}[R_t|s_t=s,a_t=a,\pi]


またこれはBellman equationと等価である。

Q^*(s,a)=\mathbb{E}_{s'\sim \mathcal{E}}[r+\gamma \max_{a'}Q(s',a')|s,a]
s',a'は次ステップの状態と行動、γは割引率である。

ニューラルネットワーク非線形近似を行うと(Q(s,a;\theta)\approx Q^*(s,a):θは重み)、
loss関数はiteration iを用いて以下のようになる。

L_i(\theta_i)=\mathbb{E}_{s,a\sim \rho(\cdot)}[(y_i-Q(s,a;\theta_i))^2]

ここでy_i=\mathbb{E}_{s'\sim \mathcal{E}}[r+\gamma \max_{a'}Q(s',a'|\theta_{i-1})|s,a]となり、
前回の重みθ_{i-1}を用いているが、文献中のAlgorithm 1ではそうはなっていない。
なぜi-1としているのかわからないが、step一回ごとに更新するのでs'、a'と行動分布を揃えるためだろうか。


"
追記11/1
Human-Level Control through Deep Reinforcement Learning
https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
を読むと、やっぱり一回前のパラメータを使っている。
Qの変動が抑えられるのでactionの選択に関する影響に遅延が生じるため、発散や振動が発生しにくくなる。
"


このニューラルネットワークを用いた手法は環境に依存せず強化学習を行え、また政策を勘案する必要がない。よってatari gameの種類ごとにモデルやアルゴリズムの変更を行う必要がない。


Deep Q-learningの手続きは文献中Algorithm 1に記載がある。
ここでreplay memoryは最近のN個を保持しランダムサンプリングを行う。
これにより学習サンプルの強い相関や最大化行動の影響によるサンプルと偏りを緩和する。φは入力の前処理である。


前処理はRGBをグレースケールとしスケールを約半分にし次いでcroppingし正方形とイメージとし、履歴の最近4フレームにこの前処理を行い入力のための蓄積を行っている。
CNNにおけるQのパラメータ化においては、入力を状態パラメータ、出力を行動予測に対応するQ値としている。
これにより一回の順方向伝播ですべての行動のQ値を計算することが出来る。
CNNはconv2層、Fully-connected1層である。
conv 8*8 16ch(stride4) + conv 4*4 32ch(stride2) + FC256units + output(num of action 4-18units)。

ハイパーパラメータは全てのgameで報酬の設定を除き同じである。
報酬はスケールの関係上全てのプラス報酬を1、マイナス報酬を-1、その他0としている。

batchsize32のRMSProp optimaizerを用い、ε-greedyでは1 millionステップごとに1-0.1まで線形減衰させている。
テスト時はε=0.05である。
また計算コストの関係からスペースインベーダ以外のゲームでは4ステップごとに行動を選択し(インベーダは3)、そのステップ間はその行動を維持する。

また学習中に評価においてはQ値のトータル平均報酬を用いるのでなく、ある状態の集合からの最大Q値の平均を用いる。
これは重みの小さな変化によりサンプリングする状態分布が大きく変化するので非常にnoisyになり、学習性能の追跡がやりづらいからである。

全てのゲームで他のRL手法より高い平均スコアであり、
スペースインベーダ、Q*bert、Seaquest以外のゲームで人間より高いまた近い平均スコアを示している。インベーダなどは長期的なスケールでの戦略が要求されるからである。



Algorithm 1はpythonコードではおそらくこのような感じになる。
enviromentは空で、εの減衰などいろいろやってないし、当然動かない。
CNN部分は代わりにリカレントネットワークとかもあり得そうな気がするが。

import tensorflow as tf
import numpy as np

class Environment(object):
    def __init__(self):
        self.actions =(0,1)
    
    def newgame(self):
        pass
    
    def observe(self):
        pass
    
    def preprocess(self,state):
        pass
    
    def excute(self,action):
        pass
    

class Agent(object):
    
    def __init__(self,env):
        self.env = env
        self.replay_memory =[]
        self.max_memory_size =1000
        self.current_pos = 0
        self.actions = env.actions
        self.max_step = 1000
        self.discount_factor = 0.9
        self.batch_size =32
        self.y = self.nn_inference()
        
    def predict(self,state,epsilon=0.1,is_training=True):
        if not is_training:
            epsilon=0.05
            
        if np.random.rand() <= epsilon:
            action = np.random.choice(self.actions)
        else:
            action = self.actions[np.argmax(self.q_eval(state))]
        
        return action
    
    def store_transition(self,state,action,reward,n_state,terminal):
        self.replay_memory[self.current_pos] = (state,action,reward,n_state,terminal)
        self.current_pos = (self.current_pos +1) % self.max_memory_size

    def nn_inference(self):
        self.x  = tf.placeholder(tf.float32, [None, 84*84])
        
        W1 = tf.Variable(tf.truncated_normal([8, 8, 4, 16],stddev=0.05))
        b1 = tf.Variable(tf.zeros([16]))
        conv1 = tf.nn.relu(tf.nn.conv2d(self.x, W1, strides=[1,4,4,1], padding='SAME')+ b1)
        
        W2 = tf.Variable(tf.truncated_normal([4, 4, 16, 32],stddev=0.05))
        b2 = tf.Variable(tf.zeros([32]))
        conv2 = tf.nn.relu(tf.nn.conv2d(conv1, W2, strides=[1,2,2,1], padding='SAME')+ b2)
        
        W3 = tf.Variable(tf.truncated_normal([11*11*32, 256],stddev=0.05))
        b3 = tf.Variable(tf.zeros([256]))
        fc1 = tf.reshape(conv2, [-1, 256])
        fc1 = tf.nn.relu(tf.matmul(fc1, W3)+ b3)
        
        W4 = tf.Variable(tf.truncated_normal([256, self.actions],stddev=0.05))
        b4 = tf.Variable(tf.zeros([self.actions]))
        output = tf.matmul(fc1, W4)+ b4
        
        return output
    
    def q_eval(self,state):
        return self.sess.run(self.y, feed_dict={self.x:[state]})[0]
    
    def train_update(self):
        state_batch = []
        labels = []
        r_mem_length = len(self.replay_memory)
        
        batchsize = min(r_mem_length,self.batch_size)
        r_mem_indexes = np.random.randint(0,r_mem_length,batchsize)
        
        for mem_index in r_mem_indexes:
            state,action,reward,n_state,terminal = self.replay_memory[mem_index]
            state_batch.append(state)

            q_vals= self.q_eval(state)
            
            if terminal:
                q_vals[action] = reward
            else:
                q_vals[action] = reward + self.discount_factor * np.max(self.q_eval(n_state))
            
            labels.append(q_vals)
            
        self.sess.run(self.train_op,feed_dict={self.x:state_batch,self.y_:labels})
        
    def train(self,epoch):
        
        self.y_ = tf.placeholder(tf.float32, [None, self.actions])
        loss = tf.reduce_mean(tf.square(self.y_ - self.y))
        
        self.train_op = tf.train.RMSPropOptimizer(self.learning_rate).minimize(loss)
        
        self.sess = tf.InteractiveSession()
        tf.global_variables_initializer().run()
        
        for i in range(epoch):
            self.env.newgame()
            state, reward, terminal = self.env.observe()
            step = 0
            
            while not terminal or step == self.max_step:
                #phi = env.preprocess(state)                    
                action = agent.predict(state)
                #for i in range(4):
                self.env.excute(action)                
                next_state, reward, terminal = self.env.observe()
                
                #n_phi = env.preprocess(next_state)
                agent.store_transition(state,action,reward,next_state,terminal)
                
                agent.train_update()
                
                state = next_state
                step += 1
                
        saver = tf.train.Saver()
        saver.save(self.sess,"tmp/dqn",global_step = epoch)
    
    
if __name__ == '__main__':

    env = Environment()
    agent = Agent(env)
    
    agent.train(10000)
    
    test_play()