diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9fb1d012..64de484f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,51 +1,45 @@ -## 如何参与OpenRL的建设 +## How to Contribute to OpenRL -OpenRL社区欢迎任何人参与到OpenRL的建设中来,无论您是开发者还是用户,您的反馈和贡献都是我们前进的动力! -您可以通过以下方式加入到OpenRL的贡献中来: +The OpenRL community welcomes anyone to contribute to the development of OpenRL, whether you are a developer or a user. Your feedback and contributions are our driving force! You can join the contribution of OpenRL in the following ways: -- 作为OpenRL的用户,发现OpenRL中的bug,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。 -- 作为OpenRL的用户,发现OpenRL文档中的错误,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。 -- 写测试代码,提升OpenRL的代码测试覆盖率(大家可以从[这里](https://app.codecov.io/gh/OpenRL-Lab/openrl)查到OpenRL的代码测试覆盖情况)。 - 您可以选择感兴趣的代码片段进行编写代码测试, -- 作为OpenRL的开发者,为OpenRL修复已有的bug。 -- 作为OpenRL的开发者,为OpenRL添加新的环境和样例。 -- 作为OpenRL的开发者,为OpenRL添加新的算法。 +- As an OpenRL user, discover bugs in OpenRL and submit an [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose). +- As an OpenRL user, discover errors in the documentation of OpenRL and submit an [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose). +- Write test code to improve the code coverage of OpenRL (you can check the code coverage situation of OpenRL from [here](https://app.codecov.io/gh/OpenRL-Lab/openrl)). You can choose interested code snippets for writing test codes. +- As an open-source developer, fix existing bugs for OpenRL. +- As an open-source developer, add new environments and examples for OpenRL. +- As an open-source developer, add new algorithms for OpenRL. -## 贡献者手册 +## Contributing to OpenRL -欢迎更多的人参与到OpenRL的开发中来,我们非常欢迎您的贡献! +Welcome to contribute to the development of OpenRL. We appreciate your contribution! -- 如果您想要贡献新的功能,请先在请先创建一个新的[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose), - 以便我们讨论这个功能的实现细节。如果该功能得到了大家的认可,您可以开始进行代码实现。 -- 您也可以在 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 中查看未被实现的功能和仍然存的在bug, -在对应的issue中进行回复,说明您想要解决该issue,然后开始进行代码实现。 +- If you want to contribute new features, please create a new [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose) first +to discuss the implementation details of this feature. If the feature is approved by everyone, you can start implementing the code. +- You can also check for unimplemented features and existing bugs in [Issues](https://github.com/OpenRL-Lab/openrl/issues), +reply in the corresponding issue that you want to solve it, and then start implementing the code. -在您完成了代码实现之后,您需要拉取最新的`main`分支并进行合并。 -解决合并冲突后, -您可以通过提交 [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls) -的方式将您的代码合并到OpenRL的main分支中。 +After completing your code implementation, you need to pull the latest `main` branch and merge it. +After resolving any merge conflicts, +you can submit your code for merging into OpenRL's main branch through [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls). -在提交Pull Request前,您需要完成 [代码测试和代码格式化](#代码测试和代码格式化)。 +Before submitting a Pull Request, you need to complete [Code Testing and Code Formatting](#code-testing-and-code-formatting). -然后,您的Pull Request需要通过GitHub上的自动化测试。 +Then, your Pull Request needs to pass automated testing on GitHub. -最后,需要得到至少一个开发人员的review和批准,才能被合并到main分支中。 +Finally, at least one maintainer's review and approval are required before being merged into the main branch. -## 代码测试和代码格式化 +## Code Testing and Code Formatting -在您提交Pull Request之前,您需要确保您的代码通过了单元测试,并且符合OpenRL的代码风格。 +Before submitting a Pull Request, make sure that your code passes unit tests and conforms with OpenRL's coding style. -首先,您需要安装测试相关的包:`pip install -e ".[test]"` +Firstly, you should install the test-related packages: `pip install -e ".[test]"` -然后,您需要确保单元测试通过,这可以通过执行`make test`来完成。 - -然后,您需要执行`make format`来格式化您的代码。 - -最后,您需要执行`make commit-checks`来检查您的代码是否符合OpenRL的代码风格。 - -> 小技巧: OpenRL使用 [black](https://github.com/psf/black) 代码风格。 -您可以在您的编辑器中安装black的[插件](https://black.readthedocs.io/en/stable/integrations/editors.html), -来帮助您自动格式化代码。 +Then, ensure that unit tests pass by executing `make test`. +Next, format your code by running `make format`. +Lastly, run `make commit-checks` to check if your code complies with OpenRL's coding style. +> Tip: OpenRL uses [black](https://github.com/psf/black) coding style. +You can install black plugins in your editor as shown in the [official website](https://black.readthedocs.io/en/stable/integrations/editors.html) +to help automatically format codes. \ No newline at end of file diff --git a/Gallery.md b/Gallery.md index 37a653a7..bf8d22ef 100644 --- a/Gallery.md +++ b/Gallery.md @@ -6,6 +6,7 @@ ![sparse](https://img.shields.io/badge/-sparse%20reward-orange) ![offline](https://img.shields.io/badge/-offlineRL-darkblue) ![selfplay](https://img.shields.io/badge/-selfplay-blue) +![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) (Discrete Action Space) @@ -15,6 +16,17 @@ ![IL](https://img.shields.io/badge/-IL/SL-purple) (Imitation Learning or Supervised Learning) +## Algorithm List + +
+ +| Algorithm | Tags | Refs | +|:-----------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------:|:-------------------------------:| +| [PPO](https://arxiv.org/abs/1707.06347) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) | +| [MAPPO](https://arxiv.org/abs/2103.01955) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) | +| [JRPO](https://arxiv.org/abs/2302.07515) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) | +| [MAT](https://arxiv.org/abs/2205.14953) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) | +
## Demo List diff --git a/README.md b/README.md index cf4e453a..847470ab 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
- +
--- @@ -12,7 +12,7 @@ [![Hits-of-Code](https://hitsofcode.com/github/OpenRL-Lab/openrl?branch=main)](https://hitsofcode.com/github/OpenRL-Lab/openrl/view?branch=main) -[![codecov](https://codecov.io/gh/OpenRL-Lab/openrl/branch/main/graph/badge.svg?token=T6BqaTiT0l)](https://codecov.io/gh/OpenRL-Lab/openrl) +[![codecov](https://codecov.io/gh/OpenRL-Lab/openrl_release/branch/main/graph/badge.svg?token=4FMEYMR83U)](https://codecov.io/gh/OpenRL-Lab/openrl_release) [![Documentation Status](https://readthedocs.org/projects/openrl-docs/badge/?version=latest)](https://openrl-docs.readthedocs.io/zh/latest/?badge=latest) [![Read the Docs](https://img.shields.io/readthedocs/openrl-docs-zh?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://openrl-docs.readthedocs.io/zh/latest/) @@ -33,201 +33,235 @@ OpenRL-v0.0.11 is updated on May 19, 2023 The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with OpenRL, you can switch to the stable branch. -## 欢迎来到OpenRL - -[English](./README_en.md) | [中文文档](https://openrl-docs.readthedocs.io/zh/latest/) | [Documentation](https://openrl-docs.readthedocs.io/en/latest/) - -OpenRL是一个开源的通用强化学习研究框架,支持单智能体、多智能体、自然语言等多种任务的训练。 OpenRL基于PyTorch进行开发,目标是为强化学习研究社区提供一个简单易用、灵活高效、可持续扩展的平台。 -目前,OpenRL支持的特性包括: - -- 简单易用且支持单智能体、多智能体训练的通用接口 -- 支持自然语言任务(如对话任务)的强化学习训练 -- 支持从[Hugging Face](https://huggingface.co/)上导入模型和数据 -- 支持LSTM,GRU,Transformer等模型 -- 支持多种训练加速,例如:自动混合精度训练,半精度策略网络收集数据等 -- 支持用户自定义训练模型、奖励模型、训练数据以及环境 -- 支持[gymnasium](https://gymnasium.farama.org/)环境 -- 支持字典观测空间 -- 支持[wandb](https://wandb.ai/),[tensorboardX](https://tensorboardx.readthedocs.io/en/latest/index.html)等主流训练可视化工具 -- 支持环境的串行和并行训练,同时保证两种模式下的训练效果一致 -- 中英文文档 -- 提供单元测试和代码覆盖测试 -- 符合Black Code Style和类型检查 - -该框架经过了[OpenRL-Lab](https://github.com/OpenRL-Lab)的多次迭代并应用于学术研究,目前已经成为了一个成熟的强化学习框架。 -OpenRL-Lab将持续维护和更新OpenRL,欢迎大家加入我们的[开源社区](./CONTRIBUTING.md),一起为强化学习的发展做出贡献。 -关于OpenRL的更多信息,请参考[文档](https://openrl-docs.readthedocs.io/zh/latest/)。 - -## 目录 - -- [欢迎来到OpenRL](#欢迎来到openrl) -- [目录](#目录) -- [安装](#安装) -- [使用Docker](#使用docker) -- [快速上手](#快速上手) -- [Gallery](#Gallery) -- [使用OpenRL的项目](#使用OpenRL的项目) -- [反馈和贡献](#反馈和贡献) -- [维护人员](#维护人员) -- [支持者](#支持者) - - [↳ Contributors](#-contributors) +## Welcome to OpenRL + +[中文介绍](README_zh.md) | [Documentation](https://openrl-docs.readthedocs.io/en/latest/) | [中文文档](https://openrl-docs.readthedocs.io/zh/latest/) + +OpenRL is an open-source general reinforcement learning research framework that supports training for various tasks +such as single-agent, multi-agent, and natural language. +Developed based on PyTorch, the goal of OpenRL is to provide a simple-to-use, flexible, efficient and sustainable platform for the reinforcement learning research community. + +Currently, the features supported by OpenRL include: + +- A simple-to-use universal interface that supports training for both single-agent and multi-agent + +- Reinforcement learning training support for natural language tasks (such as dialogue) + +- Importing models and datasets from [Hugging Face](https://huggingface.co/) + +- Support for models such as LSTM, GRU, Transformer etc. + +- Multiple training acceleration methods including automatic mixed precision training and data collecting wth half precision policy network + +- User-defined training models, reward models, training data and environment support + +- Support for [gymnasium](https://gymnasium.farama.org/) environments + +- Dictionary observation space support + +- Popular visualization tools such as [wandb](https://wandb.ai/), [tensorboardX](https://tensorboardx.readthedocs.io/en/latest/index.html) are supported + +- Serial or parallel environment training while ensuring consistent results in both modes + +- Chinese and English documentation + +- Provides unit testing and code coverage testing + +- Compliant with Black Code Style guidelines and type checking + +Algorithms currently supported by OpenRL (for more details, please refer to [Gallery](./Gallery.md)): +- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347) +- [Multi-agent PPO (MAPPO)](https://arxiv.org/abs/2103.01955) +- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515) +- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953) + +Environments currently supported by OpenRL (for more details, please refer to [Gallery](./Gallery.md)): +- [Gymnasium](https://gymnasium.farama.org/) +- [MPE](https://github.com/openai/multiagent-particle-envs) +- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros) +- [Gym Retro](https://github.com/openai/retro) + +This framework has undergone multiple iterations by the [OpenRL-Lab](https://github.com/OpenRL-Lab) team which has applied it in academic research. +It has now become a mature reinforcement learning framework. + +OpenRL-Lab will continue to maintain and update OpenRL, and we welcome everyone to join our [open-source community](./CONTRIBUTING.md) +to contribute towards the development of reinforcement learning. + +For more information about OpenRL, please refer to the [documentation](https://openrl-docs.readthedocs.io/en/latest/). + +## Outline + +- [Welcome to OpenRL](#welcome-to-openrl) +- [Outline](#outline) +- [Installation](#installation) +- [Use Docker](#use-docker) +- [Quick Start](#quick-start) +- [Gallery](#gallery) +- [Projects Using OpenRL](#projects-using-openrl) +- [Feedback and Contribution](#feedback-and-contribution) +- [Maintainers](#maintainers) +- [Supporters](#supporters) + - [↳ Contributors](#-contributors) - [↳ Stargazers](#-stargazers) - [↳ Forkers](#-forkers) - [Citing OpenRL](#citing-openrl) - [License](#license) - [Acknowledgments](#acknowledgments) -## 安装 +## Installation + +Users can directly install OpenRL via pip: -用户可以直接通过pip安装OpenRL: ```bash pip install openrl ``` -如果用户使用了Anaconda或者Miniconda,也可以通过conda安装OpenRL: +If users are using Anaconda or Miniconda, they can also install OpenRL via conda: + ```bash conda install -c openrl openrl ``` -想要修改源码的用户也可以从源码安装OpenRL: +Users who want to modify the source code can also install OpenRL from the source code: + ```bash git clone https://github.com/OpenRL-Lab/openrl.git && cd openrl pip install -e . ``` -安装完成后,用户可以直接通过命令行查看OpenRL的版本: +After installation, users can check the version of OpenRL through command line: + ```bash openrl --version ``` -**Tips**:无需安装,通过Colab在线试用OpenRL: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15VBA-B7AJF8dBazzRcWAxJxZI7Pl9m-g?usp=sharing) - -## 使用Docker +**Tips**: No installation required, try OpenRL online through Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15VBA-B7AJF8dBazzRcWAxJxZI7Pl9m-g?usp=sharing) -OpenRL目前也提供了包含显卡支持和非显卡支持的Docker镜像。 -如果用户的电脑上没有英伟达显卡,则可以通过以下命令获取不包含显卡插件的镜像: +## Use Docker +OpenRL currently provides Docker images with and without GPU support. +If the user's computer does not have an NVIDIA GPU, they can obtain an image without the GPU plugin using the following command: ```bash sudo docker pull openrllab/openrl-cpu ``` -如果用户想要通过显卡加速训练,则可以通过以下命令获取: +If the user wants to accelerate training with a GPU, they can obtain it using the following command: ```bash sudo docker pull openrllab/openrl ``` -镜像拉取成功后,用户可以通过以下命令运行OpenRL的Docker镜像: +After successfully pulling the image, users can run OpenRL's Docker image using the following commands: ```bash -# 不带显卡加速 +# Without GPU acceleration sudo docker run -it openrllab/openrl-cpu -# 带显卡加速 +# With GPU acceleration sudo docker run -it --gpus all --net host openrllab/openrl ``` -进入Docker镜像后,用户可以通过以下命令查看OpenRL的版本然后运行测例: -```bash -# 查看Docker镜像中OpenRL的版本 -openrl --version -# 运行测例 -openrl --mode train --env CartPole-v1 +Once inside the Docker container, users can check OpenRL's version and then run test cases using these commands: +```bash +# Check OpenRL version in Docker container +openrl --version +# Run test case +openrl --mode train --env CartPole-v1 ``` +## Quick Start -## 快速上手 - -OpenRL为强化学习入门用户提供了简单易用的接口, -下面是一个使用PPO算法训练`CartPole`环境的例子: +OpenRL provides a simple and easy-to-use interface for beginners in reinforcement learning. +Below is an example of using the PPO algorithm to train the `CartPole` environment: ```python # train_ppo.py from openrl.envs.common import make from openrl.modules.common import PPONet as Net from openrl.runners.common import PPOAgent as Agent -env = make("CartPole-v1", env_num=9) # 创建环境,并设置环境并行数为9 -net = Net(env) # 创建神经网络 -agent = Agent(net) # 初始化智能体 -agent.train(total_time_steps=20000) # 开始训练,并设置环境运行总步数为20000 +env = make("CartPole-v1", env_num=9) # Create an environment and set the environment parallelism to 9. +net = Net(env) # Create neural network. +agent = Agent(net) # Initialize the agent. +agent.train(total_time_steps=20000) # Start training and set the total number of steps to 20,000 for the running environment. ``` -使用OpenRL训练智能体只需要简单的四步:**创建环境**=>**初始化模型**=>**初始化智能体**=>**开始训练**! +Training an agent using OpenRL only requires four simple steps: +**Create Environment** => **Initialize Model** => **Initialize Agent** => **Start Training**! -对于训练好的智能体,用户也可以方便地进行智能体的测试: +For a well-trained agent, users can also easily test the agent: ```python # train_ppo.py from openrl.envs.common import make from openrl.modules.common import PPONet as Net from openrl.runners.common import PPOAgent as Agent -agent = Agent(Net(make("CartPole-v1", env_num=9))) # 初始化训练器 +agent = Agent(Net(make("CartPole-v1", env_num=9))) # Initialize trainer. agent.train(total_time_steps=20000) -# 创建用于测试的环境,并设置环境并行数为9,设置渲染模式为group_human +# Create an environment for test, set the parallelism of the environment to 9, and set the rendering mode to group_human. env = make("CartPole-v1", env_num=9, render_mode="group_human") -agent.set_env(env) # 训练好的智能体设置需要交互的环境 -obs, info = env.reset() # 环境进行初始化,得到初始的观测值和环境信息 +agent.set_env(env) # The agent requires an interactive environment. +obs, info = env.reset() # Initialize the environment to obtain initial observations and environmental information. while True: - action, _ = agent.act(obs) # 智能体根据环境观测输入预测下一个动作 - # 环境根据动作执行一步,得到下一个观测值、奖励、是否结束、环境信息 + action, _ = agent.act(obs) # The agent predicts the next action based on environmental observations. + # The environment takes one step according to the action, obtains the next observation, reward, whether it ends and environmental information. obs, r, done, info = env.step(action) if any(done): break -env.close() # 关闭测试环境 +env.close() # Close test environment ``` -在普通笔记本电脑上执行以上代码,只需要几秒钟,便可以完成该智能体的训练和可视化测试: +Executing the above code on a regular laptop only takes a few seconds +to complete the training. Below shows the visualization of the agent:
- +
-**Tips:** 用户还可以在终端中通过执行一行命令快速训练`CartPole`环境: +**Tips:** Users can also quickly train the `CartPole` environment by executing a command line in the terminal. ```bash openrl --mode train --env CartPole-v1 ``` -对于多智能体、自然语言等任务的训练,OpenRL也提供了同样简单易用的接口。 +For training tasks such as multi-agent and natural language processing, OpenRL also provides a similarly simple and easy-to-use interface. -关于如何进行多智能体训练、训练超参数设置、训练配置文件加载、wandb使用、保存gif动画等信息,请参考: -- [多智能体训练例子](https://openrl-docs.readthedocs.io/zh/latest/quick_start/multi_agent_RL.html) +For information on how to perform multi-agent training, set hyperparameters for training, load training configurations, use wandb, save GIF animations, etc., please refer to: +- [Multi-Agent Training Example](https://openrl-docs.readthedocs.io/en/latest/quick_start/multi_agent_RL.html) -关于自然语言任务训练、Hugging Face上模型(数据)加载、自定义训练模型(奖励模型)等信息,请参考: -- [对话任务训练例子](https://openrl-docs.readthedocs.io/zh/latest/quick_start/train_nlp.html) +For information on natural language task training, loading models/datasets on Hugging Face, customizing training models/reward models, etc., please refer to: +- [Dialogue Task Training Example](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html) -关于OpenRL的更多信息,请参考[文档](https://openrl-docs.readthedocs.io/zh/latest/)。 +For more information about OpenRL, please refer to the [documentation](https://openrl-docs.readthedocs.io/en/latest/). ## Gallery -为了方便用户熟悉该框架, -我们在[Gallery](./Gallery.md)中提供了更多使用OpenRL的示例和demo。 -也欢迎用户将自己的训练示例和demo贡献到Gallery中。 +In order to facilitate users' familiarity with the framework, we provide more examples and demos of using OpenRL in [Gallery](./Gallery.md). +Users are also welcome to contribute their own training examples and demos to the Gallery. -## 使用OpenRL的研究项目 +## Projects Using OpenRL -我们在 [OpenRL Project](./Project.md) 中列举了使用OpenRL的研究项目。 -如果你在研究项目中使用了OpenRL,也欢迎加入该列表。 - -## 反馈和贡献 -- 有问题和发现bugs可以到 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 处进行查询或提问 -- 加入QQ群:[OpenRL官方交流群](./docs/images/qq.png) +We have listed research projects that use OpenRL in the [OpenRL Project](./Project.md). +If you are using OpenRL in your research project, you are also welcome to join this list. +## Feedback and Contribution +- If you have any questions or find bugs, you can check or ask in the [Issues](https://github.com/OpenRL-Lab/openrl/issues). +- Join the QQ group: [OpenRL Official Communication Group](docs/images/qq.png)
- +
-- 加入 [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) 群组,与我们一起讨论OpenRL的使用和开发。 -- 加入 [Discord](https://discord.gg/tyy96TGbep) 群组,与我们一起讨论OpenRL的使用和开发。 -- 发送邮件到: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com) -- 加入 [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions) - -OpenRL框架目前还在持续开发和文档建设,欢迎加入我们让该项目变得更好: +- Join the [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) group to discuss OpenRL usage and development with us. +- Join the [Discord](https://discord.gg/tyy96TGbep) group to discuss OpenRL usage and development with us. +- Send an E-mail to: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com) +- Join the [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions). -- 如何贡献代码:阅读 [贡献者手册](./CONTRIBUTING.md) -- [OpenRL开发计划](https://github.com/OpenRL-Lab/openrl/issues/2) +The OpenRL framework is still under continuous development and documentation. +We welcome you to join us in making this project better: +- How to contribute code: Read the [Contributors' Guide](./CONTRIBUTING.md) +- [OpenRL Roadmap](https://github.com/OpenRL-Lab/openrl/issues/2) -## 维护人员 +## Maintainers -目前,OpenRL由以下维护人员维护: +At present, OpenRL is maintained by the following maintainers: - [Shiyu Huang](https://huangshiyu13.github.io/)([@huangshiyu13](https://github.com/huangshiyu13)) - Wenze Chen([@Chen001117](https://github.com/Chen001117)) -欢迎更多的贡献者加入我们的维护团队 (发送邮件到[huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)申请加入OpenRL团队)。 +Welcome more contributors to join our maintenance team (send an E-mail to [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com) +to apply for joining the OpenRL team). -## 支持者 +## Supporters ### ↳ Contributors @@ -245,7 +279,7 @@ OpenRL框架目前还在持续开发和文档建设,欢迎加入我们让该 ## Citing OpenRL -如果我们的工作对你有帮助,欢迎引用我们: +If our work has been helpful to you, please feel free to cite us: ```latex @misc{openrl2023, title={OpenRL}, diff --git a/README_en.md b/README_en.md deleted file mode 100644 index 4ce8a1b7..00000000 --- a/README_en.md +++ /dev/null @@ -1,297 +0,0 @@ -
- -
- ---- -[![PyPI](https://img.shields.io/pypi/v/openrl)](https://pypi.org/project/openrl/) -![PyPI - Python Version](https://img.shields.io/pypi/pyversions/openrl) -[![Anaconda-Server Badge](https://anaconda.org/openrl/openrl/badges/version.svg)](https://anaconda.org/openrl/openrl) -[![Anaconda-Server Badge](https://anaconda.org/openrl/openrl/badges/latest_release_date.svg)](https://anaconda.org/openrl/openrl) -[![Anaconda-Server Badge](https://anaconda.org/openrl/openrl/badges/downloads.svg)](https://anaconda.org/openrl/openrl) -[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) - - -[![Hits-of-Code](https://hitsofcode.com/github/OpenRL-Lab/openrl?branch=main)](https://hitsofcode.com/github/OpenRL-Lab/openrl/view?branch=main) -[![codecov](https://codecov.io/gh/OpenRL-Lab/openrl_release/branch/main/graph/badge.svg?token=4FMEYMR83U)](https://codecov.io/gh/OpenRL-Lab/openrl_release) - -[![Documentation Status](https://readthedocs.org/projects/openrl-docs/badge/?version=latest)](https://openrl-docs.readthedocs.io/zh/latest/?badge=latest) -[![Read the Docs](https://img.shields.io/readthedocs/openrl-docs-zh?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://openrl-docs.readthedocs.io/zh/latest/) - -![GitHub Org's stars](https://img.shields.io/github/stars/OpenRL-Lab) -[![GitHub stars](https://img.shields.io/github/stars/OpenRL-Lab/openrl)](https://github.com/opendilab/OpenRL/stargazers) -[![GitHub forks](https://img.shields.io/github/forks/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/network) -![GitHub commit activity](https://img.shields.io/github/commit-activity/m/OpenRL-Lab/openrl) -[![GitHub issues](https://img.shields.io/github/issues/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/issues) -[![GitHub pulls](https://img.shields.io/github/issues-pr/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/pulls) -[![Contributors](https://img.shields.io/github/contributors/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/graphs/contributors) -[![GitHub license](https://img.shields.io/github/license/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/blob/master/LICENSE) - -[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/tyy96TGbep) -[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) - -OpenRL-v0.0.11 is updated on May 19, 2023 - -The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with OpenRL, you can switch to the stable branch. - -## Welcome to OpenRL - -[中文介绍](./README.md) | [Documentation](https://openrl-docs.readthedocs.io/en/latest/) | [中文文档](https://openrl-docs.readthedocs.io/zh/latest/) - -OpenRL is an open-source general reinforcement learning research framework that supports training for various tasks -such as single-agent, multi-agent, and natural language. -Developed based on PyTorch, the goal of OpenRL is to provide a simple-to-use, flexible, efficient and sustainable platform for the reinforcement learning research community. - -Currently, the features supported by OpenRL include: - -- A simple-to-use universal interface that supports training for both single-agent and multi-agent - -- Reinforcement learning training support for natural language tasks (such as dialogue) - -- Importing models and datasets from [Hugging Face](https://huggingface.co/) - -- Support for models such as LSTM, GRU, Transformer etc. - -- Multiple training acceleration methods including automatic mixed precision training and data collecting wth half precision policy network - -- User-defined training models, reward models, training data and environment support - -- Support for [gymnasium](https://gymnasium.farama.org/) environments - -- Dictionary observation space support - -- Popular visualization tools such as [wandb](https://wandb.ai/), [tensorboardX](https://tensorboardx.readthedocs.io/en/latest/index.html) are supported - -- Serial or parallel environment training while ensuring consistent results in both modes - -- Chinese and English documentation - -- Provides unit testing and code coverage testing - -- Compliant with Black Code Style guidelines and type checking - -This framework has undergone multiple iterations by the [OpenRL-Lab](https://github.com/OpenRL-Lab) team which has applied it in academic research. -It has now become a mature reinforcement learning framework. - -OpenRL-Lab will continue to maintain and update OpenRL, and we welcome everyone to join our [open-source community](./docs/CONTRIBUTING_en.md) -to contribute towards the development of reinforcement learning. - -For more information about OpenRL, please refer to the [documentation](https://openrl-docs.readthedocs.io/en/latest/). - -## Outline - -- [Welcome to OpenRL](#welcome-to-openrl) -- [Outline](#outline) -- [Installation](#installation) -- [Use Docker](#use-docker) -- [Quick Start](#quick-start) -- [Gallery](#gallery) -- [Projects Using OpenRL](#projects-using-openrl) -- [Feedback and Contribution](#feedback-and-contribution) -- [Maintainers](#maintainers) -- [Supporters](#supporters) - - [↳ Contributors](#-contributors) - - [↳ Stargazers](#-stargazers) - - [↳ Forkers](#-forkers) -- [Citing OpenRL](#citing-openrl) -- [License](#license) -- [Acknowledgments](#acknowledgments) - -## Installation - -Users can directly install OpenRL via pip: - -```bash -pip install openrl -``` - -If users are using Anaconda or Miniconda, they can also install OpenRL via conda: - -```bash -conda install -c openrl openrl -``` - -Users who want to modify the source code can also install OpenRL from the source code: - -```bash -git clone https://github.com/OpenRL-Lab/openrl.git && cd openrl -pip install -e . -``` - -After installation, users can check the version of OpenRL through command line: - -```bash -openrl --version -``` - -**Tips**: No installation required, try OpenRL online through Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15VBA-B7AJF8dBazzRcWAxJxZI7Pl9m-g?usp=sharing) - -## Use Docker - -OpenRL currently provides Docker images with and without GPU support. -If the user's computer does not have an NVIDIA GPU, they can obtain an image without the GPU plugin using the following command: -```bash -sudo docker pull openrllab/openrl-cpu -``` - -If the user wants to accelerate training with a GPU, they can obtain it using the following command: -```bash -sudo docker pull openrllab/openrl -``` - -After successfully pulling the image, users can run OpenRL's Docker image using the following commands: -```bash -# Without GPU acceleration -sudo docker run -it openrllab/openrl-cpu -# With GPU acceleration -sudo docker run -it --gpus all --net host openrllab/openrl -``` - -Once inside the Docker container, users can check OpenRL's version and then run test cases using these commands: -```bash -# Check OpenRL version in Docker container -openrl --version -# Run test case -openrl --mode train --env CartPole-v1 -``` - -## Quick Start - -OpenRL provides a simple and easy-to-use interface for beginners in reinforcement learning. -Below is an example of using the PPO algorithm to train the `CartPole` environment: -```python -# train_ppo.py -from openrl.envs.common import make -from openrl.modules.common import PPONet as Net -from openrl.runners.common import PPOAgent as Agent -env = make("CartPole-v1", env_num=9) # Create an environment and set the environment parallelism to 9. -net = Net(env) # Create neural network. -agent = Agent(net) # Initialize the agent. -agent.train(total_time_steps=20000) # Start training and set the total number of steps to 20,000 for the running environment. -``` -Training an agent using OpenRL only requires four simple steps: -**Create Environment** => **Initialize Model** => **Initialize Agent** => **Start Training**! - -For a well-trained agent, users can also easily test the agent: -```python -# train_ppo.py -from openrl.envs.common import make -from openrl.modules.common import PPONet as Net -from openrl.runners.common import PPOAgent as Agent -agent = Agent(Net(make("CartPole-v1", env_num=9))) # Initialize trainer. -agent.train(total_time_steps=20000) -# Create an environment for test, set the parallelism of the environment to 9, and set the rendering mode to group_human. -env = make("CartPole-v1", env_num=9, render_mode="group_human") -agent.set_env(env) # The agent requires an interactive environment. -obs, info = env.reset() # Initialize the environment to obtain initial observations and environmental information. -while True: - action, _ = agent.act(obs) # The agent predicts the next action based on environmental observations. - # The environment takes one step according to the action, obtains the next observation, reward, whether it ends and environmental information. - obs, r, done, info = env.step(action) - if any(done): break -env.close() # Close test environment -``` -Executing the above code on a regular laptop only takes a few seconds -to complete the training. Below shows the visualization of the agent: - -
- -
- - -**Tips:** Users can also quickly train the `CartPole` environment by executing a command line in the terminal. -```bash -openrl --mode train --env CartPole-v1 -``` - -For training tasks such as multi-agent and natural language processing, OpenRL also provides a similarly simple and easy-to-use interface. - -For information on how to perform multi-agent training, set hyperparameters for training, load training configurations, use wandb, save GIF animations, etc., please refer to: -- [Multi-Agent Training Example](https://openrl-docs.readthedocs.io/en/latest/quick_start/multi_agent_RL.html) - -For information on natural language task training, loading models/datasets on Hugging Face, customizing training models/reward models, etc., please refer to: -- [Dialogue Task Training Example](https://openrl-docs.readthedocs.io/en/latest/quick_start/train_nlp.html) - -For more information about OpenRL, please refer to the [documentation](https://openrl-docs.readthedocs.io/en/latest/). - -## Gallery - -In order to facilitate users' familiarity with the framework, we provide more examples and demos of using OpenRL in [Gallery](./Gallery.md). -Users are also welcome to contribute their own training examples and demos to the Gallery. - -## Projects Using OpenRL - -We have listed research projects that use OpenRL in the [OpenRL Project](./Project.md). -If you are using OpenRL in your research project, you are also welcome to join this list. - -## Feedback and Contribution -- If you have any questions or find bugs, you can check or ask in the [Issues](https://github.com/OpenRL-Lab/openrl/issues). -- Join the QQ group: [OpenRL Official Communication Group](./docs/images/qq.png) -
- -
- -- Join the [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) group to discuss OpenRL usage and development with us. -- Join the [Discord](https://discord.gg/tyy96TGbep) group to discuss OpenRL usage and development with us. -- Send an E-mail to: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com) -- Join the [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions). - -The OpenRL framework is still under continuous development and documentation. -We welcome you to join us in making this project better: -- How to contribute code: Read the [Contributors' Guide](./docs/CONTRIBUTING_en.md) -- [OpenRL Roadmap](https://github.com/OpenRL-Lab/openrl/issues/2) - -## Maintainers - -At present, OpenRL is maintained by the following maintainers: -- [Shiyu Huang](https://huangshiyu13.github.io/)([@huangshiyu13](https://github.com/huangshiyu13)) -- Wenze Chen([@Chen001117](https://github.com/Chen001117)) - -Welcome more contributors to join our maintenance team (send an E-mail to [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com) -to apply for joining the OpenRL team). - -## Supporters - -### ↳ Contributors - - - - - -### ↳ Stargazers - -[![Stargazers repo roster for @OpenRL-Lab/openrl](https://reporoster.com/stars/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/stargazers) - -### ↳ Forkers - -[![Forkers repo roster for @OpenRL-Lab/openrl](https://reporoster.com/forks/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/network/members) - -## Citing OpenRL - -If our work has been helpful to you, please feel free to cite us: -```latex -@misc{openrl2023, - title={OpenRL}, - author={OpenRL Contributors}, - publisher = {GitHub}, - howpublished = {\url{https://github.com/OpenRL-Lab/openrl}}, - year={2023}, -} -``` - -## Star History - -[![Star History Chart](https://api.star-history.com/svg?repos=OpenRL-Lab/openrl&type=Date)](https://star-history.com/#OpenRL-Lab/openrl&Date) - -## License -OpenRL under the Apache 2.0 license. - -## Acknowledgments -The development of the OpenRL framework has drawn on the strengths of other reinforcement learning frameworks: - -- Stable-baselines3: https://github.com/DLR-RM/stable-baselines3 -- pytorch-a2c-ppo-acktr-gail: https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail -- MAPPO: https://github.com/marlbenchmark/on-policy -- Gymnasium: https://github.com/Farama-Foundation/Gymnasium -- DI-engine: https://github.com/opendilab/DI-engine/ -- Tianshou: https://github.com/thu-ml/tianshou -- RL4LMs: https://github.com/allenai/RL4LMs diff --git a/README_zh.md b/README_zh.md new file mode 100644 index 00000000..c6b204b2 --- /dev/null +++ b/README_zh.md @@ -0,0 +1,288 @@ +
+ +
+ +--- +[![PyPI](https://img.shields.io/pypi/v/openrl)](https://pypi.org/project/openrl/) +![PyPI - Python Version](https://img.shields.io/pypi/pyversions/openrl) +[![Anaconda-Server Badge](https://anaconda.org/openrl/openrl/badges/version.svg)](https://anaconda.org/openrl/openrl) +[![Anaconda-Server Badge](https://anaconda.org/openrl/openrl/badges/latest_release_date.svg)](https://anaconda.org/openrl/openrl) +[![Anaconda-Server Badge](https://anaconda.org/openrl/openrl/badges/downloads.svg)](https://anaconda.org/openrl/openrl) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + + +[![Hits-of-Code](https://hitsofcode.com/github/OpenRL-Lab/openrl?branch=main)](https://hitsofcode.com/github/OpenRL-Lab/openrl/view?branch=main) +[![codecov](https://codecov.io/gh/OpenRL-Lab/openrl/branch/main/graph/badge.svg?token=T6BqaTiT0l)](https://codecov.io/gh/OpenRL-Lab/openrl) + +[![Documentation Status](https://readthedocs.org/projects/openrl-docs/badge/?version=latest)](https://openrl-docs.readthedocs.io/zh/latest/?badge=latest) +[![Read the Docs](https://img.shields.io/readthedocs/openrl-docs-zh?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://openrl-docs.readthedocs.io/zh/latest/) + +![GitHub Org's stars](https://img.shields.io/github/stars/OpenRL-Lab) +[![GitHub stars](https://img.shields.io/github/stars/OpenRL-Lab/openrl)](https://github.com/opendilab/OpenRL/stargazers) +[![GitHub forks](https://img.shields.io/github/forks/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/network) +![GitHub commit activity](https://img.shields.io/github/commit-activity/m/OpenRL-Lab/openrl) +[![GitHub issues](https://img.shields.io/github/issues/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/issues) +[![GitHub pulls](https://img.shields.io/github/issues-pr/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/pulls) +[![Contributors](https://img.shields.io/github/contributors/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/graphs/contributors) +[![GitHub license](https://img.shields.io/github/license/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/blob/master/LICENSE) + +[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/tyy96TGbep) +[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) + +OpenRL-v0.0.11 is updated on May 19, 2023 + +The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with OpenRL, you can switch to the stable branch. + +## 欢迎来到OpenRL + +[English](./README.md) | [中文文档](https://openrl-docs.readthedocs.io/zh/latest/) | [Documentation](https://openrl-docs.readthedocs.io/en/latest/) + +OpenRL是一个开源的通用强化学习研究框架,支持单智能体、多智能体、自然语言等多种任务的训练。 OpenRL基于PyTorch进行开发,目标是为强化学习研究社区提供一个简单易用、灵活高效、可持续扩展的平台。 +目前,OpenRL支持的特性包括: + +- 简单易用且支持单智能体、多智能体训练的通用接口 +- 支持自然语言任务(如对话任务)的强化学习训练 +- 支持从[Hugging Face](https://huggingface.co/)上导入模型和数据 +- 支持LSTM,GRU,Transformer等模型 +- 支持多种训练加速,例如:自动混合精度训练,半精度策略网络收集数据等 +- 支持用户自定义训练模型、奖励模型、训练数据以及环境 +- 支持[gymnasium](https://gymnasium.farama.org/)环境 +- 支持字典观测空间 +- 支持[wandb](https://wandb.ai/),[tensorboardX](https://tensorboardx.readthedocs.io/en/latest/index.html)等主流训练可视化工具 +- 支持环境的串行和并行训练,同时保证两种模式下的训练效果一致 +- 中英文文档 +- 提供单元测试和代码覆盖测试 +- 符合Black Code Style和类型检查 + +OpenRL目前支持的算法(更多详情请参考 [Gallery](Gallery.md)): +- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347) +- [Multi-agent PPO (MAPPO)](https://arxiv.org/abs/2103.01955) +- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515) +- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953) + +OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)): +- [Gymnasium](https://gymnasium.farama.org/) +- [MPE](https://github.com/openai/multiagent-particle-envs) +- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros) +- [Gym Retro](https://github.com/openai/retro) + + +该框架经过了[OpenRL-Lab](https://github.com/OpenRL-Lab)的多次迭代并应用于学术研究,目前已经成为了一个成熟的强化学习框架。 +OpenRL-Lab将持续维护和更新OpenRL,欢迎大家加入我们的[开源社区](./docs/CONTRIBUTING_zh.md),一起为强化学习的发展做出贡献。 +关于OpenRL的更多信息,请参考[文档](https://openrl-docs.readthedocs.io/zh/latest/)。 + +## 目录 + +- [欢迎来到OpenRL](#欢迎来到openrl) +- [目录](#目录) +- [安装](#安装) +- [使用Docker](#使用docker) +- [快速上手](#快速上手) +- [Gallery](#Gallery) +- [使用OpenRL的项目](#使用OpenRL的项目) +- [反馈和贡献](#反馈和贡献) +- [维护人员](#维护人员) +- [支持者](#支持者) + - [↳ Contributors](#-contributors) + - [↳ Stargazers](#-stargazers) + - [↳ Forkers](#-forkers) +- [Citing OpenRL](#citing-openrl) +- [License](#license) +- [Acknowledgments](#acknowledgments) + +## 安装 + +用户可以直接通过pip安装OpenRL: +```bash +pip install openrl +``` + +如果用户使用了Anaconda或者Miniconda,也可以通过conda安装OpenRL: +```bash +conda install -c openrl openrl +``` + +想要修改源码的用户也可以从源码安装OpenRL: +```bash +git clone https://github.com/OpenRL-Lab/openrl.git && cd openrl +pip install -e . +``` + +安装完成后,用户可以直接通过命令行查看OpenRL的版本: +```bash +openrl --version +``` + +**Tips**:无需安装,通过Colab在线试用OpenRL: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15VBA-B7AJF8dBazzRcWAxJxZI7Pl9m-g?usp=sharing) + +## 使用Docker + +OpenRL目前也提供了包含显卡支持和非显卡支持的Docker镜像。 +如果用户的电脑上没有英伟达显卡,则可以通过以下命令获取不包含显卡插件的镜像: + +```bash +sudo docker pull openrllab/openrl-cpu +``` + +如果用户想要通过显卡加速训练,则可以通过以下命令获取: +```bash +sudo docker pull openrllab/openrl +``` + +镜像拉取成功后,用户可以通过以下命令运行OpenRL的Docker镜像: +```bash +# 不带显卡加速 +sudo docker run -it openrllab/openrl-cpu +# 带显卡加速 +sudo docker run -it --gpus all --net host openrllab/openrl +``` + +进入Docker镜像后,用户可以通过以下命令查看OpenRL的版本然后运行测例: +```bash +# 查看Docker镜像中OpenRL的版本 +openrl --version +# 运行测例 +openrl --mode train --env CartPole-v1 +``` + + +## 快速上手 + +OpenRL为强化学习入门用户提供了简单易用的接口, +下面是一个使用PPO算法训练`CartPole`环境的例子: +```python +# train_ppo.py +from openrl.envs.common import make +from openrl.modules.common import PPONet as Net +from openrl.runners.common import PPOAgent as Agent +env = make("CartPole-v1", env_num=9) # 创建环境,并设置环境并行数为9 +net = Net(env) # 创建神经网络 +agent = Agent(net) # 初始化智能体 +agent.train(total_time_steps=20000) # 开始训练,并设置环境运行总步数为20000 +``` +使用OpenRL训练智能体只需要简单的四步:**创建环境**=>**初始化模型**=>**初始化智能体**=>**开始训练**! + +对于训练好的智能体,用户也可以方便地进行智能体的测试: +```python +# train_ppo.py +from openrl.envs.common import make +from openrl.modules.common import PPONet as Net +from openrl.runners.common import PPOAgent as Agent +agent = Agent(Net(make("CartPole-v1", env_num=9))) # 初始化训练器 +agent.train(total_time_steps=20000) +# 创建用于测试的环境,并设置环境并行数为9,设置渲染模式为group_human +env = make("CartPole-v1", env_num=9, render_mode="group_human") +agent.set_env(env) # 训练好的智能体设置需要交互的环境 +obs, info = env.reset() # 环境进行初始化,得到初始的观测值和环境信息 +while True: + action, _ = agent.act(obs) # 智能体根据环境观测输入预测下一个动作 + # 环境根据动作执行一步,得到下一个观测值、奖励、是否结束、环境信息 + obs, r, done, info = env.step(action) + if any(done): break +env.close() # 关闭测试环境 +``` +在普通笔记本电脑上执行以上代码,只需要几秒钟,便可以完成该智能体的训练和可视化测试: + +
+ +
+ + +**Tips:** 用户还可以在终端中通过执行一行命令快速训练`CartPole`环境: +```bash +openrl --mode train --env CartPole-v1 +``` + +对于多智能体、自然语言等任务的训练,OpenRL也提供了同样简单易用的接口。 + +关于如何进行多智能体训练、训练超参数设置、训练配置文件加载、wandb使用、保存gif动画等信息,请参考: +- [多智能体训练例子](https://openrl-docs.readthedocs.io/zh/latest/quick_start/multi_agent_RL.html) + +关于自然语言任务训练、Hugging Face上模型(数据)加载、自定义训练模型(奖励模型)等信息,请参考: +- [对话任务训练例子](https://openrl-docs.readthedocs.io/zh/latest/quick_start/train_nlp.html) + +关于OpenRL的更多信息,请参考[文档](https://openrl-docs.readthedocs.io/zh/latest/)。 + +## Gallery + +为了方便用户熟悉该框架, +我们在[Gallery](Gallery.md)中提供了更多使用OpenRL的示例和demo。 +也欢迎用户将自己的训练示例和demo贡献到Gallery中。 + +## 使用OpenRL的研究项目 + +我们在 [OpenRL Project](Project.md) 中列举了使用OpenRL的研究项目。 +如果你在研究项目中使用了OpenRL,也欢迎加入该列表。 + +## 反馈和贡献 +- 有问题和发现bugs可以到 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 处进行查询或提问 +- 加入QQ群:[OpenRL官方交流群](docs/images/qq.png) + +
+ +
+ +- 加入 [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) 群组,与我们一起讨论OpenRL的使用和开发。 +- 加入 [Discord](https://discord.gg/tyy96TGbep) 群组,与我们一起讨论OpenRL的使用和开发。 +- 发送邮件到: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com) +- 加入 [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions) + +OpenRL框架目前还在持续开发和文档建设,欢迎加入我们让该项目变得更好: + +- 如何贡献代码:阅读 [贡献者手册](./docs/CONTRIBUTING_zh.md) +- [OpenRL开发计划](https://github.com/OpenRL-Lab/openrl/issues/2) + +## 维护人员 + +目前,OpenRL由以下维护人员维护: +- [Shiyu Huang](https://huangshiyu13.github.io/)([@huangshiyu13](https://github.com/huangshiyu13)) +- Wenze Chen([@Chen001117](https://github.com/Chen001117)) + +欢迎更多的贡献者加入我们的维护团队 (发送邮件到[huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)申请加入OpenRL团队)。 + +## 支持者 + +### ↳ Contributors + + + + + +### ↳ Stargazers + +[![Stargazers repo roster for @OpenRL-Lab/openrl](https://reporoster.com/stars/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/stargazers) + +### ↳ Forkers + +[![Forkers repo roster for @OpenRL-Lab/openrl](https://reporoster.com/forks/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/network/members) + +## Citing OpenRL + +如果我们的工作对你有帮助,欢迎引用我们: +```latex +@misc{openrl2023, + title={OpenRL}, + author={OpenRL Contributors}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/OpenRL-Lab/openrl}}, + year={2023}, +} +``` + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=OpenRL-Lab/openrl&type=Date)](https://star-history.com/#OpenRL-Lab/openrl&Date) + +## License +OpenRL under the Apache 2.0 license. + +## Acknowledgments +The development of the OpenRL framework has drawn on the strengths of other reinforcement learning frameworks: + +- Stable-baselines3: https://github.com/DLR-RM/stable-baselines3 +- pytorch-a2c-ppo-acktr-gail: https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail +- MAPPO: https://github.com/marlbenchmark/on-policy +- Gymnasium: https://github.com/Farama-Foundation/Gymnasium +- DI-engine: https://github.com/opendilab/DI-engine/ +- Tianshou: https://github.com/thu-ml/tianshou +- RL4LMs: https://github.com/allenai/RL4LMs diff --git a/docs/CONTRIBUTING_en.md b/docs/CONTRIBUTING_en.md deleted file mode 100644 index 64de484f..00000000 --- a/docs/CONTRIBUTING_en.md +++ /dev/null @@ -1,45 +0,0 @@ -## How to Contribute to OpenRL - -The OpenRL community welcomes anyone to contribute to the development of OpenRL, whether you are a developer or a user. Your feedback and contributions are our driving force! You can join the contribution of OpenRL in the following ways: - -- As an OpenRL user, discover bugs in OpenRL and submit an [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose). -- As an OpenRL user, discover errors in the documentation of OpenRL and submit an [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose). -- Write test code to improve the code coverage of OpenRL (you can check the code coverage situation of OpenRL from [here](https://app.codecov.io/gh/OpenRL-Lab/openrl)). You can choose interested code snippets for writing test codes. -- As an open-source developer, fix existing bugs for OpenRL. -- As an open-source developer, add new environments and examples for OpenRL. -- As an open-source developer, add new algorithms for OpenRL. - -## Contributing to OpenRL - -Welcome to contribute to the development of OpenRL. We appreciate your contribution! - -- If you want to contribute new features, please create a new [issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose) first -to discuss the implementation details of this feature. If the feature is approved by everyone, you can start implementing the code. -- You can also check for unimplemented features and existing bugs in [Issues](https://github.com/OpenRL-Lab/openrl/issues), -reply in the corresponding issue that you want to solve it, and then start implementing the code. - -After completing your code implementation, you need to pull the latest `main` branch and merge it. -After resolving any merge conflicts, -you can submit your code for merging into OpenRL's main branch through [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls). - -Before submitting a Pull Request, you need to complete [Code Testing and Code Formatting](#code-testing-and-code-formatting). - -Then, your Pull Request needs to pass automated testing on GitHub. - -Finally, at least one maintainer's review and approval are required before being merged into the main branch. - -## Code Testing and Code Formatting - -Before submitting a Pull Request, make sure that your code passes unit tests and conforms with OpenRL's coding style. - -Firstly, you should install the test-related packages: `pip install -e ".[test]"` - -Then, ensure that unit tests pass by executing `make test`. - -Next, format your code by running `make format`. - -Lastly, run `make commit-checks` to check if your code complies with OpenRL's coding style. - -> Tip: OpenRL uses [black](https://github.com/psf/black) coding style. -You can install black plugins in your editor as shown in the [official website](https://black.readthedocs.io/en/stable/integrations/editors.html) -to help automatically format codes. \ No newline at end of file diff --git a/docs/CONTRIBUTING_zh.md b/docs/CONTRIBUTING_zh.md new file mode 100644 index 00000000..9fb1d012 --- /dev/null +++ b/docs/CONTRIBUTING_zh.md @@ -0,0 +1,51 @@ +## 如何参与OpenRL的建设 + +OpenRL社区欢迎任何人参与到OpenRL的建设中来,无论您是开发者还是用户,您的反馈和贡献都是我们前进的动力! +您可以通过以下方式加入到OpenRL的贡献中来: + +- 作为OpenRL的用户,发现OpenRL中的bug,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。 +- 作为OpenRL的用户,发现OpenRL文档中的错误,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。 +- 写测试代码,提升OpenRL的代码测试覆盖率(大家可以从[这里](https://app.codecov.io/gh/OpenRL-Lab/openrl)查到OpenRL的代码测试覆盖情况)。 + 您可以选择感兴趣的代码片段进行编写代码测试, +- 作为OpenRL的开发者,为OpenRL修复已有的bug。 +- 作为OpenRL的开发者,为OpenRL添加新的环境和样例。 +- 作为OpenRL的开发者,为OpenRL添加新的算法。 + +## 贡献者手册 + +欢迎更多的人参与到OpenRL的开发中来,我们非常欢迎您的贡献! + +- 如果您想要贡献新的功能,请先在请先创建一个新的[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose), + 以便我们讨论这个功能的实现细节。如果该功能得到了大家的认可,您可以开始进行代码实现。 +- 您也可以在 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 中查看未被实现的功能和仍然存的在bug, +在对应的issue中进行回复,说明您想要解决该issue,然后开始进行代码实现。 + +在您完成了代码实现之后,您需要拉取最新的`main`分支并进行合并。 +解决合并冲突后, +您可以通过提交 [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls) +的方式将您的代码合并到OpenRL的main分支中。 + +在提交Pull Request前,您需要完成 [代码测试和代码格式化](#代码测试和代码格式化)。 + +然后,您的Pull Request需要通过GitHub上的自动化测试。 + +最后,需要得到至少一个开发人员的review和批准,才能被合并到main分支中。 + +## 代码测试和代码格式化 + +在您提交Pull Request之前,您需要确保您的代码通过了单元测试,并且符合OpenRL的代码风格。 + +首先,您需要安装测试相关的包:`pip install -e ".[test]"` + +然后,您需要确保单元测试通过,这可以通过执行`make test`来完成。 + +然后,您需要执行`make format`来格式化您的代码。 + +最后,您需要执行`make commit-checks`来检查您的代码是否符合OpenRL的代码风格。 + +> 小技巧: OpenRL使用 [black](https://github.com/psf/black) 代码风格。 +您可以在您的编辑器中安装black的[插件](https://black.readthedocs.io/en/stable/integrations/editors.html), +来帮助您自动格式化代码。 + + + diff --git a/examples/mpe/README.md b/examples/mpe/README.md index c1e25faa..f3ed00ef 100644 --- a/examples/mpe/README.md +++ b/examples/mpe/README.md @@ -1,7 +1,20 @@ ## How to Use -Users can train MPE via: +Train MPE with [MAPPO]((https://arxiv.org/abs/2103.01955)) algorithm: ```shell python train_ppo.py --config mpe_ppo.yaml +``` + +Train MPE with [JRPO](https://arxiv.org/abs/2302.07515) algorithm: + +```shell +python train_ppo.py --config mpe_jrpo.yaml +``` + + +Train MPE with [MAT](https://arxiv.org/abs/2205.14953) algorithm: + +```shell +python train_mat.py --config mpe_mat.yaml ``` \ No newline at end of file diff --git a/examples/mpe/mpe_jrpo.yaml b/examples/mpe/mpe_jrpo.yaml new file mode 100644 index 00000000..6a97afc1 --- /dev/null +++ b/examples/mpe/mpe_jrpo.yaml @@ -0,0 +1,12 @@ +seed: 0 +lr: 7e-4 +critic_lr: 7e-4 +episode_length: 25 +run_dir: ./run_results/ +experiment_name: train_mpe_jrpo +log_interval: 10 +use_recurrent_policy: true +use_joint_action_loss: true +use_valuenorm: true +use_adv_normalize: true +wandb_entity: openrl-lab \ No newline at end of file diff --git a/examples/mpe/mpe_mat.yaml b/examples/mpe/mpe_mat.yaml new file mode 100644 index 00000000..1c2305b2 --- /dev/null +++ b/examples/mpe/mpe_mat.yaml @@ -0,0 +1,9 @@ +seed: 0 +lr: 7e-4 +episode_length: 25 +run_dir: ./run_results/ +experiment_name: train_mpe_mat +log_interval: 10 +use_valuenorm: true +use_adv_normalize: true +wandb_entity: openrl-lab \ No newline at end of file diff --git a/examples/mpe/mpe_ppo.yaml b/examples/mpe/mpe_ppo.yaml index 5b31e006..d058bb98 100644 --- a/examples/mpe/mpe_ppo.yaml +++ b/examples/mpe/mpe_ppo.yaml @@ -6,7 +6,7 @@ run_dir: ./run_results/ experiment_name: train_mpe log_interval: 10 use_recurrent_policy: true -use_joint_action_loss: true +use_joint_action_loss: false use_valuenorm: true use_adv_normalize: true wandb_entity: openrl-lab \ No newline at end of file diff --git a/examples/mpe/train_mat.py b/examples/mpe/train_mat.py new file mode 100644 index 00000000..12816906 --- /dev/null +++ b/examples/mpe/train_mat.py @@ -0,0 +1,62 @@ +"""""" + +import numpy as np + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.wrappers.mat_wrapper import MATWrapper +from openrl.modules.common import MATNet as Net +from openrl.runners.common import MATAgent as Agent + + +def train(): + # 创建 环境 + env_num = 100 + env = make( + "simple_spread", + env_num=env_num, + asynchronous=True, + ) + env = MATWrapper(env) + + # 创建 神经网络 + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args() + net = Net(env, cfg=cfg, device="cuda") + + # 初始化训练器 + agent = Agent(net, use_wandb=True) + # 开始训练 + agent.train(total_time_steps=5000000) + env.close() + agent.save("./mat_agent/") + return agent + + +def evaluation(agent): + # render_model = "group_human" + render_model = None + env_num = 9 + env = make( + "simple_spread", render_mode=render_model, env_num=env_num, asynchronous=False + ) + env = MATWrapper(env) + agent.load("./mat_agent/") + agent.set_env(env) + obs, info = env.reset(seed=0) + done = False + step = 0 + total_reward = 0 + while not np.any(done): + # 智能体根据 observation 预测下一个动作 + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + step += 1 + total_reward += np.mean(r) + print(f"total_reward: {total_reward}") + env.close() + + +if __name__ == "__main__": + agent = train() + evaluation(agent) diff --git a/openrl/algorithms/dqn.py b/openrl/algorithms/dqn.py index 0c21c33c..2b50060a 100644 --- a/openrl/algorithms/dqn.py +++ b/openrl/algorithms/dqn.py @@ -220,7 +220,9 @@ def train(self, buffer, turn_on=True): elif self._use_naive_recurrent: raise NotImplementedError else: - data_generator = buffer.feed_forward_generator(None, self.num_mini_batch) + data_generator = buffer.feed_forward_generator( + None, self.num_mini_batch + ) for sample in data_generator: (q_loss) = self.dqn_update(sample, turn_on) diff --git a/openrl/algorithms/mat.py b/openrl/algorithms/mat.py new file mode 100644 index 00000000..0b1e0b61 --- /dev/null +++ b/openrl/algorithms/mat.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from openrl.algorithms.ppo import PPOAlgorithm + + +class MATAlgorithm(PPOAlgorithm): + def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on): + loss_list = [] + + loss = ( + policy_loss + - dist_entropy * self.entropy_coef + + value_loss * self.value_loss_coef + ) + loss_list.append(loss) + + return loss_list + + def get_data_generator(self, buffer, advantages): + data_generator = buffer.feed_forward_generator_transformer( + advantages, self.num_mini_batch + ) + return data_generator diff --git a/openrl/algorithms/ppo.py b/openrl/algorithms/ppo.py index 68e969f7..acd3d84e 100644 --- a/openrl/algorithms/ppo.py +++ b/openrl/algorithms/ppo.py @@ -110,43 +110,26 @@ def ppo_update(self, sample, turn_on=True): for loss in loss_list: loss.backward() - if "transformer" in self.algo_module.models: - if self._use_max_grad_norm: - grad_norm = nn.utils.clip_grad_norm_( - self.algo_module.models["transformer"].parameters(), - self.max_grad_norm, - ) - else: - grad_norm = get_gard_norm( - self.algo_module.models["transformer"].parameters() - ) - critic_grad_norm = grad_norm - actor_grad_norm = grad_norm - + # else: + if self._use_share_model: + actor_para = self.algo_module.models["model"].get_actor_para() else: - if self._use_share_model: - actor_para = self.algo_module.models["model"].get_actor_para() - else: - actor_para = self.algo_module.models["policy"].parameters() + actor_para = self.algo_module.models["policy"].parameters() - if self._use_max_grad_norm: - actor_grad_norm = nn.utils.clip_grad_norm_( - actor_para, self.max_grad_norm - ) - else: - actor_grad_norm = get_gard_norm(actor_para) + if self._use_max_grad_norm: + actor_grad_norm = nn.utils.clip_grad_norm_(actor_para, self.max_grad_norm) + else: + actor_grad_norm = get_gard_norm(actor_para) - if self._use_share_model: - critic_para = self.algo_module.models["model"].get_critic_para() - else: - critic_para = self.algo_module.models["critic"].parameters() + if self._use_share_model: + critic_para = self.algo_module.models["model"].get_critic_para() + else: + critic_para = self.algo_module.models["critic"].parameters() - if self._use_max_grad_norm: - critic_grad_norm = nn.utils.clip_grad_norm_( - critic_para, self.max_grad_norm - ) - else: - critic_grad_norm = get_gard_norm(critic_para) + if self._use_max_grad_norm: + critic_grad_norm = nn.utils.clip_grad_norm_(critic_para, self.max_grad_norm) + else: + critic_grad_norm = get_gard_norm(critic_para) if self.use_amp: for optimizer in self.algo_module.optimizers.values(): @@ -219,6 +202,16 @@ def to_single_np(self, input): reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:]) return reshape_input[:, 0, ...] + def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on): + loss_list = [] + if turn_on: + final_p_loss = policy_loss - dist_entropy * self.entropy_coef + + loss_list.append(final_p_loss) + final_v_loss = value_loss * self.value_loss_coef + loss_list.append(final_v_loss) + return loss_list + def prepare_loss( self, critic_obs_batch, @@ -235,8 +228,6 @@ def prepare_loss( active_masks_batch, turn_on, ): - loss_list = [] - if self.use_joint_action_loss: critic_obs_batch = self.to_single_np(critic_obs_batch) rnn_states_critic_batch = self.to_single_np(rnn_states_critic_batch) @@ -341,21 +332,30 @@ def prepare_loss( active_masks_batch, ) - if "transformer" in self.algo_module.models: - loss = ( - policy_loss - - dist_entropy * self.entropy_coef - + value_loss * self.value_loss_coef + loss_list = self.construct_loss_list( + policy_loss, dist_entropy, value_loss, turn_on + ) + return loss_list, value_loss, policy_loss, dist_entropy, ratio + + def get_data_generator(self, buffer, advantages): + if self._use_recurrent_policy: + if self.use_joint_action_loss: + data_generator = buffer.recurrent_generator_v3( + advantages, self.num_mini_batch, self.data_chunk_length + ) + else: + data_generator = buffer.recurrent_generator( + advantages, self.num_mini_batch, self.data_chunk_length + ) + elif self._use_naive_recurrent: + data_generator = buffer.naive_recurrent_generator( + advantages, self.num_mini_batch ) - loss_list.append(loss) else: - if turn_on: - final_p_loss = policy_loss - dist_entropy * self.entropy_coef - - loss_list.append(final_p_loss) - final_v_loss = value_loss * self.value_loss_coef - loss_list.append(final_v_loss) - return loss_list, value_loss, policy_loss, dist_entropy, ratio + data_generator = buffer.feed_forward_generator( + advantages, self.num_mini_batch + ) + return data_generator def train(self, buffer, turn_on=True): if self._use_popart or self._use_valuenorm: @@ -396,27 +396,7 @@ def train(self, buffer, turn_on=True): train_info["reduced_policy_loss"] = 0 for _ in range(self.ppo_epoch): - if "transformer" in self.algo_module.models: - data_generator = buffer.feed_forward_generator_transformer( - advantages, self.num_mini_batch - ) - elif self._use_recurrent_policy: - if self.use_joint_action_loss: - data_generator = buffer.recurrent_generator_v3( - advantages, self.num_mini_batch, self.data_chunk_length - ) - else: - data_generator = buffer.recurrent_generator( - advantages, self.num_mini_batch, self.data_chunk_length - ) - elif self._use_naive_recurrent: - data_generator = buffer.naive_recurrent_generator( - advantages, self.num_mini_batch - ) - else: - data_generator = buffer.feed_forward_generator( - advantages, self.num_mini_batch - ) + data_generator = self.get_data_generator(buffer, advantages) for sample in data_generator: ( diff --git a/openrl/buffers/offpolicy_replay_data.py b/openrl/buffers/offpolicy_replay_data.py index 27ed6898..fa34c660 100644 --- a/openrl/buffers/offpolicy_replay_data.py +++ b/openrl/buffers/offpolicy_replay_data.py @@ -159,9 +159,10 @@ def insert( self.policy_obs[self.step + 1] = policy_obs.copy() self.next_critic_obs[self.step + 1] = next_critic_obs.copy() self.next_policy_obs[self.step + 1] = next_policy_obs.copy() - - self.rnn_states[self.step + 1] = rnn_states.copy() - self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy() + if rnn_states is not None: + self.rnn_states[self.step + 1] = rnn_states.copy() + if rnn_states_critic is not None: + self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy() self.actions[self.step] = actions.copy() self.action_log_probs[self.step] = action_log_probs.copy() self.value_preds[self.step] = value_preds.copy() diff --git a/openrl/buffers/replay_data.py b/openrl/buffers/replay_data.py index 2fc16253..768c265c 100644 --- a/openrl/buffers/replay_data.py +++ b/openrl/buffers/replay_data.py @@ -264,9 +264,10 @@ def insert( else: self.critic_obs[self.step + 1] = critic_obs.copy() self.policy_obs[self.step + 1] = policy_obs.copy() - - self.rnn_states[self.step + 1] = rnn_states.copy() - self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy() + if rnn_states is not None: + self.rnn_states[self.step + 1] = rnn_states.copy() + if rnn_states_critic is not None: + self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy() self.actions[self.step] = actions.copy() self.action_log_probs[self.step] = action_log_probs.copy() self.value_preds[self.step] = value_preds.copy() diff --git a/openrl/drivers/offpolicy_driver.py b/openrl/drivers/offpolicy_driver.py index 0f8d498a..20375e56 100644 --- a/openrl/drivers/offpolicy_driver.py +++ b/openrl/drivers/offpolicy_driver.py @@ -199,9 +199,9 @@ def act( actions = np.expand_dims(q_values.argmax(axis=-1), axis=-1) if random.random() > epsilon: - actions = np.random.randint(low=0, - high=self.envs.action_space.n, - size=actions.shape) + actions = np.random.randint( + low=0, high=self.envs.action_space.n, size=actions.shape + ) return ( q_values, @@ -211,4 +211,3 @@ def act( def compute_returns(self): pass - diff --git a/openrl/drivers/onpolicy_driver.py b/openrl/drivers/onpolicy_driver.py index 99adbf07..7504b942 100644 --- a/openrl/drivers/onpolicy_driver.py +++ b/openrl/drivers/onpolicy_driver.py @@ -69,16 +69,17 @@ def add2buffer(self, data): rnn_states, rnn_states_critic, ) = data + if rnn_states is not None: + rnn_states[dones] = np.zeros( + (dones.sum(), self.recurrent_N, self.hidden_size), + dtype=np.float32, + ) - rnn_states[dones] = np.zeros( - (dones.sum(), self.recurrent_N, self.hidden_size), - dtype=np.float32, - ) - - rnn_states_critic[dones] = np.zeros( - (dones.sum(), *self.buffer.data.rnn_states_critic.shape[3:]), - dtype=np.float32, - ) + if rnn_states_critic is not None: + rnn_states_critic[dones] = np.zeros( + (dones.sum(), *self.buffer.data.rnn_states_critic.shape[3:]), + dtype=np.float32, + ) masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32) masks[dones] = np.zeros((dones.sum(), 1), dtype=np.float32) @@ -187,10 +188,12 @@ def act( action_log_probs = np.array( np.split(_t2n(action_log_prob), self.n_rollout_threads) ) - rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads)) - rnn_states_critic = np.array( - np.split(_t2n(rnn_states_critic), self.n_rollout_threads) - ) + if rnn_states is not None: + rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads)) + if rnn_states_critic is not None: + rnn_states_critic = np.array( + np.split(_t2n(rnn_states_critic), self.n_rollout_threads) + ) return ( values, diff --git a/openrl/envs/vec_env/wrappers/base_wrapper.py b/openrl/envs/vec_env/wrappers/base_wrapper.py index 656cfce6..590e56eb 100644 --- a/openrl/envs/vec_env/wrappers/base_wrapper.py +++ b/openrl/envs/vec_env/wrappers/base_wrapper.py @@ -113,9 +113,9 @@ def reset(self, **kwargs): """Reset all environments.""" return self.env.reset(**kwargs) - def step(self, actions): + def step(self, actions, *args, **kwargs): """Step all environments.""" - return self.env.step(actions) + return self.env.step(actions, *args, **kwargs) def close(self, **kwargs): return self.env.close(**kwargs) @@ -193,15 +193,14 @@ def reset(self, **kwargs): observation = results return self.observation(observation) - def step(self, actions): + def step(self, actions, *args, **kwargs): """Modifies the observation returned from the environment ``step`` using the :meth:`observation`.""" - results = self.env.step(actions) + results = self.env.step(actions, *args, **kwargs) if len(results) == 5: observation, reward, termination, truncation, info = results return ( self.observation(observation), - observation, reward, termination, truncation, @@ -211,7 +210,6 @@ def step(self, actions): observation, reward, done, info = results return ( self.observation(observation), - observation, reward, done, info, @@ -237,9 +235,9 @@ def observation(self, observation: ObsType) -> ObsType: class VectorActionWrapper(VecEnvWrapper): """Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments.""" - def step(self, actions: ActType): + def step(self, actions: ActType, *args, **kwargs): """Steps through the environment using a modified action by :meth:`action`.""" - return self.env.step(self.action(actions)) + return self.env.step(self.action(actions), *args, **kwargs) def actions(self, actions: ActType) -> ActType: """Transform the actions before sending them to the environment. @@ -256,9 +254,9 @@ def actions(self, actions: ActType) -> ActType: class VectorRewardWrapper(VecEnvWrapper): """Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments.""" - def step(self, actions): + def step(self, actions, *args, **kwargs): """Steps through the environment returning a reward modified by :meth:`reward`.""" - results = self.env.step(actions) + results = self.env.step(actions, *args, **kwargs) reward = self.reward(results[1]) return results[0], reward, *results[2:] diff --git a/openrl/envs/wrappers/mat_wrapper.py b/openrl/envs/wrappers/mat_wrapper.py new file mode 100644 index 00000000..fdf45199 --- /dev/null +++ b/openrl/envs/wrappers/mat_wrapper.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from openrl.envs.vec_env.wrappers.base_wrapper import VectorObservationWrapper + + +class MATWrapper(VectorObservationWrapper): + @property + def observation_space( + self, + ): + """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used.""" + if self._observation_space is None: + observation_space = self.env.observation_space + else: + observation_space = self._observation_space + + if ( + "critic" in observation_space.spaces.keys() + and "policy" in observation_space.spaces.keys() + ): + observation_space = observation_space["policy"] + return observation_space + + def observation(self, observation): + if self._observation_space is None: + observation_space = self.env.observation_space + else: + observation_space = self._observation_space + + if ( + "critic" in observation_space.spaces.keys() + and "policy" in observation_space.spaces.keys() + ): + observation = observation["policy"] + return observation diff --git a/openrl/modules/common/__init__.py b/openrl/modules/common/__init__.py index e1f20fc3..cef64e38 100644 --- a/openrl/modules/common/__init__.py +++ b/openrl/modules/common/__init__.py @@ -1,9 +1,11 @@ from .base_net import BaseNet from .dqn_net import DQNNet +from .mat_net import MATNet from .ppo_net import PPONet __all__ = [ "BaseNet", "PPONet", "DQNNet", + "MATNet", ] diff --git a/openrl/modules/common/mat_net.py b/openrl/modules/common/mat_net.py new file mode 100644 index 00000000..cdf4e43a --- /dev/null +++ b/openrl/modules/common/mat_net.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from typing import Any, Dict, Union + +import gymnasium as gym +import torch + +from openrl.modules.common.ppo_net import PPONet +from openrl.modules.networks.MAT_network import MultiAgentTransformer + + +class MATNet(PPONet): + def __init__( + self, + env: Union[gym.Env, str], + cfg=None, + device: Union[torch.device, str] = "cpu", + n_rollout_threads: int = 1, + model_dict: Dict[str, Any] = {"model": MultiAgentTransformer}, + ) -> None: + cfg.use_share_model = True + super().__init__( + env=env, + cfg=cfg, + device=device, + n_rollout_threads=n_rollout_threads, + model_dict=model_dict, + ) diff --git a/openrl/modules/common/ppo_net.py b/openrl/modules/common/ppo_net.py index 96907310..a95f0d3a 100644 --- a/openrl/modules/common/ppo_net.py +++ b/openrl/modules/common/ppo_net.py @@ -23,6 +23,7 @@ import torch from openrl.configs.config import create_config_parser +from openrl.modules.base_module import BaseModule from openrl.modules.common.base_net import BaseNet from openrl.modules.ppo_module import PPOModule from openrl.utils.util import set_seed @@ -36,6 +37,7 @@ def __init__( device: Union[torch.device, str] = "cpu", n_rollout_threads: int = 1, model_dict: Optional[Dict[str, Any]] = None, + module_class: BaseModule = PPOModule, ) -> None: super().__init__() @@ -46,6 +48,7 @@ def __init__( set_seed(cfg.seed) env.reset(seed=cfg.seed) + cfg.num_agents = env.agent_num cfg.n_rollout_threads = n_rollout_threads cfg.learner_n_rollout_threads = cfg.n_rollout_threads @@ -62,7 +65,7 @@ def __init__( if isinstance(device, str): device = torch.device(device) - self.module = PPOModule( + self.module = module_class( cfg=cfg, policy_input_space=env.observation_space, critic_input_space=env.observation_space, diff --git a/openrl/modules/dqn_module.py b/openrl/modules/dqn_module.py index 9d450150..a7d7d92f 100644 --- a/openrl/modules/dqn_module.py +++ b/openrl/modules/dqn_module.py @@ -102,7 +102,7 @@ def evaluate_actions( masks, available_actions=None, masks_batch=None, - critic_masks_batch=None + critic_masks_batch=None, ): if masks_batch is None: masks_batch = masks diff --git a/openrl/modules/networks/MAT_network.py b/openrl/modules/networks/MAT_network.py new file mode 100644 index 00000000..3dda25a7 --- /dev/null +++ b/openrl/modules/networks/MAT_network.py @@ -0,0 +1,484 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +from openrl.buffers.utils.util import get_critic_obs_space, get_policy_obs_space +from openrl.modules.networks.base_value_policy_network import BaseValuePolicyNetwork +from openrl.modules.networks.utils.transformer_act import ( + continuous_autoregreesive_act, + continuous_parallel_act, + discrete_autoregreesive_act, + discrete_parallel_act, +) +from openrl.modules.networks.utils.util import init +from openrl.utils.util import check_v2 as check + + +def init_(m, gain=0.01, activate=False): + if activate: + gain = nn.init.calculate_gain("relu") + return init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=gain) + + +class SelfAttention(nn.Module): + def __init__(self, n_embd, n_head, n_agent, masked=False): + super(SelfAttention, self).__init__() + + assert n_embd % n_head == 0 + self.masked = masked + self.n_head = n_head + # key, query, value projections for all heads + self.key = init_(nn.Linear(n_embd, n_embd)) + self.query = init_(nn.Linear(n_embd, n_embd)) + self.value = init_(nn.Linear(n_embd, n_embd)) + # output projection + self.proj = init_(nn.Linear(n_embd, n_embd)) + # if self.masked: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "mask", + torch.tril(torch.ones(n_agent + 1, n_agent + 1)).view( + 1, 1, n_agent + 1, n_agent + 1 + ), + ) + + self.att_bp = None + + def forward(self, key, value, query): + B, L, D = query.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = ( + self.key(key).view(B, L, self.n_head, D // self.n_head).transpose(1, 2) + ) # (B, nh, L, hs) + q = ( + self.query(query).view(B, L, self.n_head, D // self.n_head).transpose(1, 2) + ) # (B, nh, L, hs) + v = ( + self.value(value).view(B, L, self.n_head, D // self.n_head).transpose(1, 2) + ) # (B, nh, L, hs) + + # causal attention: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + + # self.att_bp = F.softmax(att, dim=-1) + + if self.masked: + att = att.masked_fill(self.mask[:, :, :L, :L] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + + y = att @ v # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, L, D) + ) # re-assemble all head outputs side by side + + # output projection + y = self.proj(y) + return y + + +class EncodeBlock(nn.Module): + """an unassuming Transformer block""" + + def __init__(self, n_embd, n_head, n_agent): + super(EncodeBlock, self).__init__() + + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + # self.attn = SelfAttention(n_embd, n_head, n_agent, masked=True) + self.attn = SelfAttention(n_embd, n_head, n_agent, masked=False) + self.mlp = nn.Sequential( + init_(nn.Linear(n_embd, 1 * n_embd), activate=True), + nn.GELU(), + init_(nn.Linear(1 * n_embd, n_embd)), + ) + + def forward(self, x): + x = self.ln1(x + self.attn(x, x, x)) + x = self.ln2(x + self.mlp(x)) + return x + + +class DecodeBlock(nn.Module): + """an unassuming Transformer block""" + + def __init__(self, n_embd, n_head, n_agent): + super(DecodeBlock, self).__init__() + + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.ln3 = nn.LayerNorm(n_embd) + self.attn1 = SelfAttention(n_embd, n_head, n_agent, masked=True) + self.attn2 = SelfAttention(n_embd, n_head, n_agent, masked=True) + self.mlp = nn.Sequential( + init_(nn.Linear(n_embd, 1 * n_embd), activate=True), + nn.GELU(), + init_(nn.Linear(1 * n_embd, n_embd)), + ) + + def forward(self, x, rep_enc): + x = self.ln1(x + self.attn1(x, x, x)) + x = self.ln2(rep_enc + self.attn2(key=x, value=x, query=rep_enc)) + x = self.ln3(x + self.mlp(x)) + return x + + +class Encoder(nn.Module): + def __init__( + self, state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state + ): + super(Encoder, self).__init__() + + self.state_dim = state_dim + self.obs_dim = obs_dim + self.n_embd = n_embd + self.n_agent = n_agent + self.encode_state = encode_state + # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd)) + + self.state_encoder = nn.Sequential( + nn.LayerNorm(state_dim), + init_(nn.Linear(state_dim, n_embd), activate=True), + nn.GELU(), + ) + self.obs_encoder = nn.Sequential( + nn.LayerNorm(obs_dim), + init_(nn.Linear(obs_dim, n_embd), activate=True), + nn.GELU(), + ) + + self.ln = nn.LayerNorm(n_embd) + self.blocks = nn.Sequential( + *[EncodeBlock(n_embd, n_head, n_agent) for _ in range(n_block)] + ) + self.head = nn.Sequential( + init_(nn.Linear(n_embd, n_embd), activate=True), + nn.GELU(), + nn.LayerNorm(n_embd), + init_(nn.Linear(n_embd, 1)), + ) + + def forward(self, state, obs): + # state: (batch, n_agent, state_dim) + # obs: (batch, n_agent, obs_dim) + if self.encode_state: + state_embeddings = self.state_encoder(state) + x = state_embeddings + else: + obs_embeddings = self.obs_encoder(obs) + x = obs_embeddings + + rep = self.blocks(self.ln(x)) + v_loc = self.head(rep) + + return v_loc, rep + + +class Decoder(nn.Module): + def __init__( + self, + obs_dim, + action_dim, + n_block, + n_embd, + n_head, + n_agent, + action_type="Discrete", + dec_actor=False, + share_actor=False, + ): + super(Decoder, self).__init__() + + self.action_dim = action_dim + self.n_embd = n_embd + self.dec_actor = dec_actor + self.share_actor = share_actor + self.action_type = action_type + + if action_type != "Discrete": + log_std = torch.ones(action_dim) + # log_std = torch.zeros(action_dim) + self.log_std = torch.nn.Parameter(log_std) + # self.log_std = torch.nn.Parameter(torch.zeros(action_dim)) + + if self.dec_actor: + if self.share_actor: + print("mac_dec!!!!!") + self.mlp = nn.Sequential( + nn.LayerNorm(obs_dim), + init_(nn.Linear(obs_dim, n_embd), activate=True), + nn.GELU(), + nn.LayerNorm(n_embd), + init_(nn.Linear(n_embd, n_embd), activate=True), + nn.GELU(), + nn.LayerNorm(n_embd), + init_(nn.Linear(n_embd, action_dim)), + ) + else: + self.mlp = nn.ModuleList() + for n in range(n_agent): + actor = nn.Sequential( + nn.LayerNorm(obs_dim), + init_(nn.Linear(obs_dim, n_embd), activate=True), + nn.GELU(), + nn.LayerNorm(n_embd), + init_(nn.Linear(n_embd, n_embd), activate=True), + nn.GELU(), + nn.LayerNorm(n_embd), + init_(nn.Linear(n_embd, action_dim)), + ) + self.mlp.append(actor) + else: + # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd)) + if action_type == "Discrete": + self.action_encoder = nn.Sequential( + init_(nn.Linear(action_dim + 1, n_embd, bias=False), activate=True), + nn.GELU(), + ) + else: + self.action_encoder = nn.Sequential( + init_(nn.Linear(action_dim, n_embd), activate=True), nn.GELU() + ) + self.obs_encoder = nn.Sequential( + nn.LayerNorm(obs_dim), + init_(nn.Linear(obs_dim, n_embd), activate=True), + nn.GELU(), + ) + self.ln = nn.LayerNorm(n_embd) + self.blocks = nn.Sequential( + *[DecodeBlock(n_embd, n_head, n_agent) for _ in range(n_block)] + ) + self.head = nn.Sequential( + init_(nn.Linear(n_embd, n_embd), activate=True), + nn.GELU(), + nn.LayerNorm(n_embd), + init_(nn.Linear(n_embd, action_dim)), + ) + + def zero_std(self, device): + if self.action_type != "Discrete": + log_std = torch.zeros(self.action_dim).to(device) + self.log_std.data = log_std + + # state, action, and return + def forward(self, action, obs_rep, obs): + # action: (batch, n_agent, action_dim), one-hot/logits? + # obs_rep: (batch, n_agent, n_embd) + if self.dec_actor: + if self.share_actor: + logit = self.mlp(obs) + else: + logit = [] + for n in range(len(self.mlp)): + logit_n = self.mlp[n](obs[:, n, :]) + logit.append(logit_n) + logit = torch.stack(logit, dim=1) + else: + action_embeddings = self.action_encoder(action) + x = self.ln(action_embeddings) + for block in self.blocks: + x = block(x, obs_rep) + logit = self.head(x) + + return logit + + +class MultiAgentTransformer(BaseValuePolicyNetwork): + def __init__( + self, + cfg, + input_space, + action_space, + device=torch.device("cpu"), + use_half=False, + ): + assert use_half == False, "half precision not supported for MAT algorithm" + super(MultiAgentTransformer, self).__init__(cfg, device) + + obs_dim = get_policy_obs_space(input_space)[0] + critic_obs_dim = get_critic_obs_space(input_space)[0] + + n_agent = cfg.num_agents + n_block = cfg.n_block + n_embd = cfg.n_embd + n_head = cfg.n_head + encode_state = cfg.encode_state + dec_actor = cfg.dec_actor + share_actor = cfg.share_actor + if action_space.__class__.__name__ == "Box": + self.action_type = "Continuous" + action_dim = action_space.shape[0] + self.action_num = action_dim + else: + self.action_type = "Discrete" + action_dim = action_space.n + self.action_num = 1 + + self.n_agent = n_agent + self.obs_dim = obs_dim + self.critic_obs_dim = critic_obs_dim + self.action_dim = action_dim + self.tpdv = dict(dtype=torch.float32, device=device) + self._use_policy_active_masks = cfg.use_policy_active_masks + self.device = device + + # state unused + state_dim = 37 + + self.encoder = Encoder( + state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state + ) + self.decoder = Decoder( + obs_dim, + action_dim, + n_block, + n_embd, + n_head, + n_agent, + self.action_type, + dec_actor=dec_actor, + share_actor=share_actor, + ) + self.to(device) + + def zero_std(self): + if self.action_type != "Discrete": + self.decoder.zero_std(self.device) + + def eval_actions( + self, obs, rnn_states, action, masks, available_actions=None, active_masks=None + ): + obs = obs.reshape(-1, self.n_agent, self.obs_dim) + + action = action.reshape(-1, self.n_agent, self.action_num) + + if available_actions is not None: + available_actions = available_actions.reshape( + -1, self.n_agent, self.action_dim + ) + + # state: (batch, n_agent, state_dim) + # obs: (batch, n_agent, obs_dim) + # action: (batch, n_agent, 1) + # available_actions: (batch, n_agent, act_dim) + + # state unused + ori_shape = np.shape(obs) + state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32) + + state = check(state).to(**self.tpdv) + obs = check(obs).to(**self.tpdv) + action = check(action).to(**self.tpdv) + + if available_actions is not None: + available_actions = check(available_actions).to(**self.tpdv) + + batch_size = np.shape(state)[0] + v_loc, obs_rep = self.encoder(state, obs) + if self.action_type == "Discrete": + action = action.long() + action_log, entropy = discrete_parallel_act( + self.decoder, + obs_rep, + obs, + action, + batch_size, + self.n_agent, + self.action_dim, + self.tpdv, + available_actions, + ) + else: + action_log, entropy = continuous_parallel_act( + self.decoder, + obs_rep, + obs, + action, + batch_size, + self.n_agent, + self.action_dim, + self.tpdv, + ) + action_log = action_log.view(-1, self.action_num) + v_loc = v_loc.view(-1, 1) + entropy = entropy.view(-1, self.action_num) + if self._use_policy_active_masks and active_masks is not None: + entropy = (entropy * active_masks).sum() / active_masks.sum() + else: + entropy = entropy.mean() + return action_log, entropy, v_loc + + def get_actions( + self, + obs, + rnn_states_actor=None, + masks=None, + available_actions=None, + deterministic=False, + ): + obs = obs.reshape(-1, self.n_agent, self.obs_dim) + if available_actions is not None: + available_actions = available_actions.reshape( + -1, self.num_agents, self.action_dim + ) + + # state unused + ori_shape = np.shape(obs) + state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32) + + state = check(state).to(**self.tpdv) + obs = check(obs).to(**self.tpdv) + if available_actions is not None: + available_actions = check(available_actions).to(**self.tpdv) + + batch_size = np.shape(obs)[0] + v_loc, obs_rep = self.encoder(state, obs) + if self.action_type == "Discrete": + output_action, output_action_log = discrete_autoregreesive_act( + self.decoder, + obs_rep, + obs, + batch_size, + self.n_agent, + self.action_dim, + self.tpdv, + available_actions, + deterministic, + ) + else: + output_action, output_action_log = continuous_autoregreesive_act( + self.decoder, + obs_rep, + obs, + batch_size, + self.n_agent, + self.action_dim, + self.tpdv, + deterministic, + ) + + output_action = output_action.reshape(-1, output_action.shape[-1]) + output_action_log = output_action_log.reshape(-1, output_action_log.shape[-1]) + return output_action, output_action_log, None + + def get_values(self, critic_obs, rnn_states_critic=None, masks=None): + critic_obs = critic_obs.reshape(-1, self.n_agent, self.critic_obs_dim) + + ori_shape = np.shape(critic_obs) + state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32) + + state = check(state).to(**self.tpdv) + obs = check(critic_obs).to(**self.tpdv) + v_tot, obs_rep = self.encoder(state, obs) + + v_tot = v_tot.reshape(-1, v_tot.shape[-1]) + return v_tot, None + + def get_actor_para(self): + return self.parameters() + + def get_critic_para(self): + return self.parameters() diff --git a/openrl/modules/networks/base_value_policy_network.py b/openrl/modules/networks/base_value_policy_network.py new file mode 100644 index 00000000..20d6caa2 --- /dev/null +++ b/openrl/modules/networks/base_value_policy_network.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2022 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from abc import ABC, abstractmethod + +import torch.nn as nn + +from openrl.modules.utils.valuenorm import ValueNorm + + +class BaseValuePolicyNetwork(ABC, nn.Module): + def __init__(self, cfg, device): + super(BaseValuePolicyNetwork, self).__init__() + self.device = device + self._use_valuenorm = cfg.use_valuenorm + + if self._use_valuenorm: + self.value_normalizer = ValueNorm(1, device=self.device) + else: + self.value_normalizer = None + + def forward(self, forward_type, *args, **kwargs): + if forward_type == "original": + return self.get_actions(*args, **kwargs) + elif forward_type == "eval_actions": + return self.eval_actions(*args, **kwargs) + elif forward_type == "get_values": + return self.get_values(*args, **kwargs) + else: + raise NotImplementedError + + @abstractmethod + def get_actions(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def eval_actions(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def get_values(self, *args, **kwargs): + raise NotImplementedError diff --git a/openrl/modules/networks/q_network.py b/openrl/modules/networks/q_network.py index ba0b63bb..ea070e71 100644 --- a/openrl/modules/networks/q_network.py +++ b/openrl/modules/networks/q_network.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from openrl.buffers.utils.util import get_shape_from_obs_space_v2, get_critic_obs_space +from openrl.buffers.utils.util import get_critic_obs_space, get_shape_from_obs_space_v2 from openrl.modules.networks.base_value_network import BaseValueNetwork from openrl.modules.networks.utils.cnn import CNNBase from openrl.modules.networks.utils.mix import MIXBase diff --git a/openrl/modules/ppo_module.py b/openrl/modules/ppo_module.py index 5d12ef70..7af9912c 100644 --- a/openrl/modules/ppo_module.py +++ b/openrl/modules/ppo_module.py @@ -130,7 +130,6 @@ def get_actions( values, rnn_states_critic = self.models["critic"]( critic_obs, rnn_states_critic, masks ) - return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic def get_values(self, critic_obs, rnn_states_critic, masks): diff --git a/openrl/runners/common/__init__.py b/openrl/runners/common/__init__.py index 5c34d65a..eb32123b 100644 --- a/openrl/runners/common/__init__.py +++ b/openrl/runners/common/__init__.py @@ -1,5 +1,6 @@ from openrl.runners.common.chat_agent import Chat6BAgent, ChatAgent from openrl.runners.common.dqn_agent import DQNAgent +from openrl.runners.common.mat_agent import MATAgent from openrl.runners.common.ppo_agent import PPOAgent __all__ = [ @@ -7,4 +8,5 @@ "ChatAgent", "Chat6BAgent", "DQNAgent", + "MATAgent", ] diff --git a/openrl/runners/common/mat_agent.py b/openrl/runners/common/mat_agent.py new file mode 100644 index 00000000..e296dac5 --- /dev/null +++ b/openrl/runners/common/mat_agent.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Type + +from openrl.algorithms.base_algorithm import BaseAlgorithm +from openrl.algorithms.mat import MATAlgorithm +from openrl.runners.common.base_agent import SelfAgent +from openrl.runners.common.ppo_agent import PPOAgent +from openrl.utils.logger import Logger + + +class MATAgent(PPOAgent): + def train( + self: SelfAgent, + total_time_steps: int, + train_algo_class: Type[BaseAlgorithm] = MATAlgorithm, + ) -> None: + logger = Logger( + cfg=self._cfg, + project_name="MATAgent", + scenario_name=self._env.env_name, + wandb_entity=self._cfg.wandb_entity, + exp_name=self.exp_name, + log_path=self.run_dir, + use_wandb=self._use_wandb, + use_tensorboard=self._use_tensorboard, + ) + + super(MATAgent, self).train( + total_time_steps=total_time_steps, + train_algo_class=train_algo_class, + logger=logger, + ) diff --git a/openrl/runners/common/ppo_agent.py b/openrl/runners/common/ppo_agent.py index 2b8993af..49aa3665 100644 --- a/openrl/runners/common/ppo_agent.py +++ b/openrl/runners/common/ppo_agent.py @@ -15,13 +15,14 @@ # limitations under the License. """""" -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Type, Union import gym import numpy as np import torch -from openrl.algorithms.ppo import PPOAlgorithm as TrainAlgo +from openrl.algorithms.base_algorithm import BaseAlgorithm +from openrl.algorithms.ppo import PPOAlgorithm from openrl.buffers import NormalReplayBuffer as ReplayBuffer from openrl.buffers.utils.obs_data import ObsData from openrl.drivers.onpolicy_driver import OnPolicyDriver as Driver @@ -48,7 +49,12 @@ def __init__( net, env, run_dir, env_num, rank, world_size, use_wandb, use_tensorboard ) - def train(self: SelfAgent, total_time_steps: int) -> None: + def train( + self: SelfAgent, + total_time_steps: int, + train_algo_class: Type[BaseAlgorithm] = PPOAlgorithm, + logger: Optional[Logger] = None, + ) -> None: self._cfg.num_env_steps = total_time_steps self.config = { @@ -59,7 +65,7 @@ def train(self: SelfAgent, total_time_steps: int) -> None: "device": self.net.device, } - trainer = TrainAlgo( + trainer = train_algo_class( cfg=self._cfg, init_module=self.net.module, device=self.net.device, @@ -73,17 +79,17 @@ def train(self: SelfAgent, total_time_steps: int) -> None: self._env.action_space, data_client=None, ) - - logger = Logger( - cfg=self._cfg, - project_name="PPOAgent", - scenario_name=self._env.env_name, - wandb_entity=self._cfg.wandb_entity, - exp_name=self.exp_name, - log_path=self.run_dir, - use_wandb=self._use_wandb, - use_tensorboard=self._use_tensorboard, - ) + if logger is None: + logger = Logger( + cfg=self._cfg, + project_name="PPOAgent", + scenario_name=self._env.env_name, + wandb_entity=self._cfg.wandb_entity, + exp_name=self.exp_name, + log_path=self.run_dir, + use_wandb=self._use_wandb, + use_tensorboard=self._use_tensorboard, + ) driver = Driver( config=self.config, trainer=trainer, diff --git a/tests/test_algorithm/test_mat_algorithm.py b/tests/test_algorithm/test_mat_algorithm.py new file mode 100644 index 00000000..4e4d9a98 --- /dev/null +++ b/tests/test_algorithm/test_mat_algorithm.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.Discrete(2) + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.ppo_module import PPOModule + + module = PPOModule( + config, + policy_input_space=obs_space, + critic_input_space=obs_space, + act_space=act_space, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.normal_buffer import NormalReplayBuffer + + buffer = NormalReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=100, + ) + return buffer.data + + +@pytest.mark.unittest +def test_mat_algorithm(config, init_module, buffer_data): + from openrl.algorithms.mat import MATAlgorithm + + mat_algo = MATAlgorithm(config, init_module) + mat_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))