项目简介
基于paddlepaddle复现了深度强化学习领域的DQN模型,在经典的Atari游戏上复现了论文同等水平的指标,模型接收游戏的图像作为输入,采用端到端的模型直接预测下一步要执行的控制信号,本项目需要在GPU环境下运行,论文原文:Human-level Control Through Deep Reinforcement Learning
问题提出:
强化学习中智能体(agent)在面对复杂外界环境时的局限性:他们需要从高维空间学习到输入数据的有效表征,并应用到新环境中,这是个富有挑战性的任务。
传统的强化学习的局限性:只能用在人工提取特征的有效表征(handcrafted feature representations)的情景,或用于可观测的低维状态空间。
该论文针对以上问题,将强化学习策略与深度卷积神经网络(CNN)结合起来,提出一种DeepQ-Network(DQN),它能够从输入的高维感知空间直接学习到知识,并通过端到端的强化学习策略(end-to-end)对智能体(agent)进行训练。该研究的两大亮点:
针对传统强化学习出现的不稳定甚至偏差现象,本文提出两种方式来解决这种不稳定问题。
a) 使用经验回放(experience replay)来去除观测值之间的相关性(removecorrelations),平滑数据分布(smooth over data distribution)并对数据进行随机抽样(randomover data).
b) 对目标参数进行周期性更新(Q-learning updates),以减少动作值Q与目标值之间的相关性。
用深度卷积神经网络(CNN)作为动作-值函数(action-value function)。利用卷积神经网络的层次特征表征及模拟生物机制的特点,通过观察状态s下的动作a,以折减系数γ计算出当前状态的value,最终的目标是最大化未来的累计奖励(maximize the cumulative future reward)。
参考博客链接:Human-Level Control Through Deep Reinforcement
Learning论文解读
可以通过pip install -r requirement.txt来安装一些常规的依赖
可以通过pip install git+https://github.com/Kojoley/atari-py.git 来安装gym[atari](为了方便,本项目在work目录下预置了gym[atari]的源码,可通过Python setup.py install的方式直接安装)
文件结构
|--rom_files # 用于存放Atari游戏的rom文件
|--saved_model # 用于存放训练过程中表现最好的模型参数
|--result # 用于存放测试阶段的游戏可视化结果
|--play.py # 用于玩游戏并计算平均奖励,以及保存可视化的结果
|--expreplay.py # 经验池,每次训练都会从经验池sample一部分样本
|--DQN_agent.py # 网络配置
|--atari_wrapper.py # atari封装
|--train.py # 训练脚本
|--atari.py # 获取state, r, isOver, info
|--requirement.txt # 环境依赖
模型结构
使用CNN网络来作为Q的计算函数,网络的框架如下图所示:
如上图所示,网络的输入为预处理过后的4个连续的84* 84的图像,然后经过两个卷积层,两个全链接层,最后输出一个包含每个动作Q值的向量。上图中的网络结果比较简单,仅仅起到说明作用,实际中可根据问题需要自行设计。
# 安装程序运行所需要的库
!pip install -r requirement.txt
!cd work && tar zxf atari-py-1.2.1.tar.gz
!cd work/atari-py-1.2.1/ && python setup.py install >/dev/null 2>&1
DEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won’t be maintained after that date. A future version of pip will drop support for Python 2.7.
Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/
Requirement already satisfied: numpy in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from -r requirement.txt (line 1)) (1.16.2)
Requirement already satisfied: gym in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from -r requirement.txt (line 2)) (0.12.1)
Requirement already satisfied: tqdm in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from -r requirement.txt (line 3)) (4.36.1)
Requirement already satisfied: opencv-python in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from -r requirement.txt (line 4)) (4.0.1.23)
Requirement already satisfied: paddlepaddle-gpu>=1.0.0 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from -r requirement.txt (line 5)) (1.5.1.post97)
Requirement already satisfied: requests>=2.0 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from gym->-r requirement.txt (line 2)) (2.22.0)
Requirement already satisfied: six in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from gym->-r requirement.txt (line 2)) (1.12.0)
Requirement already satisfied: scipy in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from gym->-r requirement.txt (line 2)) (1.2.1)
Requirement already satisfied: pyglet>=1.2.0 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from gym->-r requirement.txt (line 2)) (1.3.2)
Requirement already satisfied: pyyaml in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (5.1)
Requirement already satisfied: funcsigs in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (1.0.2)
Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (3.7.1)
Requirement already satisfied: graphviz in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (0.10.1)
Requirement already satisfied: recordio>=0.1.0 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (0.1.5)
Requirement already satisfied: nltk<=3.4,>=3.2.2; python_version < “3.5” in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (3.4)
Requirement already satisfied: matplotlib<=2.2.4; python_version < “3.6” in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (2.2.3)
Requirement already satisfied: Pillow in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (6.0.0)
Requirement already satisfied: decorator in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (4.4.0)
Requirement already satisfied: rarfile in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (3.0)
Requirement already satisfied: prettytable in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (0.7.2)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from requests>=2.0->gym->-r requirement.txt (line 2)) (3.0.4)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from requests>=2.0->gym->-r requirement.txt (line 2)) (2.8)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from requests>=2.0->gym->-r requirement.txt (line 2)) (1.25.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from requests>=2.0->gym->-r requirement.txt (line 2)) (2019.3.9)
Requirement already satisfied: future in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from pyglet>=1.2.0->gym->-r requirement.txt (line 2)) (0.17.1)
Requirement already satisfied: setuptools in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from protobuf>=3.1.0->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (41.0.0)
Requirement already satisfied: singledispatch in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from nltk<=3.4,>=3.2.2; python_version < “3.5”->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (3.4.0.3)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from matplotlib<=2.2.4; python_version < “3.6”->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (2.8.0)
Requirement already satisfied: subprocess32 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from matplotlib<=2.2.4; python_version < “3.6”->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (3.5.3)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from matplotlib<=2.2.4; python_version < “3.6”->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (0.10.0)
Requirement already satisfied: backports.functools-lru-cache in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from matplotlib<=2.2.4; python_version < “3.6”->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (1.5)
Requirement already satisfied: pytz in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from matplotlib<=2.2.4; python_version < “3.6”->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (2018.9)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from matplotlib<=2.2.4; python_version < “3.6”->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (2.4.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from matplotlib<=2.2.4; python_version < “3.6”->paddlepaddle-gpu>=1.0.0->-r requirement.txt (line 5)) (1.0.1)
训练模型
使用『python train.py --rom ./rom_files/pong.bin --use_cuda --alg DQN』训练模型,如下所示。 如果需要训练更多的游戏,可以在此下载rom文件:游戏rom链接
该模型训练耗时较长,需要的训练代数较多,总的迭代次数如果设置的较低的话可能模型的效果较差,进过一定的训练后对于pong游戏平均奖励可以达到21
可以通过『–test_every_steps』参数来设置每隔多少个step进行一次测试,默认值为100000
测试阶段默认进行30次游戏,并计算30次游戏的平均奖励并输出
!python train.py --rom ./rom_files/pong.bin --use_cuda --alg DQN --total_step 10000 --test_every_steps 5000
W1202 13:38:37.191267 573 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W1202 13:38:37.201860 573 device_context.cc:267] device: 0, cuDNN Version: 7.3.
Memory warmup: 50074it [00:56, 892.81it/s]
[train]exploration:1.048905: 10%|█ | 1021/10000 [00:08<01:12, 124.54it/s]testing
[train]exploration:1.048905: 10%|█ | 1021/10000 [00:20<01:12, 124.54it/s]eval_agent done, (steps, eval_reward): (1022, -21.0)
[train]exploration:1.044461: 55%|█████▍ | 5465/10000 [02:33<01:15, 60.20it/s]testing
[train]exploration:1.044461: 55%|█████▍ | 5465/10000 [02:50<01:15, 60.20it/s]eval_agent done, (steps, eval_reward): (5466, -21.0)
[train]exploration:1.039241: : 10685it [05:08, 34.62it/s]
测试模型:
使用训练过程中保存的最好模型玩游戏,以及计算平均奖励(rewards)
python play.py --rom ./rom_files/pong.bin --use_cuda --model_path saved_model/DQN-pong --viz 0.01
该命令可以通过可视化的形式玩游戏,由于AI Studio的环境不支持实时的视频展示,所以将玩游戏的过程逐帧保存在result目录下,并且生成对应的avi格式的视频。具体的游戏过程可以通过将视频下载到本地观看,在游戏结束后,会在命令行显示本次游戏获得的奖励值。
!python play.py --rom ./rom_files/pong.bin --use_cuda --model_path saved_model/DQN-pong --viz 0.01
W1202 13:53:07.896930 612 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W1202 13:53:07.901751 612 device_context.cc:267] device: 0, cuDNN Version: 7.3.
eval agent: 100%|██████████| 1/1 [00:05<00:00, 5.06s/it]
Average reward of epidose: -21.0
来源:CSDN
作者:PaddlePaddle开发者
链接:https://blog.csdn.net/PaddleLover/article/details/103463789