This repository is an implementation of the Diffusion Policy for Offline RL algorithm using JAX/Flax. It also marks my first attempt at constructing a relatively complex reinforcement learning system using technologies beyond PyTorch.
Please install the following packages:
You can install them using the following command (adjust based on your environment):
pip install d4rl gym numpy jax flax optax chex distrax wandbExecute the following commands:
git clone https://github.com/dibyaghosh/jaxrl_m.git
cd jaxrl_m
pip install -e .Using the -e parameter installs the package in development mode, ensuring that the project's dependency library code points to your local version of jaxrl_m.
From the project's root directory, run:
python run_<algo_name>.pyhyper_<name>.py: Contains default hyperparameters and tuning configurations.util_<name>.py: Includes utility functions for data loading, models, and other helper operations.model_<name>.py: Defines the network architecture.algo_<name>.py: Contains the core logic of the RL agent, including creation, updates, and sampling.run_<name>.py: The entry point for running the program.xxx_test.py: Test files.
The training process and results can be monitored on the Weights & Biases platform.
Test data on an RTX 4060 gaming laptop:
- Training Speed: Increased from ~38 iterations per second to ~650 iterations per second, marking a significant speedup.
- GPU Utilization: Risen from ~20% to ~45%, a modest increase.
- GPU Memory Usage: Grew from ~15% to ~70%, ensuring more efficient GPU resource usage.
- Thanks to JAX, Flax, Optax, Distrax, and other high-quality deep learning libraries for their elegant code and comprehensive documentation.
- Appreciation goes to jaxrl, jaxrl2, and jaxrl_m for their outstanding contributions to applying JAX/Flax in reinforcement learning.
- Special thanks to the original author of Diffusion Policy for Offline RL for providing a robust algorithm that maintained high reproducibility even after migrating frameworks.
If you encounter any issues while using this project, please feel free to submit an issue or a pull request to help improve it. Wishing you success in your reinforcement learning research!