diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 9fb1d012..64de484f 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,51 +1,45 @@
-## 如何参与OpenRL的建设
+## How to Contribute to OpenRL
-OpenRL社区欢迎任何人参与到OpenRL的建设中来,无论您是开发者还是用户,您的反馈和贡献都是我们前进的动力!
-您可以通过以下方式加入到OpenRL的贡献中来:
+The OpenRL community welcomes anyone to contribute to the development of OpenRL, whether you are a developer or a user. Your feedback and contributions are our driving force! You can join the contribution of OpenRL in the following ways:
-- 作为OpenRL的用户,发现OpenRL中的bug,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。
-- 作为OpenRL的用户,发现OpenRL文档中的错误,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。
-- 写测试代码,提升OpenRL的代码测试覆盖率(大家可以从[这里](https://app.codecov.io/gh/OpenRL-Lab/openrl)查到OpenRL的代码测试覆盖情况)。
- 您可以选择感兴趣的代码片段进行编写代码测试,
-- 作为OpenRL的开发者,为OpenRL修复已有的bug。
-- 作为OpenRL的开发者,为OpenRL添加新的环境和样例。
-- 作为OpenRL的开发者,为OpenRL添加新的算法。
+- As an OpenRL user, discover bugs in OpenRL and submit an [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose).
+- As an OpenRL user, discover errors in the documentation of OpenRL and submit an [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose).
+- Write test code to improve the code coverage of OpenRL (you can check the code coverage situation of OpenRL from [here](https://app.codecov.io/gh/OpenRL-Lab/openrl)). You can choose interested code snippets for writing test codes.
+- As an open-source developer, fix existing bugs for OpenRL.
+- As an open-source developer, add new environments and examples for OpenRL.
+- As an open-source developer, add new algorithms for OpenRL.
-## 贡献者手册
+## Contributing to OpenRL
-欢迎更多的人参与到OpenRL的开发中来,我们非常欢迎您的贡献!
+Welcome to contribute to the development of OpenRL. We appreciate your contribution!
-- 如果您想要贡献新的功能,请先在请先创建一个新的[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose),
- 以便我们讨论这个功能的实现细节。如果该功能得到了大家的认可,您可以开始进行代码实现。
-- 您也可以在 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 中查看未被实现的功能和仍然存的在bug,
-在对应的issue中进行回复,说明您想要解决该issue,然后开始进行代码实现。
+- If you want to contribute new features, please create a new [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose) first
+to discuss the implementation details of this feature. If the feature is approved by everyone, you can start implementing the code.
+- You can also check for unimplemented features and existing bugs in [Issues](https://github.com/OpenRL-Lab/openrl/issues),
+reply in the corresponding issue that you want to solve it, and then start implementing the code.
-在您完成了代码实现之后,您需要拉取最新的`main`分支并进行合并。
-解决合并冲突后,
-您可以通过提交 [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls)
-的方式将您的代码合并到OpenRL的main分支中。
+After completing your code implementation, you need to pull the latest `main` branch and merge it.
+After resolving any merge conflicts,
+you can submit your code for merging into OpenRL's main branch through [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls).
-在提交Pull Request前,您需要完成 [代码测试和代码格式化](#代码测试和代码格式化)。
+Before submitting a Pull Request, you need to complete [Code Testing and Code Formatting](#code-testing-and-code-formatting).
-然后,您的Pull Request需要通过GitHub上的自动化测试。
+Then, your Pull Request needs to pass automated testing on GitHub.
-最后,需要得到至少一个开发人员的review和批准,才能被合并到main分支中。
+Finally, at least one maintainer's review and approval are required before being merged into the main branch.
-## 代码测试和代码格式化
+## Code Testing and Code Formatting
-在您提交Pull Request之前,您需要确保您的代码通过了单元测试,并且符合OpenRL的代码风格。
+Before submitting a Pull Request, make sure that your code passes unit tests and conforms with OpenRL's coding style.
-首先,您需要安装测试相关的包:`pip install -e ".[test]"`
+Firstly, you should install the test-related packages: `pip install -e ".[test]"`
-然后,您需要确保单元测试通过,这可以通过执行`make test`来完成。
-
-然后,您需要执行`make format`来格式化您的代码。
-
-最后,您需要执行`make commit-checks`来检查您的代码是否符合OpenRL的代码风格。
-
-> 小技巧: OpenRL使用 [black](https://github.com/psf/black) 代码风格。
-您可以在您的编辑器中安装black的[插件](https://black.readthedocs.io/en/stable/integrations/editors.html),
-来帮助您自动格式化代码。
+Then, ensure that unit tests pass by executing `make test`.
+Next, format your code by running `make format`.
+Lastly, run `make commit-checks` to check if your code complies with OpenRL's coding style.
+> Tip: OpenRL uses [black](https://github.com/psf/black) coding style.
+You can install black plugins in your editor as shown in the [official website](https://black.readthedocs.io/en/stable/integrations/editors.html)
+to help automatically format codes.
\ No newline at end of file
diff --git a/Gallery.md b/Gallery.md
index 37a653a7..bf8d22ef 100644
--- a/Gallery.md
+++ b/Gallery.md
@@ -6,6 +6,7 @@



+
 (Discrete Action Space)
@@ -15,6 +16,17 @@
 (Imitation Learning or Supervised Learning)
+## Algorithm List
+
+
+
+| Algorithm | Tags | Refs |
+|:-----------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------:|:-------------------------------:|
+| [PPO](https://arxiv.org/abs/1707.06347) |  | [code](./examples/cartpole/) |
+| [MAPPO](https://arxiv.org/abs/2103.01955) |  | [code](./examples/mpe/) |
+| [JRPO](https://arxiv.org/abs/2302.07515) |  | [code](./examples/mpe/) |
+| [MAT](https://arxiv.org/abs/2205.14953) |  | [code](./examples/mpe/) |
+
## Demo List
diff --git a/README.md b/README.md
index cf4e453a..847470ab 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-

+
---
@@ -12,7 +12,7 @@
[](https://hitsofcode.com/github/OpenRL-Lab/openrl/view?branch=main)
-[](https://codecov.io/gh/OpenRL-Lab/openrl)
+[](https://codecov.io/gh/OpenRL-Lab/openrl_release)
[](https://openrl-docs.readthedocs.io/zh/latest/?badge=latest)
[](https://openrl-docs.readthedocs.io/zh/latest/)
@@ -33,201 +33,235 @@ OpenRL-v0.0.11 is updated on May 19, 2023
The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with OpenRL, you can switch to the stable branch.
-## 欢迎来到OpenRL
-
-[English](./README_en.md) | [中文文档](https://openrl-docs.readthedocs.io/zh/latest/) | [Documentation](https://openrl-docs.readthedocs.io/en/latest/)
-
-OpenRL是一个开源的通用强化学习研究框架,支持单智能体、多智能体、自然语言等多种任务的训练。 OpenRL基于PyTorch进行开发,目标是为强化学习研究社区提供一个简单易用、灵活高效、可持续扩展的平台。
-目前,OpenRL支持的特性包括:
-
-- 简单易用且支持单智能体、多智能体训练的通用接口
-- 支持自然语言任务(如对话任务)的强化学习训练
-- 支持从[Hugging Face](https://huggingface.co/)上导入模型和数据
-- 支持LSTM,GRU,Transformer等模型
-- 支持多种训练加速,例如:自动混合精度训练,半精度策略网络收集数据等
-- 支持用户自定义训练模型、奖励模型、训练数据以及环境
-- 支持[gymnasium](https://gymnasium.farama.org/)环境
-- 支持字典观测空间
-- 支持[wandb](https://wandb.ai/),[tensorboardX](https://tensorboardx.readthedocs.io/en/latest/index.html)等主流训练可视化工具
-- 支持环境的串行和并行训练,同时保证两种模式下的训练效果一致
-- 中英文文档
-- 提供单元测试和代码覆盖测试
-- 符合Black Code Style和类型检查
-
-该框架经过了[OpenRL-Lab](https://github.com/OpenRL-Lab)的多次迭代并应用于学术研究,目前已经成为了一个成熟的强化学习框架。
-OpenRL-Lab将持续维护和更新OpenRL,欢迎大家加入我们的[开源社区](./CONTRIBUTING.md),一起为强化学习的发展做出贡献。
-关于OpenRL的更多信息,请参考[文档](https://openrl-docs.readthedocs.io/zh/latest/)。
-
-## 目录
-
-- [欢迎来到OpenRL](#欢迎来到openrl)
-- [目录](#目录)
-- [安装](#安装)
-- [使用Docker](#使用docker)
-- [快速上手](#快速上手)
-- [Gallery](#Gallery)
-- [使用OpenRL的项目](#使用OpenRL的项目)
-- [反馈和贡献](#反馈和贡献)
-- [维护人员](#维护人员)
-- [支持者](#支持者)
- - [↳ Contributors](#-contributors)
+## Welcome to OpenRL
+
+[中文介绍](README_zh.md) | [Documentation](https://openrl-docs.readthedocs.io/en/latest/) | [中文文档](https://openrl-docs.readthedocs.io/zh/latest/)
+
+OpenRL is an open-source general reinforcement learning research framework that supports training for various tasks
+such as single-agent, multi-agent, and natural language.
+Developed based on PyTorch, the goal of OpenRL is to provide a simple-to-use, flexible, efficient and sustainable platform for the reinforcement learning research community.
+
+Currently, the features supported by OpenRL include:
+
+- A simple-to-use universal interface that supports training for both single-agent and multi-agent
+
+- Reinforcement learning training support for natural language tasks (such as dialogue)
+
+- Importing models and datasets from [Hugging Face](https://huggingface.co/)
+
+- Support for models such as LSTM, GRU, Transformer etc.
+
+- Multiple training acceleration methods including automatic mixed precision training and data collecting wth half precision policy network
+
+- User-defined training models, reward models, training data and environment support
+
+- Support for [gymnasium](https://gymnasium.farama.org/) environments
+
+- Dictionary observation space support
+
+- Popular visualization tools such as [wandb](https://wandb.ai/), [tensorboardX](https://tensorboardx.readthedocs.io/en/latest/index.html) are supported
+
+- Serial or parallel environment training while ensuring consistent results in both modes
+
+- Chinese and English documentation
+
+- Provides unit testing and code coverage testing
+
+- Compliant with Black Code Style guidelines and type checking
+
+Algorithms currently supported by OpenRL (for more details, please refer to [Gallery](./Gallery.md)):
+- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
+- [Multi-agent PPO (MAPPO)](https://arxiv.org/abs/2103.01955)
+- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
+- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)
+
+Environments currently supported by OpenRL (for more details, please refer to [Gallery](./Gallery.md)):
+- [Gymnasium](https://gymnasium.farama.org/)
+- [MPE](https://github.com/openai/multiagent-particle-envs)
+- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
+- [Gym Retro](https://github.com/openai/retro)
+
+This framework has undergone multiple iterations by the [OpenRL-Lab](https://github.com/OpenRL-Lab) team which has applied it in academic research.
+It has now become a mature reinforcement learning framework.
+
+OpenRL-Lab will continue to maintain and update OpenRL, and we welcome everyone to join our [open-source community](./CONTRIBUTING.md)
+to contribute towards the development of reinforcement learning.
+
+For more information about OpenRL, please refer to the [documentation](https://openrl-docs.readthedocs.io/en/latest/).
+
+## Outline
+
+- [Welcome to OpenRL](#welcome-to-openrl)
+- [Outline](#outline)
+- [Installation](#installation)
+- [Use Docker](#use-docker)
+- [Quick Start](#quick-start)
+- [Gallery](#gallery)
+- [Projects Using OpenRL](#projects-using-openrl)
+- [Feedback and Contribution](#feedback-and-contribution)
+- [Maintainers](#maintainers)
+- [Supporters](#supporters)
+ - [↳ Contributors](#-contributors)
- [↳ Stargazers](#-stargazers)
- [↳ Forkers](#-forkers)
- [Citing OpenRL](#citing-openrl)
- [License](#license)
- [Acknowledgments](#acknowledgments)
-## 安装
+## Installation
+
+Users can directly install OpenRL via pip:
-用户可以直接通过pip安装OpenRL:
```bash
pip install openrl
```
-如果用户使用了Anaconda或者Miniconda,也可以通过conda安装OpenRL:
+If users are using Anaconda or Miniconda, they can also install OpenRL via conda:
+
```bash
conda install -c openrl openrl
```
-想要修改源码的用户也可以从源码安装OpenRL:
+Users who want to modify the source code can also install OpenRL from the source code:
+
```bash
git clone https://github.com/OpenRL-Lab/openrl.git && cd openrl
pip install -e .
```
-安装完成后,用户可以直接通过命令行查看OpenRL的版本:
+After installation, users can check the version of OpenRL through command line:
+
```bash
openrl --version
```
-**Tips**:无需安装,通过Colab在线试用OpenRL: [](https://colab.research.google.com/drive/15VBA-B7AJF8dBazzRcWAxJxZI7Pl9m-g?usp=sharing)
-
-## 使用Docker
+**Tips**: No installation required, try OpenRL online through Colab: [](https://colab.research.google.com/drive/15VBA-B7AJF8dBazzRcWAxJxZI7Pl9m-g?usp=sharing)
-OpenRL目前也提供了包含显卡支持和非显卡支持的Docker镜像。
-如果用户的电脑上没有英伟达显卡,则可以通过以下命令获取不包含显卡插件的镜像:
+## Use Docker
+OpenRL currently provides Docker images with and without GPU support.
+If the user's computer does not have an NVIDIA GPU, they can obtain an image without the GPU plugin using the following command:
```bash
sudo docker pull openrllab/openrl-cpu
```
-如果用户想要通过显卡加速训练,则可以通过以下命令获取:
+If the user wants to accelerate training with a GPU, they can obtain it using the following command:
```bash
sudo docker pull openrllab/openrl
```
-镜像拉取成功后,用户可以通过以下命令运行OpenRL的Docker镜像:
+After successfully pulling the image, users can run OpenRL's Docker image using the following commands:
```bash
-# 不带显卡加速
+# Without GPU acceleration
sudo docker run -it openrllab/openrl-cpu
-# 带显卡加速
+# With GPU acceleration
sudo docker run -it --gpus all --net host openrllab/openrl
```
-进入Docker镜像后,用户可以通过以下命令查看OpenRL的版本然后运行测例:
-```bash
-# 查看Docker镜像中OpenRL的版本
-openrl --version
-# 运行测例
-openrl --mode train --env CartPole-v1
+Once inside the Docker container, users can check OpenRL's version and then run test cases using these commands:
+```bash
+# Check OpenRL version in Docker container
+openrl --version
+# Run test case
+openrl --mode train --env CartPole-v1
```
+## Quick Start
-## 快速上手
-
-OpenRL为强化学习入门用户提供了简单易用的接口,
-下面是一个使用PPO算法训练`CartPole`环境的例子:
+OpenRL provides a simple and easy-to-use interface for beginners in reinforcement learning.
+Below is an example of using the PPO algorithm to train the `CartPole` environment:
```python
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
-env = make("CartPole-v1", env_num=9) # 创建环境,并设置环境并行数为9
-net = Net(env) # 创建神经网络
-agent = Agent(net) # 初始化智能体
-agent.train(total_time_steps=20000) # 开始训练,并设置环境运行总步数为20000
+env = make("CartPole-v1", env_num=9) # Create an environment and set the environment parallelism to 9.
+net = Net(env) # Create neural network.
+agent = Agent(net) # Initialize the agent.
+agent.train(total_time_steps=20000) # Start training and set the total number of steps to 20,000 for the running environment.
```
-使用OpenRL训练智能体只需要简单的四步:**创建环境**=>**初始化模型**=>**初始化智能体**=>**开始训练**!
+Training an agent using OpenRL only requires four simple steps:
+**Create Environment** => **Initialize Model** => **Initialize Agent** => **Start Training**!
-对于训练好的智能体,用户也可以方便地进行智能体的测试:
+For a well-trained agent, users can also easily test the agent:
```python
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
-agent = Agent(Net(make("CartPole-v1", env_num=9))) # 初始化训练器
+agent = Agent(Net(make("CartPole-v1", env_num=9))) # Initialize trainer.
agent.train(total_time_steps=20000)
-# 创建用于测试的环境,并设置环境并行数为9,设置渲染模式为group_human
+# Create an environment for test, set the parallelism of the environment to 9, and set the rendering mode to group_human.
env = make("CartPole-v1", env_num=9, render_mode="group_human")
-agent.set_env(env) # 训练好的智能体设置需要交互的环境
-obs, info = env.reset() # 环境进行初始化,得到初始的观测值和环境信息
+agent.set_env(env) # The agent requires an interactive environment.
+obs, info = env.reset() # Initialize the environment to obtain initial observations and environmental information.
while True:
- action, _ = agent.act(obs) # 智能体根据环境观测输入预测下一个动作
- # 环境根据动作执行一步,得到下一个观测值、奖励、是否结束、环境信息
+ action, _ = agent.act(obs) # The agent predicts the next action based on environmental observations.
+ # The environment takes one step according to the action, obtains the next observation, reward, whether it ends and environmental information.
obs, r, done, info = env.step(action)
if any(done): break
-env.close() # 关闭测试环境
+env.close() # Close test environment
```
-在普通笔记本电脑上执行以上代码,只需要几秒钟,便可以完成该智能体的训练和可视化测试:
+Executing the above code on a regular laptop only takes a few seconds
+to complete the training. Below shows the visualization of the agent:
-

+
-**Tips:** 用户还可以在终端中通过执行一行命令快速训练`CartPole`环境:
+**Tips:** Users can also quickly train the `CartPole` environment by executing a command line in the terminal.
```bash
openrl --mode train --env CartPole-v1
```
-对于多智能体、自然语言等任务的训练,OpenRL也提供了同样简单易用的接口。
+For training tasks such as multi-agent and natural language processing, OpenRL also provides a similarly simple and easy-to-use interface.
-关于如何进行多智能体训练、训练超参数设置、训练配置文件加载、wandb使用、保存gif动画等信息,请参考:
-- [多智能体训练例子](https://openrl-docs.readthedocs.io/zh/latest/quick_start/multi_agent_RL.html)
+For information on how to perform multi-agent training, set hyperparameters for training, load training configurations, use wandb, save GIF animations, etc., please refer to:
+- [Multi-Agent Training Example](https://openrl-docs.readthedocs.io/en/latest/quick_start/multi_agent_RL.html)
-关于自然语言任务训练、Hugging Face上模型(数据)加载、自定义训练模型(奖励模型)等信息,请参考:
-- [对话任务训练例子](https://openrl-docs.readthedocs.io/zh/latest/quick_start/train_nlp.html)
+For information on natural language task training, loading models/datasets on Hugging Face, customizing training models/reward models, etc., please refer to:
+- [Dialogue Task Training Example](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html)
-关于OpenRL的更多信息,请参考[文档](https://openrl-docs.readthedocs.io/zh/latest/)。
+For more information about OpenRL, please refer to the [documentation](https://openrl-docs.readthedocs.io/en/latest/).
## Gallery
-为了方便用户熟悉该框架,
-我们在[Gallery](./Gallery.md)中提供了更多使用OpenRL的示例和demo。
-也欢迎用户将自己的训练示例和demo贡献到Gallery中。
+In order to facilitate users' familiarity with the framework, we provide more examples and demos of using OpenRL in [Gallery](./Gallery.md).
+Users are also welcome to contribute their own training examples and demos to the Gallery.
-## 使用OpenRL的研究项目
+## Projects Using OpenRL
-我们在 [OpenRL Project](./Project.md) 中列举了使用OpenRL的研究项目。
-如果你在研究项目中使用了OpenRL,也欢迎加入该列表。
-
-## 反馈和贡献
-- 有问题和发现bugs可以到 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 处进行查询或提问
-- 加入QQ群:[OpenRL官方交流群](./docs/images/qq.png)
+We have listed research projects that use OpenRL in the [OpenRL Project](./Project.md).
+If you are using OpenRL in your research project, you are also welcome to join this list.
+## Feedback and Contribution
+- If you have any questions or find bugs, you can check or ask in the [Issues](https://github.com/OpenRL-Lab/openrl/issues).
+- Join the QQ group: [OpenRL Official Communication Group](docs/images/qq.png)
-

+
-- 加入 [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) 群组,与我们一起讨论OpenRL的使用和开发。
-- 加入 [Discord](https://discord.gg/tyy96TGbep) 群组,与我们一起讨论OpenRL的使用和开发。
-- 发送邮件到: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)
-- 加入 [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions)
-
-OpenRL框架目前还在持续开发和文档建设,欢迎加入我们让该项目变得更好:
+- Join the [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) group to discuss OpenRL usage and development with us.
+- Join the [Discord](https://discord.gg/tyy96TGbep) group to discuss OpenRL usage and development with us.
+- Send an E-mail to: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)
+- Join the [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions).
-- 如何贡献代码:阅读 [贡献者手册](./CONTRIBUTING.md)
-- [OpenRL开发计划](https://github.com/OpenRL-Lab/openrl/issues/2)
+The OpenRL framework is still under continuous development and documentation.
+We welcome you to join us in making this project better:
+- How to contribute code: Read the [Contributors' Guide](./CONTRIBUTING.md)
+- [OpenRL Roadmap](https://github.com/OpenRL-Lab/openrl/issues/2)
-## 维护人员
+## Maintainers
-目前,OpenRL由以下维护人员维护:
+At present, OpenRL is maintained by the following maintainers:
- [Shiyu Huang](https://huangshiyu13.github.io/)([@huangshiyu13](https://github.com/huangshiyu13))
- Wenze Chen([@Chen001117](https://github.com/Chen001117))
-欢迎更多的贡献者加入我们的维护团队 (发送邮件到[huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)申请加入OpenRL团队)。
+Welcome more contributors to join our maintenance team (send an E-mail to [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)
+to apply for joining the OpenRL team).
-## 支持者
+## Supporters
### ↳ Contributors
@@ -245,7 +279,7 @@ OpenRL框架目前还在持续开发和文档建设,欢迎加入我们让该
## Citing OpenRL
-如果我们的工作对你有帮助,欢迎引用我们:
+If our work has been helpful to you, please feel free to cite us:
```latex
@misc{openrl2023,
title={OpenRL},
diff --git a/README_en.md b/README_en.md
deleted file mode 100644
index 4ce8a1b7..00000000
--- a/README_en.md
+++ /dev/null
@@ -1,297 +0,0 @@
-
-

-
-
----
-[](https://pypi.org/project/openrl/)
-
-[](https://anaconda.org/openrl/openrl)
-[](https://anaconda.org/openrl/openrl)
-[](https://anaconda.org/openrl/openrl)
-[](https://github.com/psf/black)
-
-
-[](https://hitsofcode.com/github/OpenRL-Lab/openrl/view?branch=main)
-[](https://codecov.io/gh/OpenRL-Lab/openrl_release)
-
-[](https://openrl-docs.readthedocs.io/zh/latest/?badge=latest)
-[](https://openrl-docs.readthedocs.io/zh/latest/)
-
-
-[](https://github.com/opendilab/OpenRL/stargazers)
-[](https://github.com/OpenRL-Lab/openrl/network)
-
-[](https://github.com/OpenRL-Lab/openrl/issues)
-[](https://github.com/OpenRL-Lab/openrl/pulls)
-[](https://github.com/OpenRL-Lab/openrl/graphs/contributors)
-[](https://github.com/OpenRL-Lab/openrl/blob/master/LICENSE)
-
-[](https://discord.gg/tyy96TGbep)
-[](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg)
-
-OpenRL-v0.0.11 is updated on May 19, 2023
-
-The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with OpenRL, you can switch to the stable branch.
-
-## Welcome to OpenRL
-
-[中文介绍](./README.md) | [Documentation](https://openrl-docs.readthedocs.io/en/latest/) | [中文文档](https://openrl-docs.readthedocs.io/zh/latest/)
-
-OpenRL is an open-source general reinforcement learning research framework that supports training for various tasks
-such as single-agent, multi-agent, and natural language.
-Developed based on PyTorch, the goal of OpenRL is to provide a simple-to-use, flexible, efficient and sustainable platform for the reinforcement learning research community.
-
-Currently, the features supported by OpenRL include:
-
-- A simple-to-use universal interface that supports training for both single-agent and multi-agent
-
-- Reinforcement learning training support for natural language tasks (such as dialogue)
-
-- Importing models and datasets from [Hugging Face](https://huggingface.co/)
-
-- Support for models such as LSTM, GRU, Transformer etc.
-
-- Multiple training acceleration methods including automatic mixed precision training and data collecting wth half precision policy network
-
-- User-defined training models, reward models, training data and environment support
-
-- Support for [gymnasium](https://gymnasium.farama.org/) environments
-
-- Dictionary observation space support
-
-- Popular visualization tools such as [wandb](https://wandb.ai/), [tensorboardX](https://tensorboardx.readthedocs.io/en/latest/index.html) are supported
-
-- Serial or parallel environment training while ensuring consistent results in both modes
-
-- Chinese and English documentation
-
-- Provides unit testing and code coverage testing
-
-- Compliant with Black Code Style guidelines and type checking
-
-This framework has undergone multiple iterations by the [OpenRL-Lab](https://github.com/OpenRL-Lab) team which has applied it in academic research.
-It has now become a mature reinforcement learning framework.
-
-OpenRL-Lab will continue to maintain and update OpenRL, and we welcome everyone to join our [open-source community](./docs/CONTRIBUTING_en.md)
-to contribute towards the development of reinforcement learning.
-
-For more information about OpenRL, please refer to the [documentation](https://openrl-docs.readthedocs.io/en/latest/).
-
-## Outline
-
-- [Welcome to OpenRL](#welcome-to-openrl)
-- [Outline](#outline)
-- [Installation](#installation)
-- [Use Docker](#use-docker)
-- [Quick Start](#quick-start)
-- [Gallery](#gallery)
-- [Projects Using OpenRL](#projects-using-openrl)
-- [Feedback and Contribution](#feedback-and-contribution)
-- [Maintainers](#maintainers)
-- [Supporters](#supporters)
- - [↳ Contributors](#-contributors)
- - [↳ Stargazers](#-stargazers)
- - [↳ Forkers](#-forkers)
-- [Citing OpenRL](#citing-openrl)
-- [License](#license)
-- [Acknowledgments](#acknowledgments)
-
-## Installation
-
-Users can directly install OpenRL via pip:
-
-```bash
-pip install openrl
-```
-
-If users are using Anaconda or Miniconda, they can also install OpenRL via conda:
-
-```bash
-conda install -c openrl openrl
-```
-
-Users who want to modify the source code can also install OpenRL from the source code:
-
-```bash
-git clone https://github.com/OpenRL-Lab/openrl.git && cd openrl
-pip install -e .
-```
-
-After installation, users can check the version of OpenRL through command line:
-
-```bash
-openrl --version
-```
-
-**Tips**: No installation required, try OpenRL online through Colab: [](https://colab.research.google.com/drive/15VBA-B7AJF8dBazzRcWAxJxZI7Pl9m-g?usp=sharing)
-
-## Use Docker
-
-OpenRL currently provides Docker images with and without GPU support.
-If the user's computer does not have an NVIDIA GPU, they can obtain an image without the GPU plugin using the following command:
-```bash
-sudo docker pull openrllab/openrl-cpu
-```
-
-If the user wants to accelerate training with a GPU, they can obtain it using the following command:
-```bash
-sudo docker pull openrllab/openrl
-```
-
-After successfully pulling the image, users can run OpenRL's Docker image using the following commands:
-```bash
-# Without GPU acceleration
-sudo docker run -it openrllab/openrl-cpu
-# With GPU acceleration
-sudo docker run -it --gpus all --net host openrllab/openrl
-```
-
-Once inside the Docker container, users can check OpenRL's version and then run test cases using these commands:
-```bash
-# Check OpenRL version in Docker container
-openrl --version
-# Run test case
-openrl --mode train --env CartPole-v1
-```
-
-## Quick Start
-
-OpenRL provides a simple and easy-to-use interface for beginners in reinforcement learning.
-Below is an example of using the PPO algorithm to train the `CartPole` environment:
-```python
-# train_ppo.py
-from openrl.envs.common import make
-from openrl.modules.common import PPONet as Net
-from openrl.runners.common import PPOAgent as Agent
-env = make("CartPole-v1", env_num=9) # Create an environment and set the environment parallelism to 9.
-net = Net(env) # Create neural network.
-agent = Agent(net) # Initialize the agent.
-agent.train(total_time_steps=20000) # Start training and set the total number of steps to 20,000 for the running environment.
-```
-Training an agent using OpenRL only requires four simple steps:
-**Create Environment** => **Initialize Model** => **Initialize Agent** => **Start Training**!
-
-For a well-trained agent, users can also easily test the agent:
-```python
-# train_ppo.py
-from openrl.envs.common import make
-from openrl.modules.common import PPONet as Net
-from openrl.runners.common import PPOAgent as Agent
-agent = Agent(Net(make("CartPole-v1", env_num=9))) # Initialize trainer.
-agent.train(total_time_steps=20000)
-# Create an environment for test, set the parallelism of the environment to 9, and set the rendering mode to group_human.
-env = make("CartPole-v1", env_num=9, render_mode="group_human")
-agent.set_env(env) # The agent requires an interactive environment.
-obs, info = env.reset() # Initialize the environment to obtain initial observations and environmental information.
-while True:
- action, _ = agent.act(obs) # The agent predicts the next action based on environmental observations.
- # The environment takes one step according to the action, obtains the next observation, reward, whether it ends and environmental information.
- obs, r, done, info = env.step(action)
- if any(done): break
-env.close() # Close test environment
-```
-Executing the above code on a regular laptop only takes a few seconds
-to complete the training. Below shows the visualization of the agent:
-
-
-

-
-
-
-**Tips:** Users can also quickly train the `CartPole` environment by executing a command line in the terminal.
-```bash
-openrl --mode train --env CartPole-v1
-```
-
-For training tasks such as multi-agent and natural language processing, OpenRL also provides a similarly simple and easy-to-use interface.
-
-For information on how to perform multi-agent training, set hyperparameters for training, load training configurations, use wandb, save GIF animations, etc., please refer to:
-- [Multi-Agent Training Example](https://openrl-docs.readthedocs.io/en/latest/quick_start/multi_agent_RL.html)
-
-For information on natural language task training, loading models/datasets on Hugging Face, customizing training models/reward models, etc., please refer to:
-- [Dialogue Task Training Example](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html)
-
-For more information about OpenRL, please refer to the [documentation](https://openrl-docs.readthedocs.io/en/latest/).
-
-## Gallery
-
-In order to facilitate users' familiarity with the framework, we provide more examples and demos of using OpenRL in [Gallery](./Gallery.md).
-Users are also welcome to contribute their own training examples and demos to the Gallery.
-
-## Projects Using OpenRL
-
-We have listed research projects that use OpenRL in the [OpenRL Project](./Project.md).
-If you are using OpenRL in your research project, you are also welcome to join this list.
-
-## Feedback and Contribution
-- If you have any questions or find bugs, you can check or ask in the [Issues](https://github.com/OpenRL-Lab/openrl/issues).
-- Join the QQ group: [OpenRL Official Communication Group](./docs/images/qq.png)
-
-

-
-
-- Join the [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) group to discuss OpenRL usage and development with us.
-- Join the [Discord](https://discord.gg/tyy96TGbep) group to discuss OpenRL usage and development with us.
-- Send an E-mail to: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)
-- Join the [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions).
-
-The OpenRL framework is still under continuous development and documentation.
-We welcome you to join us in making this project better:
-- How to contribute code: Read the [Contributors' Guide](./docs/CONTRIBUTING_en.md)
-- [OpenRL Roadmap](https://github.com/OpenRL-Lab/openrl/issues/2)
-
-## Maintainers
-
-At present, OpenRL is maintained by the following maintainers:
-- [Shiyu Huang](https://huangshiyu13.github.io/)([@huangshiyu13](https://github.com/huangshiyu13))
-- Wenze Chen([@Chen001117](https://github.com/Chen001117))
-
-Welcome more contributors to join our maintenance team (send an E-mail to [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)
-to apply for joining the OpenRL team).
-
-## Supporters
-
-### ↳ Contributors
-
-
-
-
-
-### ↳ Stargazers
-
-[](https://github.com/OpenRL-Lab/openrl/stargazers)
-
-### ↳ Forkers
-
-[](https://github.com/OpenRL-Lab/openrl/network/members)
-
-## Citing OpenRL
-
-If our work has been helpful to you, please feel free to cite us:
-```latex
-@misc{openrl2023,
- title={OpenRL},
- author={OpenRL Contributors},
- publisher = {GitHub},
- howpublished = {\url{https://github.com/OpenRL-Lab/openrl}},
- year={2023},
-}
-```
-
-## Star History
-
-[](https://star-history.com/#OpenRL-Lab/openrl&Date)
-
-## License
-OpenRL under the Apache 2.0 license.
-
-## Acknowledgments
-The development of the OpenRL framework has drawn on the strengths of other reinforcement learning frameworks:
-
-- Stable-baselines3: https://github.com/DLR-RM/stable-baselines3
-- pytorch-a2c-ppo-acktr-gail: https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail
-- MAPPO: https://github.com/marlbenchmark/on-policy
-- Gymnasium: https://github.com/Farama-Foundation/Gymnasium
-- DI-engine: https://github.com/opendilab/DI-engine/
-- Tianshou: https://github.com/thu-ml/tianshou
-- RL4LMs: https://github.com/allenai/RL4LMs
diff --git a/README_zh.md b/README_zh.md
new file mode 100644
index 00000000..c6b204b2
--- /dev/null
+++ b/README_zh.md
@@ -0,0 +1,288 @@
+
+

+
+
+---
+[](https://pypi.org/project/openrl/)
+
+[](https://anaconda.org/openrl/openrl)
+[](https://anaconda.org/openrl/openrl)
+[](https://anaconda.org/openrl/openrl)
+[](https://github.com/psf/black)
+
+
+[](https://hitsofcode.com/github/OpenRL-Lab/openrl/view?branch=main)
+[](https://codecov.io/gh/OpenRL-Lab/openrl)
+
+[](https://openrl-docs.readthedocs.io/zh/latest/?badge=latest)
+[](https://openrl-docs.readthedocs.io/zh/latest/)
+
+
+[](https://github.com/opendilab/OpenRL/stargazers)
+[](https://github.com/OpenRL-Lab/openrl/network)
+
+[](https://github.com/OpenRL-Lab/openrl/issues)
+[](https://github.com/OpenRL-Lab/openrl/pulls)
+[](https://github.com/OpenRL-Lab/openrl/graphs/contributors)
+[](https://github.com/OpenRL-Lab/openrl/blob/master/LICENSE)
+
+[](https://discord.gg/tyy96TGbep)
+[](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg)
+
+OpenRL-v0.0.11 is updated on May 19, 2023
+
+The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with OpenRL, you can switch to the stable branch.
+
+## 欢迎来到OpenRL
+
+[English](./README.md) | [中文文档](https://openrl-docs.readthedocs.io/zh/latest/) | [Documentation](https://openrl-docs.readthedocs.io/en/latest/)
+
+OpenRL是一个开源的通用强化学习研究框架,支持单智能体、多智能体、自然语言等多种任务的训练。 OpenRL基于PyTorch进行开发,目标是为强化学习研究社区提供一个简单易用、灵活高效、可持续扩展的平台。
+目前,OpenRL支持的特性包括:
+
+- 简单易用且支持单智能体、多智能体训练的通用接口
+- 支持自然语言任务(如对话任务)的强化学习训练
+- 支持从[Hugging Face](https://huggingface.co/)上导入模型和数据
+- 支持LSTM,GRU,Transformer等模型
+- 支持多种训练加速,例如:自动混合精度训练,半精度策略网络收集数据等
+- 支持用户自定义训练模型、奖励模型、训练数据以及环境
+- 支持[gymnasium](https://gymnasium.farama.org/)环境
+- 支持字典观测空间
+- 支持[wandb](https://wandb.ai/),[tensorboardX](https://tensorboardx.readthedocs.io/en/latest/index.html)等主流训练可视化工具
+- 支持环境的串行和并行训练,同时保证两种模式下的训练效果一致
+- 中英文文档
+- 提供单元测试和代码覆盖测试
+- 符合Black Code Style和类型检查
+
+OpenRL目前支持的算法(更多详情请参考 [Gallery](Gallery.md)):
+- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
+- [Multi-agent PPO (MAPPO)](https://arxiv.org/abs/2103.01955)
+- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
+- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)
+
+OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
+- [Gymnasium](https://gymnasium.farama.org/)
+- [MPE](https://github.com/openai/multiagent-particle-envs)
+- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
+- [Gym Retro](https://github.com/openai/retro)
+
+
+该框架经过了[OpenRL-Lab](https://github.com/OpenRL-Lab)的多次迭代并应用于学术研究,目前已经成为了一个成熟的强化学习框架。
+OpenRL-Lab将持续维护和更新OpenRL,欢迎大家加入我们的[开源社区](./docs/CONTRIBUTING_zh.md),一起为强化学习的发展做出贡献。
+关于OpenRL的更多信息,请参考[文档](https://openrl-docs.readthedocs.io/zh/latest/)。
+
+## 目录
+
+- [欢迎来到OpenRL](#欢迎来到openrl)
+- [目录](#目录)
+- [安装](#安装)
+- [使用Docker](#使用docker)
+- [快速上手](#快速上手)
+- [Gallery](#Gallery)
+- [使用OpenRL的项目](#使用OpenRL的项目)
+- [反馈和贡献](#反馈和贡献)
+- [维护人员](#维护人员)
+- [支持者](#支持者)
+ - [↳ Contributors](#-contributors)
+ - [↳ Stargazers](#-stargazers)
+ - [↳ Forkers](#-forkers)
+- [Citing OpenRL](#citing-openrl)
+- [License](#license)
+- [Acknowledgments](#acknowledgments)
+
+## 安装
+
+用户可以直接通过pip安装OpenRL:
+```bash
+pip install openrl
+```
+
+如果用户使用了Anaconda或者Miniconda,也可以通过conda安装OpenRL:
+```bash
+conda install -c openrl openrl
+```
+
+想要修改源码的用户也可以从源码安装OpenRL:
+```bash
+git clone https://github.com/OpenRL-Lab/openrl.git && cd openrl
+pip install -e .
+```
+
+安装完成后,用户可以直接通过命令行查看OpenRL的版本:
+```bash
+openrl --version
+```
+
+**Tips**:无需安装,通过Colab在线试用OpenRL: [](https://colab.research.google.com/drive/15VBA-B7AJF8dBazzRcWAxJxZI7Pl9m-g?usp=sharing)
+
+## 使用Docker
+
+OpenRL目前也提供了包含显卡支持和非显卡支持的Docker镜像。
+如果用户的电脑上没有英伟达显卡,则可以通过以下命令获取不包含显卡插件的镜像:
+
+```bash
+sudo docker pull openrllab/openrl-cpu
+```
+
+如果用户想要通过显卡加速训练,则可以通过以下命令获取:
+```bash
+sudo docker pull openrllab/openrl
+```
+
+镜像拉取成功后,用户可以通过以下命令运行OpenRL的Docker镜像:
+```bash
+# 不带显卡加速
+sudo docker run -it openrllab/openrl-cpu
+# 带显卡加速
+sudo docker run -it --gpus all --net host openrllab/openrl
+```
+
+进入Docker镜像后,用户可以通过以下命令查看OpenRL的版本然后运行测例:
+```bash
+# 查看Docker镜像中OpenRL的版本
+openrl --version
+# 运行测例
+openrl --mode train --env CartPole-v1
+```
+
+
+## 快速上手
+
+OpenRL为强化学习入门用户提供了简单易用的接口,
+下面是一个使用PPO算法训练`CartPole`环境的例子:
+```python
+# train_ppo.py
+from openrl.envs.common import make
+from openrl.modules.common import PPONet as Net
+from openrl.runners.common import PPOAgent as Agent
+env = make("CartPole-v1", env_num=9) # 创建环境,并设置环境并行数为9
+net = Net(env) # 创建神经网络
+agent = Agent(net) # 初始化智能体
+agent.train(total_time_steps=20000) # 开始训练,并设置环境运行总步数为20000
+```
+使用OpenRL训练智能体只需要简单的四步:**创建环境**=>**初始化模型**=>**初始化智能体**=>**开始训练**!
+
+对于训练好的智能体,用户也可以方便地进行智能体的测试:
+```python
+# train_ppo.py
+from openrl.envs.common import make
+from openrl.modules.common import PPONet as Net
+from openrl.runners.common import PPOAgent as Agent
+agent = Agent(Net(make("CartPole-v1", env_num=9))) # 初始化训练器
+agent.train(total_time_steps=20000)
+# 创建用于测试的环境,并设置环境并行数为9,设置渲染模式为group_human
+env = make("CartPole-v1", env_num=9, render_mode="group_human")
+agent.set_env(env) # 训练好的智能体设置需要交互的环境
+obs, info = env.reset() # 环境进行初始化,得到初始的观测值和环境信息
+while True:
+ action, _ = agent.act(obs) # 智能体根据环境观测输入预测下一个动作
+ # 环境根据动作执行一步,得到下一个观测值、奖励、是否结束、环境信息
+ obs, r, done, info = env.step(action)
+ if any(done): break
+env.close() # 关闭测试环境
+```
+在普通笔记本电脑上执行以上代码,只需要几秒钟,便可以完成该智能体的训练和可视化测试:
+
+
+

+
+
+
+**Tips:** 用户还可以在终端中通过执行一行命令快速训练`CartPole`环境:
+```bash
+openrl --mode train --env CartPole-v1
+```
+
+对于多智能体、自然语言等任务的训练,OpenRL也提供了同样简单易用的接口。
+
+关于如何进行多智能体训练、训练超参数设置、训练配置文件加载、wandb使用、保存gif动画等信息,请参考:
+- [多智能体训练例子](https://openrl-docs.readthedocs.io/zh/latest/quick_start/multi_agent_RL.html)
+
+关于自然语言任务训练、Hugging Face上模型(数据)加载、自定义训练模型(奖励模型)等信息,请参考:
+- [对话任务训练例子](https://openrl-docs.readthedocs.io/zh/latest/quick_start/train_nlp.html)
+
+关于OpenRL的更多信息,请参考[文档](https://openrl-docs.readthedocs.io/zh/latest/)。
+
+## Gallery
+
+为了方便用户熟悉该框架,
+我们在[Gallery](Gallery.md)中提供了更多使用OpenRL的示例和demo。
+也欢迎用户将自己的训练示例和demo贡献到Gallery中。
+
+## 使用OpenRL的研究项目
+
+我们在 [OpenRL Project](Project.md) 中列举了使用OpenRL的研究项目。
+如果你在研究项目中使用了OpenRL,也欢迎加入该列表。
+
+## 反馈和贡献
+- 有问题和发现bugs可以到 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 处进行查询或提问
+- 加入QQ群:[OpenRL官方交流群](docs/images/qq.png)
+
+
+

+
+
+- 加入 [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) 群组,与我们一起讨论OpenRL的使用和开发。
+- 加入 [Discord](https://discord.gg/tyy96TGbep) 群组,与我们一起讨论OpenRL的使用和开发。
+- 发送邮件到: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)
+- 加入 [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions)
+
+OpenRL框架目前还在持续开发和文档建设,欢迎加入我们让该项目变得更好:
+
+- 如何贡献代码:阅读 [贡献者手册](./docs/CONTRIBUTING_zh.md)
+- [OpenRL开发计划](https://github.com/OpenRL-Lab/openrl/issues/2)
+
+## 维护人员
+
+目前,OpenRL由以下维护人员维护:
+- [Shiyu Huang](https://huangshiyu13.github.io/)([@huangshiyu13](https://github.com/huangshiyu13))
+- Wenze Chen([@Chen001117](https://github.com/Chen001117))
+
+欢迎更多的贡献者加入我们的维护团队 (发送邮件到[huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)申请加入OpenRL团队)。
+
+## 支持者
+
+### ↳ Contributors
+
+
+
+
+
+### ↳ Stargazers
+
+[](https://github.com/OpenRL-Lab/openrl/stargazers)
+
+### ↳ Forkers
+
+[](https://github.com/OpenRL-Lab/openrl/network/members)
+
+## Citing OpenRL
+
+如果我们的工作对你有帮助,欢迎引用我们:
+```latex
+@misc{openrl2023,
+ title={OpenRL},
+ author={OpenRL Contributors},
+ publisher = {GitHub},
+ howpublished = {\url{https://github.com/OpenRL-Lab/openrl}},
+ year={2023},
+}
+```
+
+## Star History
+
+[](https://star-history.com/#OpenRL-Lab/openrl&Date)
+
+## License
+OpenRL under the Apache 2.0 license.
+
+## Acknowledgments
+The development of the OpenRL framework has drawn on the strengths of other reinforcement learning frameworks:
+
+- Stable-baselines3: https://github.com/DLR-RM/stable-baselines3
+- pytorch-a2c-ppo-acktr-gail: https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail
+- MAPPO: https://github.com/marlbenchmark/on-policy
+- Gymnasium: https://github.com/Farama-Foundation/Gymnasium
+- DI-engine: https://github.com/opendilab/DI-engine/
+- Tianshou: https://github.com/thu-ml/tianshou
+- RL4LMs: https://github.com/allenai/RL4LMs
diff --git a/docs/CONTRIBUTING_en.md b/docs/CONTRIBUTING_en.md
deleted file mode 100644
index 64de484f..00000000
--- a/docs/CONTRIBUTING_en.md
+++ /dev/null
@@ -1,45 +0,0 @@
-## How to Contribute to OpenRL
-
-The OpenRL community welcomes anyone to contribute to the development of OpenRL, whether you are a developer or a user. Your feedback and contributions are our driving force! You can join the contribution of OpenRL in the following ways:
-
-- As an OpenRL user, discover bugs in OpenRL and submit an [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose).
-- As an OpenRL user, discover errors in the documentation of OpenRL and submit an [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose).
-- Write test code to improve the code coverage of OpenRL (you can check the code coverage situation of OpenRL from [here](https://app.codecov.io/gh/OpenRL-Lab/openrl)). You can choose interested code snippets for writing test codes.
-- As an open-source developer, fix existing bugs for OpenRL.
-- As an open-source developer, add new environments and examples for OpenRL.
-- As an open-source developer, add new algorithms for OpenRL.
-
-## Contributing to OpenRL
-
-Welcome to contribute to the development of OpenRL. We appreciate your contribution!
-
-- If you want to contribute new features, please create a new [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose) first
-to discuss the implementation details of this feature. If the feature is approved by everyone, you can start implementing the code.
-- You can also check for unimplemented features and existing bugs in [Issues](https://github.com/OpenRL-Lab/openrl/issues),
-reply in the corresponding issue that you want to solve it, and then start implementing the code.
-
-After completing your code implementation, you need to pull the latest `main` branch and merge it.
-After resolving any merge conflicts,
-you can submit your code for merging into OpenRL's main branch through [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls).
-
-Before submitting a Pull Request, you need to complete [Code Testing and Code Formatting](#code-testing-and-code-formatting).
-
-Then, your Pull Request needs to pass automated testing on GitHub.
-
-Finally, at least one maintainer's review and approval are required before being merged into the main branch.
-
-## Code Testing and Code Formatting
-
-Before submitting a Pull Request, make sure that your code passes unit tests and conforms with OpenRL's coding style.
-
-Firstly, you should install the test-related packages: `pip install -e ".[test]"`
-
-Then, ensure that unit tests pass by executing `make test`.
-
-Next, format your code by running `make format`.
-
-Lastly, run `make commit-checks` to check if your code complies with OpenRL's coding style.
-
-> Tip: OpenRL uses [black](https://github.com/psf/black) coding style.
-You can install black plugins in your editor as shown in the [official website](https://black.readthedocs.io/en/stable/integrations/editors.html)
-to help automatically format codes.
\ No newline at end of file
diff --git a/docs/CONTRIBUTING_zh.md b/docs/CONTRIBUTING_zh.md
new file mode 100644
index 00000000..9fb1d012
--- /dev/null
+++ b/docs/CONTRIBUTING_zh.md
@@ -0,0 +1,51 @@
+## 如何参与OpenRL的建设
+
+OpenRL社区欢迎任何人参与到OpenRL的建设中来,无论您是开发者还是用户,您的反馈和贡献都是我们前进的动力!
+您可以通过以下方式加入到OpenRL的贡献中来:
+
+- 作为OpenRL的用户,发现OpenRL中的bug,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。
+- 作为OpenRL的用户,发现OpenRL文档中的错误,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。
+- 写测试代码,提升OpenRL的代码测试覆盖率(大家可以从[这里](https://app.codecov.io/gh/OpenRL-Lab/openrl)查到OpenRL的代码测试覆盖情况)。
+ 您可以选择感兴趣的代码片段进行编写代码测试,
+- 作为OpenRL的开发者,为OpenRL修复已有的bug。
+- 作为OpenRL的开发者,为OpenRL添加新的环境和样例。
+- 作为OpenRL的开发者,为OpenRL添加新的算法。
+
+## 贡献者手册
+
+欢迎更多的人参与到OpenRL的开发中来,我们非常欢迎您的贡献!
+
+- 如果您想要贡献新的功能,请先在请先创建一个新的[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose),
+ 以便我们讨论这个功能的实现细节。如果该功能得到了大家的认可,您可以开始进行代码实现。
+- 您也可以在 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 中查看未被实现的功能和仍然存的在bug,
+在对应的issue中进行回复,说明您想要解决该issue,然后开始进行代码实现。
+
+在您完成了代码实现之后,您需要拉取最新的`main`分支并进行合并。
+解决合并冲突后,
+您可以通过提交 [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls)
+的方式将您的代码合并到OpenRL的main分支中。
+
+在提交Pull Request前,您需要完成 [代码测试和代码格式化](#代码测试和代码格式化)。
+
+然后,您的Pull Request需要通过GitHub上的自动化测试。
+
+最后,需要得到至少一个开发人员的review和批准,才能被合并到main分支中。
+
+## 代码测试和代码格式化
+
+在您提交Pull Request之前,您需要确保您的代码通过了单元测试,并且符合OpenRL的代码风格。
+
+首先,您需要安装测试相关的包:`pip install -e ".[test]"`
+
+然后,您需要确保单元测试通过,这可以通过执行`make test`来完成。
+
+然后,您需要执行`make format`来格式化您的代码。
+
+最后,您需要执行`make commit-checks`来检查您的代码是否符合OpenRL的代码风格。
+
+> 小技巧: OpenRL使用 [black](https://github.com/psf/black) 代码风格。
+您可以在您的编辑器中安装black的[插件](https://black.readthedocs.io/en/stable/integrations/editors.html),
+来帮助您自动格式化代码。
+
+
+
diff --git a/examples/mpe/README.md b/examples/mpe/README.md
index c1e25faa..f3ed00ef 100644
--- a/examples/mpe/README.md
+++ b/examples/mpe/README.md
@@ -1,7 +1,20 @@
## How to Use
-Users can train MPE via:
+Train MPE with [MAPPO]((https://arxiv.org/abs/2103.01955)) algorithm:
```shell
python train_ppo.py --config mpe_ppo.yaml
+```
+
+Train MPE with [JRPO](https://arxiv.org/abs/2302.07515) algorithm:
+
+```shell
+python train_ppo.py --config mpe_jrpo.yaml
+```
+
+
+Train MPE with [MAT](https://arxiv.org/abs/2205.14953) algorithm:
+
+```shell
+python train_mat.py --config mpe_mat.yaml
```
\ No newline at end of file
diff --git a/examples/mpe/mpe_jrpo.yaml b/examples/mpe/mpe_jrpo.yaml
new file mode 100644
index 00000000..6a97afc1
--- /dev/null
+++ b/examples/mpe/mpe_jrpo.yaml
@@ -0,0 +1,12 @@
+seed: 0
+lr: 7e-4
+critic_lr: 7e-4
+episode_length: 25
+run_dir: ./run_results/
+experiment_name: train_mpe_jrpo
+log_interval: 10
+use_recurrent_policy: true
+use_joint_action_loss: true
+use_valuenorm: true
+use_adv_normalize: true
+wandb_entity: openrl-lab
\ No newline at end of file
diff --git a/examples/mpe/mpe_mat.yaml b/examples/mpe/mpe_mat.yaml
new file mode 100644
index 00000000..1c2305b2
--- /dev/null
+++ b/examples/mpe/mpe_mat.yaml
@@ -0,0 +1,9 @@
+seed: 0
+lr: 7e-4
+episode_length: 25
+run_dir: ./run_results/
+experiment_name: train_mpe_mat
+log_interval: 10
+use_valuenorm: true
+use_adv_normalize: true
+wandb_entity: openrl-lab
\ No newline at end of file
diff --git a/examples/mpe/mpe_ppo.yaml b/examples/mpe/mpe_ppo.yaml
index 5b31e006..d058bb98 100644
--- a/examples/mpe/mpe_ppo.yaml
+++ b/examples/mpe/mpe_ppo.yaml
@@ -6,7 +6,7 @@ run_dir: ./run_results/
experiment_name: train_mpe
log_interval: 10
use_recurrent_policy: true
-use_joint_action_loss: true
+use_joint_action_loss: false
use_valuenorm: true
use_adv_normalize: true
wandb_entity: openrl-lab
\ No newline at end of file
diff --git a/examples/mpe/train_mat.py b/examples/mpe/train_mat.py
new file mode 100644
index 00000000..12816906
--- /dev/null
+++ b/examples/mpe/train_mat.py
@@ -0,0 +1,62 @@
+""""""
+
+import numpy as np
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.envs.wrappers.mat_wrapper import MATWrapper
+from openrl.modules.common import MATNet as Net
+from openrl.runners.common import MATAgent as Agent
+
+
+def train():
+ # 创建 环境
+ env_num = 100
+ env = make(
+ "simple_spread",
+ env_num=env_num,
+ asynchronous=True,
+ )
+ env = MATWrapper(env)
+
+ # 创建 神经网络
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args()
+ net = Net(env, cfg=cfg, device="cuda")
+
+ # 初始化训练器
+ agent = Agent(net, use_wandb=True)
+ # 开始训练
+ agent.train(total_time_steps=5000000)
+ env.close()
+ agent.save("./mat_agent/")
+ return agent
+
+
+def evaluation(agent):
+ # render_model = "group_human"
+ render_model = None
+ env_num = 9
+ env = make(
+ "simple_spread", render_mode=render_model, env_num=env_num, asynchronous=False
+ )
+ env = MATWrapper(env)
+ agent.load("./mat_agent/")
+ agent.set_env(env)
+ obs, info = env.reset(seed=0)
+ done = False
+ step = 0
+ total_reward = 0
+ while not np.any(done):
+ # 智能体根据 observation 预测下一个动作
+ action, _ = agent.act(obs, deterministic=True)
+ obs, r, done, info = env.step(action)
+ step += 1
+ total_reward += np.mean(r)
+ print(f"total_reward: {total_reward}")
+ env.close()
+
+
+if __name__ == "__main__":
+ agent = train()
+ evaluation(agent)
diff --git a/openrl/algorithms/dqn.py b/openrl/algorithms/dqn.py
index 0c21c33c..2b50060a 100644
--- a/openrl/algorithms/dqn.py
+++ b/openrl/algorithms/dqn.py
@@ -220,7 +220,9 @@ def train(self, buffer, turn_on=True):
elif self._use_naive_recurrent:
raise NotImplementedError
else:
- data_generator = buffer.feed_forward_generator(None, self.num_mini_batch)
+ data_generator = buffer.feed_forward_generator(
+ None, self.num_mini_batch
+ )
for sample in data_generator:
(q_loss) = self.dqn_update(sample, turn_on)
diff --git a/openrl/algorithms/mat.py b/openrl/algorithms/mat.py
new file mode 100644
index 00000000..0b1e0b61
--- /dev/null
+++ b/openrl/algorithms/mat.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+from openrl.algorithms.ppo import PPOAlgorithm
+
+
+class MATAlgorithm(PPOAlgorithm):
+ def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on):
+ loss_list = []
+
+ loss = (
+ policy_loss
+ - dist_entropy * self.entropy_coef
+ + value_loss * self.value_loss_coef
+ )
+ loss_list.append(loss)
+
+ return loss_list
+
+ def get_data_generator(self, buffer, advantages):
+ data_generator = buffer.feed_forward_generator_transformer(
+ advantages, self.num_mini_batch
+ )
+ return data_generator
diff --git a/openrl/algorithms/ppo.py b/openrl/algorithms/ppo.py
index 68e969f7..acd3d84e 100644
--- a/openrl/algorithms/ppo.py
+++ b/openrl/algorithms/ppo.py
@@ -110,43 +110,26 @@ def ppo_update(self, sample, turn_on=True):
for loss in loss_list:
loss.backward()
- if "transformer" in self.algo_module.models:
- if self._use_max_grad_norm:
- grad_norm = nn.utils.clip_grad_norm_(
- self.algo_module.models["transformer"].parameters(),
- self.max_grad_norm,
- )
- else:
- grad_norm = get_gard_norm(
- self.algo_module.models["transformer"].parameters()
- )
- critic_grad_norm = grad_norm
- actor_grad_norm = grad_norm
-
+ # else:
+ if self._use_share_model:
+ actor_para = self.algo_module.models["model"].get_actor_para()
else:
- if self._use_share_model:
- actor_para = self.algo_module.models["model"].get_actor_para()
- else:
- actor_para = self.algo_module.models["policy"].parameters()
+ actor_para = self.algo_module.models["policy"].parameters()
- if self._use_max_grad_norm:
- actor_grad_norm = nn.utils.clip_grad_norm_(
- actor_para, self.max_grad_norm
- )
- else:
- actor_grad_norm = get_gard_norm(actor_para)
+ if self._use_max_grad_norm:
+ actor_grad_norm = nn.utils.clip_grad_norm_(actor_para, self.max_grad_norm)
+ else:
+ actor_grad_norm = get_gard_norm(actor_para)
- if self._use_share_model:
- critic_para = self.algo_module.models["model"].get_critic_para()
- else:
- critic_para = self.algo_module.models["critic"].parameters()
+ if self._use_share_model:
+ critic_para = self.algo_module.models["model"].get_critic_para()
+ else:
+ critic_para = self.algo_module.models["critic"].parameters()
- if self._use_max_grad_norm:
- critic_grad_norm = nn.utils.clip_grad_norm_(
- critic_para, self.max_grad_norm
- )
- else:
- critic_grad_norm = get_gard_norm(critic_para)
+ if self._use_max_grad_norm:
+ critic_grad_norm = nn.utils.clip_grad_norm_(critic_para, self.max_grad_norm)
+ else:
+ critic_grad_norm = get_gard_norm(critic_para)
if self.use_amp:
for optimizer in self.algo_module.optimizers.values():
@@ -219,6 +202,16 @@ def to_single_np(self, input):
reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:])
return reshape_input[:, 0, ...]
+ def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on):
+ loss_list = []
+ if turn_on:
+ final_p_loss = policy_loss - dist_entropy * self.entropy_coef
+
+ loss_list.append(final_p_loss)
+ final_v_loss = value_loss * self.value_loss_coef
+ loss_list.append(final_v_loss)
+ return loss_list
+
def prepare_loss(
self,
critic_obs_batch,
@@ -235,8 +228,6 @@ def prepare_loss(
active_masks_batch,
turn_on,
):
- loss_list = []
-
if self.use_joint_action_loss:
critic_obs_batch = self.to_single_np(critic_obs_batch)
rnn_states_critic_batch = self.to_single_np(rnn_states_critic_batch)
@@ -341,21 +332,30 @@ def prepare_loss(
active_masks_batch,
)
- if "transformer" in self.algo_module.models:
- loss = (
- policy_loss
- - dist_entropy * self.entropy_coef
- + value_loss * self.value_loss_coef
+ loss_list = self.construct_loss_list(
+ policy_loss, dist_entropy, value_loss, turn_on
+ )
+ return loss_list, value_loss, policy_loss, dist_entropy, ratio
+
+ def get_data_generator(self, buffer, advantages):
+ if self._use_recurrent_policy:
+ if self.use_joint_action_loss:
+ data_generator = buffer.recurrent_generator_v3(
+ advantages, self.num_mini_batch, self.data_chunk_length
+ )
+ else:
+ data_generator = buffer.recurrent_generator(
+ advantages, self.num_mini_batch, self.data_chunk_length
+ )
+ elif self._use_naive_recurrent:
+ data_generator = buffer.naive_recurrent_generator(
+ advantages, self.num_mini_batch
)
- loss_list.append(loss)
else:
- if turn_on:
- final_p_loss = policy_loss - dist_entropy * self.entropy_coef
-
- loss_list.append(final_p_loss)
- final_v_loss = value_loss * self.value_loss_coef
- loss_list.append(final_v_loss)
- return loss_list, value_loss, policy_loss, dist_entropy, ratio
+ data_generator = buffer.feed_forward_generator(
+ advantages, self.num_mini_batch
+ )
+ return data_generator
def train(self, buffer, turn_on=True):
if self._use_popart or self._use_valuenorm:
@@ -396,27 +396,7 @@ def train(self, buffer, turn_on=True):
train_info["reduced_policy_loss"] = 0
for _ in range(self.ppo_epoch):
- if "transformer" in self.algo_module.models:
- data_generator = buffer.feed_forward_generator_transformer(
- advantages, self.num_mini_batch
- )
- elif self._use_recurrent_policy:
- if self.use_joint_action_loss:
- data_generator = buffer.recurrent_generator_v3(
- advantages, self.num_mini_batch, self.data_chunk_length
- )
- else:
- data_generator = buffer.recurrent_generator(
- advantages, self.num_mini_batch, self.data_chunk_length
- )
- elif self._use_naive_recurrent:
- data_generator = buffer.naive_recurrent_generator(
- advantages, self.num_mini_batch
- )
- else:
- data_generator = buffer.feed_forward_generator(
- advantages, self.num_mini_batch
- )
+ data_generator = self.get_data_generator(buffer, advantages)
for sample in data_generator:
(
diff --git a/openrl/buffers/offpolicy_replay_data.py b/openrl/buffers/offpolicy_replay_data.py
index 27ed6898..fa34c660 100644
--- a/openrl/buffers/offpolicy_replay_data.py
+++ b/openrl/buffers/offpolicy_replay_data.py
@@ -159,9 +159,10 @@ def insert(
self.policy_obs[self.step + 1] = policy_obs.copy()
self.next_critic_obs[self.step + 1] = next_critic_obs.copy()
self.next_policy_obs[self.step + 1] = next_policy_obs.copy()
-
- self.rnn_states[self.step + 1] = rnn_states.copy()
- self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
+ if rnn_states is not None:
+ self.rnn_states[self.step + 1] = rnn_states.copy()
+ if rnn_states_critic is not None:
+ self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
self.actions[self.step] = actions.copy()
self.action_log_probs[self.step] = action_log_probs.copy()
self.value_preds[self.step] = value_preds.copy()
diff --git a/openrl/buffers/replay_data.py b/openrl/buffers/replay_data.py
index 2fc16253..768c265c 100644
--- a/openrl/buffers/replay_data.py
+++ b/openrl/buffers/replay_data.py
@@ -264,9 +264,10 @@ def insert(
else:
self.critic_obs[self.step + 1] = critic_obs.copy()
self.policy_obs[self.step + 1] = policy_obs.copy()
-
- self.rnn_states[self.step + 1] = rnn_states.copy()
- self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
+ if rnn_states is not None:
+ self.rnn_states[self.step + 1] = rnn_states.copy()
+ if rnn_states_critic is not None:
+ self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
self.actions[self.step] = actions.copy()
self.action_log_probs[self.step] = action_log_probs.copy()
self.value_preds[self.step] = value_preds.copy()
diff --git a/openrl/drivers/offpolicy_driver.py b/openrl/drivers/offpolicy_driver.py
index 0f8d498a..20375e56 100644
--- a/openrl/drivers/offpolicy_driver.py
+++ b/openrl/drivers/offpolicy_driver.py
@@ -199,9 +199,9 @@ def act(
actions = np.expand_dims(q_values.argmax(axis=-1), axis=-1)
if random.random() > epsilon:
- actions = np.random.randint(low=0,
- high=self.envs.action_space.n,
- size=actions.shape)
+ actions = np.random.randint(
+ low=0, high=self.envs.action_space.n, size=actions.shape
+ )
return (
q_values,
@@ -211,4 +211,3 @@ def act(
def compute_returns(self):
pass
-
diff --git a/openrl/drivers/onpolicy_driver.py b/openrl/drivers/onpolicy_driver.py
index 99adbf07..7504b942 100644
--- a/openrl/drivers/onpolicy_driver.py
+++ b/openrl/drivers/onpolicy_driver.py
@@ -69,16 +69,17 @@ def add2buffer(self, data):
rnn_states,
rnn_states_critic,
) = data
+ if rnn_states is not None:
+ rnn_states[dones] = np.zeros(
+ (dones.sum(), self.recurrent_N, self.hidden_size),
+ dtype=np.float32,
+ )
- rnn_states[dones] = np.zeros(
- (dones.sum(), self.recurrent_N, self.hidden_size),
- dtype=np.float32,
- )
-
- rnn_states_critic[dones] = np.zeros(
- (dones.sum(), *self.buffer.data.rnn_states_critic.shape[3:]),
- dtype=np.float32,
- )
+ if rnn_states_critic is not None:
+ rnn_states_critic[dones] = np.zeros(
+ (dones.sum(), *self.buffer.data.rnn_states_critic.shape[3:]),
+ dtype=np.float32,
+ )
masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
masks[dones] = np.zeros((dones.sum(), 1), dtype=np.float32)
@@ -187,10 +188,12 @@ def act(
action_log_probs = np.array(
np.split(_t2n(action_log_prob), self.n_rollout_threads)
)
- rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))
- rnn_states_critic = np.array(
- np.split(_t2n(rnn_states_critic), self.n_rollout_threads)
- )
+ if rnn_states is not None:
+ rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))
+ if rnn_states_critic is not None:
+ rnn_states_critic = np.array(
+ np.split(_t2n(rnn_states_critic), self.n_rollout_threads)
+ )
return (
values,
diff --git a/openrl/envs/vec_env/wrappers/base_wrapper.py b/openrl/envs/vec_env/wrappers/base_wrapper.py
index 656cfce6..590e56eb 100644
--- a/openrl/envs/vec_env/wrappers/base_wrapper.py
+++ b/openrl/envs/vec_env/wrappers/base_wrapper.py
@@ -113,9 +113,9 @@ def reset(self, **kwargs):
"""Reset all environments."""
return self.env.reset(**kwargs)
- def step(self, actions):
+ def step(self, actions, *args, **kwargs):
"""Step all environments."""
- return self.env.step(actions)
+ return self.env.step(actions, *args, **kwargs)
def close(self, **kwargs):
return self.env.close(**kwargs)
@@ -193,15 +193,14 @@ def reset(self, **kwargs):
observation = results
return self.observation(observation)
- def step(self, actions):
+ def step(self, actions, *args, **kwargs):
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
- results = self.env.step(actions)
+ results = self.env.step(actions, *args, **kwargs)
if len(results) == 5:
observation, reward, termination, truncation, info = results
return (
self.observation(observation),
- observation,
reward,
termination,
truncation,
@@ -211,7 +210,6 @@ def step(self, actions):
observation, reward, done, info = results
return (
self.observation(observation),
- observation,
reward,
done,
info,
@@ -237,9 +235,9 @@ def observation(self, observation: ObsType) -> ObsType:
class VectorActionWrapper(VecEnvWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
- def step(self, actions: ActType):
+ def step(self, actions: ActType, *args, **kwargs):
"""Steps through the environment using a modified action by :meth:`action`."""
- return self.env.step(self.action(actions))
+ return self.env.step(self.action(actions), *args, **kwargs)
def actions(self, actions: ActType) -> ActType:
"""Transform the actions before sending them to the environment.
@@ -256,9 +254,9 @@ def actions(self, actions: ActType) -> ActType:
class VectorRewardWrapper(VecEnvWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments."""
- def step(self, actions):
+ def step(self, actions, *args, **kwargs):
"""Steps through the environment returning a reward modified by :meth:`reward`."""
- results = self.env.step(actions)
+ results = self.env.step(actions, *args, **kwargs)
reward = self.reward(results[1])
return results[0], reward, *results[2:]
diff --git a/openrl/envs/wrappers/mat_wrapper.py b/openrl/envs/wrappers/mat_wrapper.py
new file mode 100644
index 00000000..fdf45199
--- /dev/null
+++ b/openrl/envs/wrappers/mat_wrapper.py
@@ -0,0 +1,50 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+from openrl.envs.vec_env.wrappers.base_wrapper import VectorObservationWrapper
+
+
+class MATWrapper(VectorObservationWrapper):
+ @property
+ def observation_space(
+ self,
+ ):
+ """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used."""
+ if self._observation_space is None:
+ observation_space = self.env.observation_space
+ else:
+ observation_space = self._observation_space
+
+ if (
+ "critic" in observation_space.spaces.keys()
+ and "policy" in observation_space.spaces.keys()
+ ):
+ observation_space = observation_space["policy"]
+ return observation_space
+
+ def observation(self, observation):
+ if self._observation_space is None:
+ observation_space = self.env.observation_space
+ else:
+ observation_space = self._observation_space
+
+ if (
+ "critic" in observation_space.spaces.keys()
+ and "policy" in observation_space.spaces.keys()
+ ):
+ observation = observation["policy"]
+ return observation
diff --git a/openrl/modules/common/__init__.py b/openrl/modules/common/__init__.py
index e1f20fc3..cef64e38 100644
--- a/openrl/modules/common/__init__.py
+++ b/openrl/modules/common/__init__.py
@@ -1,9 +1,11 @@
from .base_net import BaseNet
from .dqn_net import DQNNet
+from .mat_net import MATNet
from .ppo_net import PPONet
__all__ = [
"BaseNet",
"PPONet",
"DQNNet",
+ "MATNet",
]
diff --git a/openrl/modules/common/mat_net.py b/openrl/modules/common/mat_net.py
new file mode 100644
index 00000000..cdf4e43a
--- /dev/null
+++ b/openrl/modules/common/mat_net.py
@@ -0,0 +1,44 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+from typing import Any, Dict, Union
+
+import gymnasium as gym
+import torch
+
+from openrl.modules.common.ppo_net import PPONet
+from openrl.modules.networks.MAT_network import MultiAgentTransformer
+
+
+class MATNet(PPONet):
+ def __init__(
+ self,
+ env: Union[gym.Env, str],
+ cfg=None,
+ device: Union[torch.device, str] = "cpu",
+ n_rollout_threads: int = 1,
+ model_dict: Dict[str, Any] = {"model": MultiAgentTransformer},
+ ) -> None:
+ cfg.use_share_model = True
+ super().__init__(
+ env=env,
+ cfg=cfg,
+ device=device,
+ n_rollout_threads=n_rollout_threads,
+ model_dict=model_dict,
+ )
diff --git a/openrl/modules/common/ppo_net.py b/openrl/modules/common/ppo_net.py
index 96907310..a95f0d3a 100644
--- a/openrl/modules/common/ppo_net.py
+++ b/openrl/modules/common/ppo_net.py
@@ -23,6 +23,7 @@
import torch
from openrl.configs.config import create_config_parser
+from openrl.modules.base_module import BaseModule
from openrl.modules.common.base_net import BaseNet
from openrl.modules.ppo_module import PPOModule
from openrl.utils.util import set_seed
@@ -36,6 +37,7 @@ def __init__(
device: Union[torch.device, str] = "cpu",
n_rollout_threads: int = 1,
model_dict: Optional[Dict[str, Any]] = None,
+ module_class: BaseModule = PPOModule,
) -> None:
super().__init__()
@@ -46,6 +48,7 @@ def __init__(
set_seed(cfg.seed)
env.reset(seed=cfg.seed)
+ cfg.num_agents = env.agent_num
cfg.n_rollout_threads = n_rollout_threads
cfg.learner_n_rollout_threads = cfg.n_rollout_threads
@@ -62,7 +65,7 @@ def __init__(
if isinstance(device, str):
device = torch.device(device)
- self.module = PPOModule(
+ self.module = module_class(
cfg=cfg,
policy_input_space=env.observation_space,
critic_input_space=env.observation_space,
diff --git a/openrl/modules/dqn_module.py b/openrl/modules/dqn_module.py
index 9d450150..a7d7d92f 100644
--- a/openrl/modules/dqn_module.py
+++ b/openrl/modules/dqn_module.py
@@ -102,7 +102,7 @@ def evaluate_actions(
masks,
available_actions=None,
masks_batch=None,
- critic_masks_batch=None
+ critic_masks_batch=None,
):
if masks_batch is None:
masks_batch = masks
diff --git a/openrl/modules/networks/MAT_network.py b/openrl/modules/networks/MAT_network.py
new file mode 100644
index 00000000..3dda25a7
--- /dev/null
+++ b/openrl/modules/networks/MAT_network.py
@@ -0,0 +1,484 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from openrl.buffers.utils.util import get_critic_obs_space, get_policy_obs_space
+from openrl.modules.networks.base_value_policy_network import BaseValuePolicyNetwork
+from openrl.modules.networks.utils.transformer_act import (
+ continuous_autoregreesive_act,
+ continuous_parallel_act,
+ discrete_autoregreesive_act,
+ discrete_parallel_act,
+)
+from openrl.modules.networks.utils.util import init
+from openrl.utils.util import check_v2 as check
+
+
+def init_(m, gain=0.01, activate=False):
+ if activate:
+ gain = nn.init.calculate_gain("relu")
+ return init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=gain)
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, n_embd, n_head, n_agent, masked=False):
+ super(SelfAttention, self).__init__()
+
+ assert n_embd % n_head == 0
+ self.masked = masked
+ self.n_head = n_head
+ # key, query, value projections for all heads
+ self.key = init_(nn.Linear(n_embd, n_embd))
+ self.query = init_(nn.Linear(n_embd, n_embd))
+ self.value = init_(nn.Linear(n_embd, n_embd))
+ # output projection
+ self.proj = init_(nn.Linear(n_embd, n_embd))
+ # if self.masked:
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ self.register_buffer(
+ "mask",
+ torch.tril(torch.ones(n_agent + 1, n_agent + 1)).view(
+ 1, 1, n_agent + 1, n_agent + 1
+ ),
+ )
+
+ self.att_bp = None
+
+ def forward(self, key, value, query):
+ B, L, D = query.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ k = (
+ self.key(key).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)
+ ) # (B, nh, L, hs)
+ q = (
+ self.query(query).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)
+ ) # (B, nh, L, hs)
+ v = (
+ self.value(value).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)
+ ) # (B, nh, L, hs)
+
+ # causal attention: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+
+ # self.att_bp = F.softmax(att, dim=-1)
+
+ if self.masked:
+ att = att.masked_fill(self.mask[:, :, :L, :L] == 0, float("-inf"))
+ att = F.softmax(att, dim=-1)
+
+ y = att @ v # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
+ y = (
+ y.transpose(1, 2).contiguous().view(B, L, D)
+ ) # re-assemble all head outputs side by side
+
+ # output projection
+ y = self.proj(y)
+ return y
+
+
+class EncodeBlock(nn.Module):
+ """an unassuming Transformer block"""
+
+ def __init__(self, n_embd, n_head, n_agent):
+ super(EncodeBlock, self).__init__()
+
+ self.ln1 = nn.LayerNorm(n_embd)
+ self.ln2 = nn.LayerNorm(n_embd)
+ # self.attn = SelfAttention(n_embd, n_head, n_agent, masked=True)
+ self.attn = SelfAttention(n_embd, n_head, n_agent, masked=False)
+ self.mlp = nn.Sequential(
+ init_(nn.Linear(n_embd, 1 * n_embd), activate=True),
+ nn.GELU(),
+ init_(nn.Linear(1 * n_embd, n_embd)),
+ )
+
+ def forward(self, x):
+ x = self.ln1(x + self.attn(x, x, x))
+ x = self.ln2(x + self.mlp(x))
+ return x
+
+
+class DecodeBlock(nn.Module):
+ """an unassuming Transformer block"""
+
+ def __init__(self, n_embd, n_head, n_agent):
+ super(DecodeBlock, self).__init__()
+
+ self.ln1 = nn.LayerNorm(n_embd)
+ self.ln2 = nn.LayerNorm(n_embd)
+ self.ln3 = nn.LayerNorm(n_embd)
+ self.attn1 = SelfAttention(n_embd, n_head, n_agent, masked=True)
+ self.attn2 = SelfAttention(n_embd, n_head, n_agent, masked=True)
+ self.mlp = nn.Sequential(
+ init_(nn.Linear(n_embd, 1 * n_embd), activate=True),
+ nn.GELU(),
+ init_(nn.Linear(1 * n_embd, n_embd)),
+ )
+
+ def forward(self, x, rep_enc):
+ x = self.ln1(x + self.attn1(x, x, x))
+ x = self.ln2(rep_enc + self.attn2(key=x, value=x, query=rep_enc))
+ x = self.ln3(x + self.mlp(x))
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self, state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state
+ ):
+ super(Encoder, self).__init__()
+
+ self.state_dim = state_dim
+ self.obs_dim = obs_dim
+ self.n_embd = n_embd
+ self.n_agent = n_agent
+ self.encode_state = encode_state
+ # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd))
+
+ self.state_encoder = nn.Sequential(
+ nn.LayerNorm(state_dim),
+ init_(nn.Linear(state_dim, n_embd), activate=True),
+ nn.GELU(),
+ )
+ self.obs_encoder = nn.Sequential(
+ nn.LayerNorm(obs_dim),
+ init_(nn.Linear(obs_dim, n_embd), activate=True),
+ nn.GELU(),
+ )
+
+ self.ln = nn.LayerNorm(n_embd)
+ self.blocks = nn.Sequential(
+ *[EncodeBlock(n_embd, n_head, n_agent) for _ in range(n_block)]
+ )
+ self.head = nn.Sequential(
+ init_(nn.Linear(n_embd, n_embd), activate=True),
+ nn.GELU(),
+ nn.LayerNorm(n_embd),
+ init_(nn.Linear(n_embd, 1)),
+ )
+
+ def forward(self, state, obs):
+ # state: (batch, n_agent, state_dim)
+ # obs: (batch, n_agent, obs_dim)
+ if self.encode_state:
+ state_embeddings = self.state_encoder(state)
+ x = state_embeddings
+ else:
+ obs_embeddings = self.obs_encoder(obs)
+ x = obs_embeddings
+
+ rep = self.blocks(self.ln(x))
+ v_loc = self.head(rep)
+
+ return v_loc, rep
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ obs_dim,
+ action_dim,
+ n_block,
+ n_embd,
+ n_head,
+ n_agent,
+ action_type="Discrete",
+ dec_actor=False,
+ share_actor=False,
+ ):
+ super(Decoder, self).__init__()
+
+ self.action_dim = action_dim
+ self.n_embd = n_embd
+ self.dec_actor = dec_actor
+ self.share_actor = share_actor
+ self.action_type = action_type
+
+ if action_type != "Discrete":
+ log_std = torch.ones(action_dim)
+ # log_std = torch.zeros(action_dim)
+ self.log_std = torch.nn.Parameter(log_std)
+ # self.log_std = torch.nn.Parameter(torch.zeros(action_dim))
+
+ if self.dec_actor:
+ if self.share_actor:
+ print("mac_dec!!!!!")
+ self.mlp = nn.Sequential(
+ nn.LayerNorm(obs_dim),
+ init_(nn.Linear(obs_dim, n_embd), activate=True),
+ nn.GELU(),
+ nn.LayerNorm(n_embd),
+ init_(nn.Linear(n_embd, n_embd), activate=True),
+ nn.GELU(),
+ nn.LayerNorm(n_embd),
+ init_(nn.Linear(n_embd, action_dim)),
+ )
+ else:
+ self.mlp = nn.ModuleList()
+ for n in range(n_agent):
+ actor = nn.Sequential(
+ nn.LayerNorm(obs_dim),
+ init_(nn.Linear(obs_dim, n_embd), activate=True),
+ nn.GELU(),
+ nn.LayerNorm(n_embd),
+ init_(nn.Linear(n_embd, n_embd), activate=True),
+ nn.GELU(),
+ nn.LayerNorm(n_embd),
+ init_(nn.Linear(n_embd, action_dim)),
+ )
+ self.mlp.append(actor)
+ else:
+ # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd))
+ if action_type == "Discrete":
+ self.action_encoder = nn.Sequential(
+ init_(nn.Linear(action_dim + 1, n_embd, bias=False), activate=True),
+ nn.GELU(),
+ )
+ else:
+ self.action_encoder = nn.Sequential(
+ init_(nn.Linear(action_dim, n_embd), activate=True), nn.GELU()
+ )
+ self.obs_encoder = nn.Sequential(
+ nn.LayerNorm(obs_dim),
+ init_(nn.Linear(obs_dim, n_embd), activate=True),
+ nn.GELU(),
+ )
+ self.ln = nn.LayerNorm(n_embd)
+ self.blocks = nn.Sequential(
+ *[DecodeBlock(n_embd, n_head, n_agent) for _ in range(n_block)]
+ )
+ self.head = nn.Sequential(
+ init_(nn.Linear(n_embd, n_embd), activate=True),
+ nn.GELU(),
+ nn.LayerNorm(n_embd),
+ init_(nn.Linear(n_embd, action_dim)),
+ )
+
+ def zero_std(self, device):
+ if self.action_type != "Discrete":
+ log_std = torch.zeros(self.action_dim).to(device)
+ self.log_std.data = log_std
+
+ # state, action, and return
+ def forward(self, action, obs_rep, obs):
+ # action: (batch, n_agent, action_dim), one-hot/logits?
+ # obs_rep: (batch, n_agent, n_embd)
+ if self.dec_actor:
+ if self.share_actor:
+ logit = self.mlp(obs)
+ else:
+ logit = []
+ for n in range(len(self.mlp)):
+ logit_n = self.mlp[n](obs[:, n, :])
+ logit.append(logit_n)
+ logit = torch.stack(logit, dim=1)
+ else:
+ action_embeddings = self.action_encoder(action)
+ x = self.ln(action_embeddings)
+ for block in self.blocks:
+ x = block(x, obs_rep)
+ logit = self.head(x)
+
+ return logit
+
+
+class MultiAgentTransformer(BaseValuePolicyNetwork):
+ def __init__(
+ self,
+ cfg,
+ input_space,
+ action_space,
+ device=torch.device("cpu"),
+ use_half=False,
+ ):
+ assert use_half == False, "half precision not supported for MAT algorithm"
+ super(MultiAgentTransformer, self).__init__(cfg, device)
+
+ obs_dim = get_policy_obs_space(input_space)[0]
+ critic_obs_dim = get_critic_obs_space(input_space)[0]
+
+ n_agent = cfg.num_agents
+ n_block = cfg.n_block
+ n_embd = cfg.n_embd
+ n_head = cfg.n_head
+ encode_state = cfg.encode_state
+ dec_actor = cfg.dec_actor
+ share_actor = cfg.share_actor
+ if action_space.__class__.__name__ == "Box":
+ self.action_type = "Continuous"
+ action_dim = action_space.shape[0]
+ self.action_num = action_dim
+ else:
+ self.action_type = "Discrete"
+ action_dim = action_space.n
+ self.action_num = 1
+
+ self.n_agent = n_agent
+ self.obs_dim = obs_dim
+ self.critic_obs_dim = critic_obs_dim
+ self.action_dim = action_dim
+ self.tpdv = dict(dtype=torch.float32, device=device)
+ self._use_policy_active_masks = cfg.use_policy_active_masks
+ self.device = device
+
+ # state unused
+ state_dim = 37
+
+ self.encoder = Encoder(
+ state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state
+ )
+ self.decoder = Decoder(
+ obs_dim,
+ action_dim,
+ n_block,
+ n_embd,
+ n_head,
+ n_agent,
+ self.action_type,
+ dec_actor=dec_actor,
+ share_actor=share_actor,
+ )
+ self.to(device)
+
+ def zero_std(self):
+ if self.action_type != "Discrete":
+ self.decoder.zero_std(self.device)
+
+ def eval_actions(
+ self, obs, rnn_states, action, masks, available_actions=None, active_masks=None
+ ):
+ obs = obs.reshape(-1, self.n_agent, self.obs_dim)
+
+ action = action.reshape(-1, self.n_agent, self.action_num)
+
+ if available_actions is not None:
+ available_actions = available_actions.reshape(
+ -1, self.n_agent, self.action_dim
+ )
+
+ # state: (batch, n_agent, state_dim)
+ # obs: (batch, n_agent, obs_dim)
+ # action: (batch, n_agent, 1)
+ # available_actions: (batch, n_agent, act_dim)
+
+ # state unused
+ ori_shape = np.shape(obs)
+ state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)
+
+ state = check(state).to(**self.tpdv)
+ obs = check(obs).to(**self.tpdv)
+ action = check(action).to(**self.tpdv)
+
+ if available_actions is not None:
+ available_actions = check(available_actions).to(**self.tpdv)
+
+ batch_size = np.shape(state)[0]
+ v_loc, obs_rep = self.encoder(state, obs)
+ if self.action_type == "Discrete":
+ action = action.long()
+ action_log, entropy = discrete_parallel_act(
+ self.decoder,
+ obs_rep,
+ obs,
+ action,
+ batch_size,
+ self.n_agent,
+ self.action_dim,
+ self.tpdv,
+ available_actions,
+ )
+ else:
+ action_log, entropy = continuous_parallel_act(
+ self.decoder,
+ obs_rep,
+ obs,
+ action,
+ batch_size,
+ self.n_agent,
+ self.action_dim,
+ self.tpdv,
+ )
+ action_log = action_log.view(-1, self.action_num)
+ v_loc = v_loc.view(-1, 1)
+ entropy = entropy.view(-1, self.action_num)
+ if self._use_policy_active_masks and active_masks is not None:
+ entropy = (entropy * active_masks).sum() / active_masks.sum()
+ else:
+ entropy = entropy.mean()
+ return action_log, entropy, v_loc
+
+ def get_actions(
+ self,
+ obs,
+ rnn_states_actor=None,
+ masks=None,
+ available_actions=None,
+ deterministic=False,
+ ):
+ obs = obs.reshape(-1, self.n_agent, self.obs_dim)
+ if available_actions is not None:
+ available_actions = available_actions.reshape(
+ -1, self.num_agents, self.action_dim
+ )
+
+ # state unused
+ ori_shape = np.shape(obs)
+ state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)
+
+ state = check(state).to(**self.tpdv)
+ obs = check(obs).to(**self.tpdv)
+ if available_actions is not None:
+ available_actions = check(available_actions).to(**self.tpdv)
+
+ batch_size = np.shape(obs)[0]
+ v_loc, obs_rep = self.encoder(state, obs)
+ if self.action_type == "Discrete":
+ output_action, output_action_log = discrete_autoregreesive_act(
+ self.decoder,
+ obs_rep,
+ obs,
+ batch_size,
+ self.n_agent,
+ self.action_dim,
+ self.tpdv,
+ available_actions,
+ deterministic,
+ )
+ else:
+ output_action, output_action_log = continuous_autoregreesive_act(
+ self.decoder,
+ obs_rep,
+ obs,
+ batch_size,
+ self.n_agent,
+ self.action_dim,
+ self.tpdv,
+ deterministic,
+ )
+
+ output_action = output_action.reshape(-1, output_action.shape[-1])
+ output_action_log = output_action_log.reshape(-1, output_action_log.shape[-1])
+ return output_action, output_action_log, None
+
+ def get_values(self, critic_obs, rnn_states_critic=None, masks=None):
+ critic_obs = critic_obs.reshape(-1, self.n_agent, self.critic_obs_dim)
+
+ ori_shape = np.shape(critic_obs)
+ state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)
+
+ state = check(state).to(**self.tpdv)
+ obs = check(critic_obs).to(**self.tpdv)
+ v_tot, obs_rep = self.encoder(state, obs)
+
+ v_tot = v_tot.reshape(-1, v_tot.shape[-1])
+ return v_tot, None
+
+ def get_actor_para(self):
+ return self.parameters()
+
+ def get_critic_para(self):
+ return self.parameters()
diff --git a/openrl/modules/networks/base_value_policy_network.py b/openrl/modules/networks/base_value_policy_network.py
new file mode 100644
index 00000000..20d6caa2
--- /dev/null
+++ b/openrl/modules/networks/base_value_policy_network.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2022 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+from abc import ABC, abstractmethod
+
+import torch.nn as nn
+
+from openrl.modules.utils.valuenorm import ValueNorm
+
+
+class BaseValuePolicyNetwork(ABC, nn.Module):
+ def __init__(self, cfg, device):
+ super(BaseValuePolicyNetwork, self).__init__()
+ self.device = device
+ self._use_valuenorm = cfg.use_valuenorm
+
+ if self._use_valuenorm:
+ self.value_normalizer = ValueNorm(1, device=self.device)
+ else:
+ self.value_normalizer = None
+
+ def forward(self, forward_type, *args, **kwargs):
+ if forward_type == "original":
+ return self.get_actions(*args, **kwargs)
+ elif forward_type == "eval_actions":
+ return self.eval_actions(*args, **kwargs)
+ elif forward_type == "get_values":
+ return self.get_values(*args, **kwargs)
+ else:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_actions(self, *args, **kwargs):
+ raise NotImplementedError
+
+ @abstractmethod
+ def eval_actions(self, *args, **kwargs):
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_values(self, *args, **kwargs):
+ raise NotImplementedError
diff --git a/openrl/modules/networks/q_network.py b/openrl/modules/networks/q_network.py
index ba0b63bb..ea070e71 100644
--- a/openrl/modules/networks/q_network.py
+++ b/openrl/modules/networks/q_network.py
@@ -18,7 +18,7 @@
import torch
import torch.nn as nn
-from openrl.buffers.utils.util import get_shape_from_obs_space_v2, get_critic_obs_space
+from openrl.buffers.utils.util import get_critic_obs_space, get_shape_from_obs_space_v2
from openrl.modules.networks.base_value_network import BaseValueNetwork
from openrl.modules.networks.utils.cnn import CNNBase
from openrl.modules.networks.utils.mix import MIXBase
diff --git a/openrl/modules/ppo_module.py b/openrl/modules/ppo_module.py
index 5d12ef70..7af9912c 100644
--- a/openrl/modules/ppo_module.py
+++ b/openrl/modules/ppo_module.py
@@ -130,7 +130,6 @@ def get_actions(
values, rnn_states_critic = self.models["critic"](
critic_obs, rnn_states_critic, masks
)
-
return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic
def get_values(self, critic_obs, rnn_states_critic, masks):
diff --git a/openrl/runners/common/__init__.py b/openrl/runners/common/__init__.py
index 5c34d65a..eb32123b 100644
--- a/openrl/runners/common/__init__.py
+++ b/openrl/runners/common/__init__.py
@@ -1,5 +1,6 @@
from openrl.runners.common.chat_agent import Chat6BAgent, ChatAgent
from openrl.runners.common.dqn_agent import DQNAgent
+from openrl.runners.common.mat_agent import MATAgent
from openrl.runners.common.ppo_agent import PPOAgent
__all__ = [
@@ -7,4 +8,5 @@
"ChatAgent",
"Chat6BAgent",
"DQNAgent",
+ "MATAgent",
]
diff --git a/openrl/runners/common/mat_agent.py b/openrl/runners/common/mat_agent.py
new file mode 100644
index 00000000..e296dac5
--- /dev/null
+++ b/openrl/runners/common/mat_agent.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+from typing import Type
+
+from openrl.algorithms.base_algorithm import BaseAlgorithm
+from openrl.algorithms.mat import MATAlgorithm
+from openrl.runners.common.base_agent import SelfAgent
+from openrl.runners.common.ppo_agent import PPOAgent
+from openrl.utils.logger import Logger
+
+
+class MATAgent(PPOAgent):
+ def train(
+ self: SelfAgent,
+ total_time_steps: int,
+ train_algo_class: Type[BaseAlgorithm] = MATAlgorithm,
+ ) -> None:
+ logger = Logger(
+ cfg=self._cfg,
+ project_name="MATAgent",
+ scenario_name=self._env.env_name,
+ wandb_entity=self._cfg.wandb_entity,
+ exp_name=self.exp_name,
+ log_path=self.run_dir,
+ use_wandb=self._use_wandb,
+ use_tensorboard=self._use_tensorboard,
+ )
+
+ super(MATAgent, self).train(
+ total_time_steps=total_time_steps,
+ train_algo_class=train_algo_class,
+ logger=logger,
+ )
diff --git a/openrl/runners/common/ppo_agent.py b/openrl/runners/common/ppo_agent.py
index 2b8993af..49aa3665 100644
--- a/openrl/runners/common/ppo_agent.py
+++ b/openrl/runners/common/ppo_agent.py
@@ -15,13 +15,14 @@
# limitations under the License.
""""""
-from typing import Dict, Optional, Tuple, Union
+from typing import Dict, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch
-from openrl.algorithms.ppo import PPOAlgorithm as TrainAlgo
+from openrl.algorithms.base_algorithm import BaseAlgorithm
+from openrl.algorithms.ppo import PPOAlgorithm
from openrl.buffers import NormalReplayBuffer as ReplayBuffer
from openrl.buffers.utils.obs_data import ObsData
from openrl.drivers.onpolicy_driver import OnPolicyDriver as Driver
@@ -48,7 +49,12 @@ def __init__(
net, env, run_dir, env_num, rank, world_size, use_wandb, use_tensorboard
)
- def train(self: SelfAgent, total_time_steps: int) -> None:
+ def train(
+ self: SelfAgent,
+ total_time_steps: int,
+ train_algo_class: Type[BaseAlgorithm] = PPOAlgorithm,
+ logger: Optional[Logger] = None,
+ ) -> None:
self._cfg.num_env_steps = total_time_steps
self.config = {
@@ -59,7 +65,7 @@ def train(self: SelfAgent, total_time_steps: int) -> None:
"device": self.net.device,
}
- trainer = TrainAlgo(
+ trainer = train_algo_class(
cfg=self._cfg,
init_module=self.net.module,
device=self.net.device,
@@ -73,17 +79,17 @@ def train(self: SelfAgent, total_time_steps: int) -> None:
self._env.action_space,
data_client=None,
)
-
- logger = Logger(
- cfg=self._cfg,
- project_name="PPOAgent",
- scenario_name=self._env.env_name,
- wandb_entity=self._cfg.wandb_entity,
- exp_name=self.exp_name,
- log_path=self.run_dir,
- use_wandb=self._use_wandb,
- use_tensorboard=self._use_tensorboard,
- )
+ if logger is None:
+ logger = Logger(
+ cfg=self._cfg,
+ project_name="PPOAgent",
+ scenario_name=self._env.env_name,
+ wandb_entity=self._cfg.wandb_entity,
+ exp_name=self.exp_name,
+ log_path=self.run_dir,
+ use_wandb=self._use_wandb,
+ use_tensorboard=self._use_tensorboard,
+ )
driver = Driver(
config=self.config,
trainer=trainer,
diff --git a/tests/test_algorithm/test_mat_algorithm.py b/tests/test_algorithm/test_mat_algorithm.py
new file mode 100644
index 00000000..4e4d9a98
--- /dev/null
+++ b/tests/test_algorithm/test_mat_algorithm.py
@@ -0,0 +1,82 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import numpy as np
+import pytest
+from gymnasium import spaces
+
+
+@pytest.fixture
+def obs_space():
+ return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32)
+
+
+@pytest.fixture
+def act_space():
+ return spaces.Discrete(2)
+
+
+@pytest.fixture(scope="module", params=[""])
+def config(request):
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.fixture
+def init_module(config, obs_space, act_space):
+ from openrl.modules.ppo_module import PPOModule
+
+ module = PPOModule(
+ config,
+ policy_input_space=obs_space,
+ critic_input_space=obs_space,
+ act_space=act_space,
+ )
+ return module
+
+
+@pytest.fixture
+def buffer_data(config, obs_space, act_space):
+ from openrl.buffers.normal_buffer import NormalReplayBuffer
+
+ buffer = NormalReplayBuffer(
+ config,
+ num_agents=1,
+ obs_space=obs_space,
+ act_space=act_space,
+ data_client=None,
+ episode_length=100,
+ )
+ return buffer.data
+
+
+@pytest.mark.unittest
+def test_mat_algorithm(config, init_module, buffer_data):
+ from openrl.algorithms.mat import MATAlgorithm
+
+ mat_algo = MATAlgorithm(config, init_module)
+ mat_algo.train(buffer_data)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))