PARL强化学习——用人工智能玩合成大西瓜!

项目地址

https://github.com/Sharpiless/PARL-DQN-daxigua

项目视频链接:

https://www.bilibili.com/video/BV1Tz4y1U7HE

https://www.bilibili.com/video/BV1Wy4y1n73E

https://www.bilibili.com/video/BV1gN411d7dr

项目背景:

最近突然一款小游戏火爆到冲上热搜了,刚试着玩了一会,完全停不下来的感觉。它的规则非常简单,适合各个年龄段的玩家,打开游戏页面按照滑落顺序合成更大的西瓜,趣味性十足。从玩法上看,《合成大西瓜》与前几年走红过的《2048》类似,通过控制水果下落速度和位置进行合成,两个小体型水果合成更大体型水果,大西瓜则是一连串合成进化的终点。但是相比于2048》游戏具有固定的位置,《合成大西瓜》使用了物理模组,使得对局状态随机性大大增加,强化学习算法的训练难度也随之增加。

 

项目介绍:

本项目基于PARL强化学习框架进行开发,算法使用DQN算法,以当前游戏状态的特诊为输入,输出下一步的动作。强化学习作为AI技术发展的重要分支,除了应用于模拟器和游戏领域,在工业领域也正取得长足的进步。强化学习的主要思想是基于机器人(agent)和环境(environment)的交互学习,其中agent通过action影响environment,environment返回reward和state,整个交互过程是一个马尔可夫决策过程。在交互学习的过程中,没有人的示范,而是让机器自主去做一个动作,让机器拥有自我学习和自我思考的能力。强化学习能够解决很多有监督学习方法无法解决的问题。

其中PARL框架是一个高性能、灵活的强化学习框架,通过组件化的模块可以轻松搭建 DQN/PPO/A3C等经典算法,同时具备扩展多simulator并行的能力。算法使用DQN算法,DQN(Deep Q-learning Network)算法是经典强化学习算法 Q-Learning和深度神经网络的结合,并采用了经验回放、目标网络等技巧, 学习出端到端的控制算法。同时我们也可以使用PARL框架提供的诸多算法,仅需一行代码即可替换。

游戏环境:

其中游戏我是用pygame进行重构,从而可以更好地与PARL框架进行交互,并且可以使用多进程的方法加速训练过程。

1. 游戏共有11种水果:

 

2. 碰撞检测:

1. def setup_collision_handler(self):

2.         def post_solve_bird_line(arbiter, space, data):

3.             if not self.lock:

4.                 self.lock = True

5.                 b1, b2 = NoneNone

6.                 i = arbiter.shapes[0].collision_type + 1

7.                 x1, y1 = arbiter.shapes[0].body.position

8.                 x2, y2 = arbiter.shapes[1].body.position

3. 奖励机制:

每合成一种水果,reward加相应的分数

水果

分数

樱桃

2

橘子

3

...

...

西瓜

10

大西瓜

100

1. if i < 11:

2.     self.last_score = self.score

3.     self.score += i

4. elif i == 11:

5.     self.last_score = self.score

6.     self.score += 100

4. 惩罚机制:

如果一次action后 1s(即新旧水果生成间隔)没有成功合成水果,则reward减去放下水果的分数

1. _, reward, _ = self.next_frame(action=action)

2. for _ in range(int(self.create_time * self.FPS)):

3.     _, nreward, _ = self.next_frame(action=None, generate=False)

4.     reward += nreward

5.     if reward == 0:

6.         reward = -i

5. 输入特征:

之前的版本(https://aistudio.baidu.com/aistudio/projectdetail/1540300)输入特征为游戏截图,采用ResNet提取特征

但是直接原图输入使得模型很难学习到有效的特征

因此新版本使用pygame接口获取当前状态

1. def get_feature(self, N_class=12Keep=15):

2.         # 特征工程

3.         c_t = self.i

4.         # 自身类别

5.         feature_t = np.zeros((1, N_class + 1), dtype=np.float)

6.         feature_t[0c_t] = 1.

7.         feature_t[00] = 0.5

8.         feature_p = np.zeros((Keep, N_class + 1), dtype=np.float)

9.         Xcs = []

10.         Ycs = []

11.         Ts = []

12.         for i, ball in enumerate(self.balls):

13.             if ball:

14.                 x = int(ball.body.position[0])

15.                 y = int(ball.body.position[1])

16.                 t = self.fruits[i].type

17.                 Xcs.append(x/self.WIDTH)

18.                 Ycs.append(y/self.HEIGHT)

19.                 Ts.append(t)

20.         sorted_id = sorted_index(Ycs)

21.         for i, id_ in enumerate(sorted_id):

22.             if i == Keep:

23.                 break

24.             feature_p[i, Ts[id_]] = 1.

25.             feature_p[i, 0] = Xcs[id_]

26.             feature_p[i, -1] = Ycs[id_]

27. 

28.         image = np.concatenate((feature_t, feature_p), axis=0)

29.         return image

注:N_class = 水果类别数 + 1

feature_t:

用于表示当前手中水果类别的ont-hot向量;

feature_p:

用于表示当前游戏状态,大小为(Keep, N_class + 1)

Keep 表示只关注当前位置最高的 Keep 个水果

N_class - 1 是某个水果类别的ont-hot向量, 0 位置为 x 坐标,-1 位置为 y 坐标(归一化)

 

项目复现:

1. 安装依赖库

其中游戏代码使用pygame重构

物理模块使用pymunk

注:paddlepaddle版本为1.8.0,parl版本为1.3.1

1. # !pip install pygame -i https://mirror.baidu.com/pypi/simple

2. # !pip install parl==1.3.1 -i https://mirror.baidu.com/pypi/simple

3. # !pip install pymunk

1. # !unzip work/code.zip -d ./

2. 设置环境变量

由于notebook无法显示pygame界面,所以我们设置如下环境变量

1. import os

2. os.putenv('SDL_VIDEODRIVER''fbcon')

3. os.environ["SDL_VIDEODRIVER"] = "dummy"

3. 构建多层神经网络

该版本使用两层全连接层

卷积神经网络版本为:https://aistudio.baidu.com/aistudio/projectdetail/1540300

1. import parl

2. from parl import layers

3. 

4. class Model(parl.Model):

5.     def __init__(self, act_dim):

6.         hid1_size = 256

7.         hid2_size = 256

8.         # 3层全连接网络

9.         self.fc1 = layers.fc(size=hid1_size, act='relu')

10.         self.fc2 = layers.fc(size=hid2_size, act='relu')

11.         self.fc3 = layers.fc(size=act_dim, act=None)

12. 

13.     def value(self, obs):

14.         # 定义网络

15.         # 输入state,输出所有action对应的Q,[Q(s,a1), Q(s,a2), Q(s,a3)...]

16.         h1 = self.fc1(obs)

17.         h2 = self.fc2(h1)

18.         Q = self.fc3(h2)

19.         return Q

4. 构建DQN算法、Agent和经验池

1. from parl.algorithms import DQN # 也可以直接从parl库中导入DQN算法

2. import pygame

3. 

4. class Agent(parl.Agent):

5.     def __init__(self,

6.                  algorithm,

7.                  obs_dim,

8.                  act_dim,

9.                  e_greed=0.1,

10.                  e_greed_decrement=0):

11.         assert isinstance(obs_dim, int)

12.         assert isinstance(act_dim, int)

13.         self.obs_dim = obs_dim

14.         self.act_dim = act_dim

15.         super(Agentself).__init__(algorithm)

16. 

17.         self.global_step = 0

18.         self.update_target_steps = 200  

19. 

20.         self.e_greed = e_greed  # 有一定概率随机选取动作,探索

21.         self.e_greed_decrement = e_greed_decrement  # 随着训练逐步收敛,探索的程度慢慢降低

22. 

23.     def build_program(self):

24.         self.pred_program = fluid.Program()

25.         self.learn_program = fluid.Program()

26. 

27.         with fluid.program_guard(self.pred_program):  # 搭建计算图用于 预测动作,定义输入输出变量

28.             obs = layers.data(

29.                 name='obs', shape=[self.obs_dim], dtype='float32')

30.             self.value = self.alg.predict(obs)

31. 

32.         with fluid.program_guard(self.learn_program):  # 搭建计算图用于 更新Q网络,定义输入输出变量

33.             obs = layers.data(

34.                 name='obs', shape=[self.obs_dim], dtype='float32')

35.             action = layers.data(name='act', shape=[1], dtype='int32')

36.             reward = layers.data(name='reward', shape=[], dtype='float32')

37.             next_obs = layers.data(

38.                 name='next_obs', shape=[self.obs_dim], dtype='float32')

39.             terminal = layers.data(name='terminal', shape=[], dtype='bool')

40.             self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)

41. 

42.     def sample(self, obs):

43.         sample = np.random.rand()  # 产生0~1之间的小数

44.         if sample < self.e_greed:

45.             act = np.random.randint(self.act_dim)  # 探索:每个动作都有概率被选择

46.         else:

47.             act = self.predict(obs)  # 选择最优动作

48.         self.e_greed = max(

49.             0.01self.e_greed -self.e_greed_decrement)  # 随着训练逐步收敛,探索的程度慢慢降低

50.         return act

51. 

52.     def predict(self, obs):  # 选择最优动作

53.         obs = np.expand_dims(obs, axis=0)

54.         pred_Q = self.fluid_executor.run(

55.             self.pred_program,

56.             feed={'obs': obs.astype('float32')},

57.             fetch_list=[self.value])[0]

58.         pred_Q = np.squeeze(pred_Q, axis=0)

59.         act = np.argmax(pred_Q)  # 选择Q最大的下标,即对应的动作

60.         return act

61. 

62.     def learn(self, obs, act, reward, next_obs, terminal):

63.         # 每隔200个training steps同步一次model和target_model的参数

64.         if self.global_step %self.update_target_steps == 0:

65.             self.alg.sync_target()

66.         self.global_step += 1

67. 

68.         act = np.expand_dims(act, -1)

69.         feed = {

70.             'obs': obs.astype('float32'),

71.             'act': act.astype('int32'),

72.             'reward': reward,

73.             'next_obs': next_obs.astype('float32'),

74.             'terminal': terminal

75.         }

76.         cost = self.fluid_executor.run(

77.             self.learn_program, feed=feed, fetch_list=[self.cost])[0]  # 训练一次网络

78.         return cost

79. 

80. 

81. 

82. class ReplayMemory(object):

83.     def __init__(self, max_size):

84.         self.buffer = collections.deque(maxlen=max_size)

85. 

86.     # 增加一条经验到经验池中

87.     def append(self, exp):

88.         self.buffer.append(exp)

89. 

90.     # 从经验池中选取N条经验出来

91.     def sample(self, batch_size):

92.         mini_batch = random.sample(self.buffer, batch_size)

93.         obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []

94. 

95.         for experience in mini_batch:

96.             s, a, r, s_p, done = experience

97.             obs_batch.append(s)

98.             action_batch.append(a)

99.             reward_batch.append(r)

100.             next_obs_batch.append(s_p)

101.             done_batch.append(done)

102. 

103.         return np.array(obs_batch).astype('float32'), \

104.             np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\

105.             np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32')

106. 

107.     def __len__(self):

108.         return len(self.buffer)

109. 

110. 

111. # 训练一个episode

112. def run_episode(env, agent, rpm, episode):

113.     total_reward = 0

114.     env.reset()

115.     action = np.random.randint(0, env.action_num - 1)

116.     obs, _, _ = env.next(action)

117.     step = 0

118.     while True:

119.         step += 1

120.         action = agent.sample(obs)  # 采样动作,所有动作都有概率被尝试到

121.         next_obs, reward, done = env.next(action)

122.         rpm.append((obs, action, reward, next_obs, done))

123. 

124.         # train model

125.         if (len(rpm) > MEMORY_WARMUP_SIZE)and (step % LEARN_FREQ == 0):

126.             (batch_obs, batch_action, batch_reward, batch_next_obs,

127.              batch_done) = rpm.sample(BATCH_SIZE)

128.             train_loss = agent.learn(batch_obs, batch_action, batch_reward,

129.                                      batch_next_obs,

130.                                      batch_done)  # s,a,r,s',done

131. 

132.         total_reward += reward

133.         obs = next_obs

134.         if done:

135.             break

136.         if not step % 20:

137.             logger.info('step:{} e_greed:{} action:{} reward:{}'.format(

138.                 step, agent.e_greed, action, reward))

139.         if not step % 500:

140.             image = pygame.surfarray.array3d(

141.                  pygame.display.get_surface()).copy()

142.             image = np.flip(image[:, :, [21,0]], 0)

143.             image = np.rot90(image, 3)

144.             img_pt = os.path.join('outputs','snapshoot_{}_{}.jpg'.format(episode, step))

145.             cv2.imwrite(img_pt, image)

146.     return total_reward

1. pygame 2.0.1 (SDL 2.0.14Python 3.7.4)

2. Hello from the pygame community. https://www.pygame.org/contribute.html

5. 创建游戏和Agent实例

1. from State2NN import AI_Board

2. 

3. env = AI_Board()  

4. action_dim = env.action_num  

5. obs_shape = 16 * 13  

6. e_greed = 0.2

7. 

8. rpm = ReplayMemory(MEMORY_SIZE)  # DQN的经验回放池

9. 

10. # 根据parl框架构建agent

11. model = Model(act_dim=action_dim)

12. algorithm = DQN(model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE)

13. agent = Agent(

14.     algorithm,

15.     obs_dim=obs_shape,

16.     act_dim=action_dim,

17.     e_greed=e_greed,  # 有一定概率随机选取动作,探索

18.     e_greed_decrement=1e-6)  # 随着训练逐步收敛,探索的程度慢慢降低

6. 训练模型

1. from State2NN import AI_Board

2. import os

3. 

4. dirs = ['weights''outputs']

5. for d in dirs:

6.     if not os.path.exists(d):

7.         os.mkdir(d)

8. 

9. # 先往经验池里存一些数据,避免最开始训练的时候样本丰富度不够

10. while len(rpm) < MEMORY_WARMUP_SIZE:

11.     run_episode(env, agent, rpm, episode=0)

12. 

13. max_episode = 2000

14. 

15. # 开始训练

16. episode = 0

17. while episode < max_episode:  # 训练max_episode个回合,test部分不计算入episode数量

18.     # train part

19.     for i in range(050):

20.         total_reward = run_episode(env, agent, rpm, episode+1)

21.         episode += 1

22.         save_path ='./weights/dqn_model_episode_{}.ckpt'.format(episode)

23.         agent.save(save_path)

24.         print('-[INFO] episode:{}, model saved at {}'.format(episode, save_path))

25.         env.reset()

26. 

27. # 训练结束,保存模型

28. save_path = './final.ckpt'

29. agent.save(save_path)

 

https://mp.weixin.qq.com/s/3bvcI18bnVNXGC7cTMiPIA