External Memory

プログラミング周辺知識の備忘録メイン

deep Q-Networks(2)

前回のdeep Q-learningの続き
Deep Q-Networks (DQN) - External Memory

deep Q-Networks(DQN)によりenvironmentに依存せず、最適Q値を近似することが可能である。DQNは状態(今回は画像)を入力すると、actionごとのQ値を出力する。
environmentとして、かなり単純なゲームっぽいもの(Simple-seek,Simple-shooting)を実際に作成してネットワーク構造を変更せずに学習効果を確認した。

学習効果の確認においてはDQNによる予測actionとランダムなactionによる比較により、100回ゲームを行いその合計値から検証を行った。
ランダム選択の場合epsilon=1、完全なDQN依存の場合epsilon=0である。

DQN学習はepoch1000回ゲームを行い、epsilonはepsilon=1-0.9*step/epochとして大雑把にepsilonの減衰を適用した。

Simple_seek(max score 100)

4*4のマスの中でランダムに1つ当たりマスを指定し、初期位置(0,0)カーソルを移動させて当たりマスを探す。
actionはx,y座標のどちらかを+1カーソル移動の2種類のみである。
当たりマスを探し当てるとreward +1、4*4マスから外れるとreward -1である。

Simple_seek

epsilon = 1(random)
epoch:0 score:33
epoch:100 score:38
epoch:200 score:28
epoch:300 score:36
epoch:400 score:33
epoch:500 score:33
epoch:600 score:33
epoch:700 score:33
epoch:800 score:26
epoch:900 score:38
epoch:1000 score:21

epsilon = 0.05
epoch:0 score:36
epoch:100 score:63
epoch:200 score:89
epoch:300 score:94
epoch:400 score:95
epoch:500 score:94
epoch:600 score:98
epoch:700 score:97
epoch:800 score:94
epoch:900 score:97
epoch:1000 score:93

epsilon = 0
epoch:0 score:23
epoch:100 score:74
epoch:200 score:77
epoch:300 score:92
epoch:400 score:88
epoch:500 score:94
epoch:600 score:100
epoch:700 score:100
epoch:800 score:100
epoch:900 score:75
epoch:1000 score:90
Simple_shooting(max score 200)

9*5画面上でおおよそx方向に動くtargetを座標(5,9)からbulletで撃ち、当たった数をscoreとする。
actionは何もしないと、bullet発射の2種類であり、撃つタイミングがscoreに影響するものである。
targetに当たるとreward +1、外すとreward -1である。
targetとbulletは1ゲーム当たり2つに設定しているので、max scoreは200である。

Simple_shooting

epsilon = 1(random)
epoch:0 score:46
epoch:100 score:46
epoch:200 score:45
epoch:300 score:34
epoch:400 score:46
epoch:500 score:45
epoch:600 score:39
epoch:700 score:45
epoch:800 score:40
epoch:900 score:38
epoch:1000 score:48

epsilon = 0.05
epoch:0 score:70
epoch:100 score:138
epoch:200 score:182
epoch:300 score:185
epoch:400 score:188
epoch:500 score:189
epoch:600 score:186
epoch:700 score:191
epoch:800 score:185
epoch:900 score:181
epoch:1000 score:192

epsilon = 0
epoch:0 score:38
epoch:100 score:94
epoch:200 score:197
epoch:300 score:200
epoch:400 score:200
epoch:500 score:200
epoch:600 score:200
epoch:700 score:200
epoch:800 score:200
epoch:900 score:200
epoch:1000 score:200

これらの状態遷移は決定論的なのでepisilon=0でほぼmax scoreを出すように学習させることが出来た。
Sinple_seekの方は若干バラつきが大きく収束が遅かった。
報酬確定時点でゲームを打ち切ることで学習精度が向上したので報酬確定するまでの状態数の差あたりが原因だろうか。
なんにせよ、報酬確定後の余計な状態は切り捨てるほうがよさそうに感じた。


以下は検証で使ったコードである。
小さく複雑でない画像を入力するのでCNNでなく、Fully-Connectedのみで構成されたネットワークを使った。
(追記11/1 labelの計算に現在のパラメータを使い、一回前のパラメータを使っていないのでネットワークを複製してパラメータコピーなど修正必要だがそれなりに収束はしてくれている。)

env:Simple_shooting

class Simple_shooting(object):
    def __init__(self,num_actions,num_targets):
        self.actions =tuple(i for i in range(num_actions))
        self.img_height = 5
        self.img_width = 9
        self.num_targets = num_targets
        self.totalscore = 0
        
    def newgame(self):
        
        self.terminal = False
        self.reward = 0
        self.image = np.zeros(self.img_height*self.img_width)
        self.bullets = self.num_targets #equal targets
        self.pos_bullets = []
        self.pos_targets = random.sample(range((self.img_height-2)*self.img_width),self.num_targets)
                
        for i in self.pos_targets:
            self.image[i] = 1
            
        self.totalscore = 0
        
    def observe(self):
        return self.image, self.reward, self.terminal
    
    def preprocess(self,state):
        pass
    
    def excute(self,action):
        self.image = np.zeros(5*9)
        self.reward = 0        
        b_list = self.pos_bullets
        self.pos_bullets =[]
        
        for i in b_list:
            if i > self.img_width-1:
                self.pos_bullets.append(i-9)
                self.image[i-9] = -1
            else:
                self.reward -= 1
        
        if action == 1 and self.bullets > 0:
            self.image[self.img_height*self.img_width-5] = -1
            self.pos_bullets.append(self.img_height*self.img_width-5)
            self.bullets -= 1    
        
        self.pos_targets = [(i+1)%((self.img_height-2)*self.img_width) for i in self.pos_targets]
        for i in self.pos_targets:
            self.image[i] = 1
            
        hit = set(self.pos_targets) & set(self.pos_bullets)
        
        for i in hit:
            self.pos_bullets.remove(i)
            self.pos_targets.remove(i)
            self.image[i] = 0
            self.reward += 1
            self.totalscore += 1
            
        if len(self.pos_bullets) == 0 and self.bullets <= 0:
            self.terminal = True

env:Simple_seek

class Simple_seek(object):
    def __init__(self,num_actions):
        self.actions =tuple(i for i in range(num_actions))
        self.img_height = 4
        self.img_width = 4
        self.totalscore = 0
        
    def newgame(self):
        self.terminal = False
        self.reward = 0
        self.totalscore = 0
        self.image = np.zeros((self.img_height,self.img_width))
        pos = np.random.randint(1,16)
        self.t_x = pos%self.img_height
        self.t_y = pos//self.img_height
        self.image[self.t_y][self.t_x] = 1
        self.pointer = [0,0]
        self.image[0][0] = -1
    
    def observe(self):
        return self.image, self.reward, self.terminal
    
    def excute(self,action):
        self.image[self.pointer[1]][self.pointer[0]] = 0
        self.reward = 0
        
        if action == 0:
            self.pointer[1] += 1
        elif action == 1:
            self.pointer[0] += 1
                                          
        if self.pointer[0]>self.t_x or self.pointer[1]>self.t_y:
            self.reward -= 1
            self.terminal = True
        elif self.image[self.pointer[1]][self.pointer[0]] == 1:
            self.reward += 1
            self.terminal = True
            self.totalscore += 1
        else:
            self.image[self.pointer[1]][self.pointer[0]] = -1

agent

class Agent(object):
    
    def __init__(self,env):
        self.env = env        
        self.max_memory_size =1000
        self.replay_memory =[]
        self.current_pos = 0
        self.actions = env.actions
        self.learning_rate =0.001
        self.max_step = 1000
        self.discount_factor = 0.9
        self.batch_size =32
        self.y = self.nn_inference()
        
    def predict(self,state,epsilon=0.1,is_training=True):
        if not is_training:
            epsilon=0.05
            
        if np.random.rand() <= epsilon:
            action = np.random.choice(self.actions)
        else:
            action = self.actions[np.argmax(self.q_eval(state))]
        
        return action
    
    def store_transition(self,state,action,reward,n_state,terminal):
        if len(self.replay_memory) == self.max_memory_size:
            self.replay_memory[self.current_pos] = (state,action,reward,n_state,terminal)
            self.current_pos = (self.current_pos +1) % self.max_memory_size
        else:
            self.replay_memory.append((state,action,reward,n_state,terminal))

    def nn_inference(self):
        #simple_shooting
        self.x  = tf.placeholder(tf.float32, [None, 45])
        
        #simple_seek
        #self.x  = tf.placeholder(tf.float32, [None, 4,4])
        #inpt = tf.reshape(self.x, [-1,16])
        
        W_1 = tf.Variable(tf.truncated_normal([45, 128],stddev=0.05))
        b_1 = tf.Variable(tf.zeros([128]))
    
        hidden1 = tf.nn.relu(tf.matmul(self.x, W_1) + b_1)
        #hidden1 = tf.nn.relu(tf.matmul(inpt, W_1) + b_1)

        W_3 = tf.Variable(tf.truncated_normal([128, 2],stddev=0.05))
        b_3 = tf.Variable(tf.zeros([2]))
    
        output = tf.matmul(hidden1, W_3) + b_3
        
        return output
    
    def q_eval(self,state):
        return self.sess.run(self.y, feed_dict={self.x:[state]})[0]
    
    def train_update(self):
        state_batch = []
        labels = []
        r_mem_length = len(self.replay_memory)
        
        batchsize = min(r_mem_length,self.batch_size)
        r_mem_indexes = np.random.randint(0,r_mem_length,batchsize)
        
        for mem_index in r_mem_indexes:
            state,action,reward,n_state,terminal = self.replay_memory[mem_index]
            state_batch.append(state)
            
            q_vals= self.q_eval(state)
            
            if terminal:
                q_vals[action] = reward
            else:
                q_vals[action] = reward + self.discount_factor * np.max(self.q_eval(n_state))
            
            labels.append(q_vals)
            
        self.sess.run(self.train_op,feed_dict={self.x:state_batch,self.y_:labels})
        
    def train(self,epoch):
        
        self.y_ = tf.placeholder(tf.float32, [None, len(self.actions)])
        loss = tf.reduce_mean(tf.square(self.y_ - self.y))
        
        self.train_op = tf.train.RMSPropOptimizer(self.learning_rate).minimize(loss)
        #self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(loss)
        self.sess = tf.InteractiveSession()
        tf.global_variables_initializer().run()
        
        for i in range(epoch):
            self.env.newgame()
            state, reward, terminal = self.env.observe()
            step = 0
            
            while not terminal or step == self.max_step:
                  
                action = self.predict(state,epsilon=1-0.9*i/epoch)

                self.env.excute(action)                
                next_state, reward, terminal = self.env.observe()

                self.store_transition(state,action,reward,next_state,terminal)
                
                self.train_update()
                
                state = next_state
                step += 1
            
            if i%(epoch//10) == 0:
                #saver = tf.train.Saver(max_to_keep=11)
                #saver.save(self.sess,"tmp/dqn",global_step = i)
                self.test_play(100,i)   
        #saver.save(self.sess,"tmp/dqn",global_step = epoch)
        
    def test_play(self,times,epoch):
        
        score = 0
        
        for i in range(times):
            self.env.newgame()
            state, reward, terminal = self.env.observe()
            
            while not terminal:
                action = self.predict(state,is_training=False)
                self.env.excute(action)                
                state, _, terminal = self.env.observe()

            score += self.env.totalscore
        
        print("epoch:{} score:{}".format(epoch,score))
    
    
if __name__ == '__main__':

    env = Simple_shooting(2,2)
    #env = Simple_seek(2)
    agent = Agent(env)
    
    agent.train(1000)
    
    agent.test_play(100,1000)