A3C——一种异步强化学习方法

 ̄綄美尐妖づ 提交于 2020-10-31 06:35:22

目录

1、简介2、算法细节3、代码3.1 主结构3.2 Actor Critic 网络3.3 Worker3.4 Worker并行工作4、参考

1、简介

A3C是Google DeepMind 提出的一种解决Actor-Critic不收敛问题的算法。我们知道DQN中很重要的一点是他具有经验池,可以降低数据之间的相关性,而A3C则提出降低数据之间的相关性的另一种方法:异步

简单来说:A3C会创建多个并行的环境, 让多个拥有副结构的 agent 同时在这些并行环境上更新主结构中的参数. 并行中的 agent 们互不干扰, 而主结构的参数更新受到副结构提交更新的不连续性干扰, 所以更新的相关性被降低, 收敛性提高.

2、算法细节

A3C的算法实际上就是将Actor-Critic放在了多个线程中进行同步训练. 可以想象成几个人同时在玩一样的游戏, 而他们玩游戏的经验都会同步上传到一个中央大脑. 然后他们又从中央大脑中获取最新的玩游戏方法。

这样, 对于这几个人, 他们的好处是: 中央大脑汇集了所有人的经验, 是最会玩游戏的一个, 他们能时不时获取到中央大脑的必杀招, 用在自己的场景中.

对于中央大脑的好处是: 中央大脑最怕一个人的连续性更新, 不只基于一个人推送更新这种方式能打消这种连续性. 使中央大脑不必像DQN,DDPG那样的记忆库也能很好的更新。


为了达到这个目的,我们要有两套体系, 可以看作中央大脑拥有 global net和他的参数, 每位玩家有一个 global net的副本 local net, 可以定时向 global net推送更新, 然后定时从 global net那获取综合版的更新.

如果在 tensorboard 中查看我们今天要建立的体系, 这就是你会看到的。

W_0就是第0个 worker, 每个 worker都可以分享 global_net


如果我们调用 sync中的 pull, 这个 worker就会从 global_net中获取到最新的参数.


如果我们调用sync中的push, 这个worker就会将自己的个人更新推送去global_net.

3、代码

这次我们也是使用连续动作环境Pendulum做例子。

3.1 主结构


我们使用了 Normal distribution 来选择动作, 所以在搭建神经网络的时候, actor这边要输出动作的均值和方差. 然后放入 Normal distribution 去选择动作. 计算 actor loss的时候我们还需要使用到 critic提供的 TD error作为 gradient ascent 的导向.

critic只需要得到他对于 state的价值就好了. 用于计算 TD error.

3.2 Actor Critic 网络

这里因为代码有点多,有些部分会使用伪代码,完整代码最后会附上链接。

我们将ActorCritic合并成一整套系统, 这样方便运行.

 1# 这个 class 可以被调用生成一个 global net.
2# 也能被调用生成一个 worker 的 net, 因为他们的结构是一样的,
3# 所以这个 class 可以被重复利用.
4class ACNet(object):
5    def __init__(self, globalAC=None):
6        # 当创建 worker 网络的时候, 我们传入之前创建的 globalAC 给这个 worker
7        if 这是 global:   # 判断当下建立的网络是 local 还是 global
8            with tf.variable_scope('Global_Net'):
9                self._build_net()
10        else:
11            with tf.variable_scope('worker'):
12                self._build_net()
13
14            # 接着计算 critic loss 和 actor loss
15            # 用这两个 loss 计算要推送的 gradients
16
17            with tf.name_scope('sync'):  # 同步
18                with tf.name_scope('pull'):
19                    # 更新去 global
20                with tf.name_scope('push'):
21                    # 获取 global 参数
22
23    def _build_net(self):
24        # 在这里搭建 Actor 和 Critic 的网络
25        return 均值, 方差, state_value
26
27    def update_global(self, feed_dict):
28        # 进行 push 操作
29
30    def pull_global(self):
31        # 进行 pull 操作
32
33    def choose_action(self, s):
34        # 根据 s 选动作

这些只是在创建网络而已,worker还有属于自己的class, 用来执行在每个线程里的工作.

3.3 Worker

每个worker有自己的class, class 里面有他的工作内容work

 1class Worker(object):
2    def __init__(self, name, globalAC):
3        self.env = gym.make(GAME).unwrapped # 创建自己的环境
4        self.name = name    # 自己的名字
5        self.AC = ACNet(name, globalAC) # 自己的 local net, 并绑定上 globalAC
6
7    def work(self):
8        # s, a, r 的缓存, 用于 n_steps 更新
9        buffer_s, buffer_a, buffer_r = [], [], []
10        while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
11            s = self.env.reset()
12
13            for ep_t in range(MAX_EP_STEP):
14                a = self.AC.choose_action(s)
15                s_, r, done, info = self.env.step(a)
16
17                buffer_s.append(s)  # 添加各种缓存
18                buffer_a.append(a)
19                buffer_r.append(r)
20
21                # 每 UPDATE_GLOBAL_ITER 步 或者回合完了, 进行 sync 操作
22                if total_step % UPDATE_GLOBAL_ITER == 0 or done:
23                    # 获得用于计算 TD error 的 下一 state 的 value
24                    if done:
25                        v_s_ = 0   # terminal
26                    else:
27                        v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[00]
28
29                    buffer_v_target = []    # 下 state value 的缓存, 用于算 TD
30                    for r in buffer_r[::-1]:    # 进行 n_steps forward view
31                        v_s_ = r + GAMMA * v_s_
32                        buffer_v_target.append(v_s_)
33                    buffer_v_target.reverse()
34
35                    buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(buffer_v_target)
36
37                    feed_dict = {
38                        self.AC.s: buffer_s,
39                        self.AC.a_his: buffer_a,
40                        self.AC.v_target: buffer_v_target,
41                    }
42
43                    self.AC.update_global(feed_dict)    # 推送更新去 globalAC
44                    buffer_s, buffer_a, buffer_r = [], [], []   # 清空缓存
45                    self.AC.pull_global()   # 获取 globalAC 的最新参数
46
47                s = s_
48                if done:
49                    GLOBAL_EP += 1  # 加一回合
50                    break   # 结束这回合

3.4 Worker并行工作

这里是重点,也就是Worker并行工作的计算

 1    GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE)  # 建立 Global AC
2    workers = []
3    for i in range(N_WORKERS):  # 创建 worker, 之后在并行
4        workers.append(Worker(GLOBAL_AC))   # 每个 worker 都有共享这个 global AC
5
6COORD = tf.train.Coordinator()  # Tensorflow 用于并行的工具
7
8worker_threads = []
9for worker in workers:
10    job = lambda: worker.work()
11    t = threading.Thread(target=job)    # 添加一个工作线程
12    t.start()
13    worker_threads.append(t)
14COORD.join(worker_threads)  # tf 的线程调度

电脑里CPU有几个核就可以建立多少个worker, 也就可以把它们放在CPU核数个线程中并行探索更新. 最后的学习结果可以用这个获取 moving average 的 reward 的图来概括.

完整代码链接:

https://github.com/cristianoc20/RL_learning/tree/master/A3C

4、参考

  1. https://medium.com/emergent-future/simple-reinforcement-learning-with-tensorflow-part-8-asynchronous-actor-critic-agents-a3c-c88f72a5e9f2

  2. https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/6-3-A3C/


本文分享自微信公众号 - 计算机视觉漫谈()。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!