Video:
This repository contains source code for experiments in the paper titled "RealAnt: An Open-Source Low-Cost Quadruped for Research in Real-World Reinforcement Learning" by Rinu Boney*, Jussi Sainio*, Mikko Kaivola, Arno Solin, and Juho Kannala. It consists of:
- Supporting software for reinforcement learning with the RealAnt robot
- PyTorch implementations of REDQ, TD3 and SAC algorithms
- MuJoCo and PyBullet environments of the RealAnt robot
RealAnt is a minimal and low-cost (~350€ in materials) physical version of the popular 'Ant' benchmark used in reinforcement learning. It can be built using easily available electronic components and a 3D printed body. Code for the RealAnt platform including the 3D models, microcontroller board firmware, Python interface and pose estimation is available here: https://github.com/OteRobotics/realant
Observation space (29-dim):
- x, y, and z velocities of the torso (3),
- z position of the torso (1),
- sin and cos values of Euler angles of the torso (6),
- velocities of Euler angles of the torso (3),
- angular positions of the joints (8), and
- angular velocities of the joints (8).
We rely on augmented reality (AR) tag tracking using ArUco tags for pose estimation.
Action space (8-dim): set-points for the angular positions of the robot joints.
We consider three benchmark tasks:
- Stand upright.
- Turn 180 degrees.
- Walk forward as fast as possible.
REDQ and TD3 algorithms are able to successfully learn all three tasks. With REDQ, learning to stand takes around 5 minutes of experience, learning to turn takes 5 minues of experience, and learning to walk takes 10 minutes of experience.
The training code is decoupled into a training client and a rollout server, communicating using ZeroMQ. The training client (train_client.py
) controls the whole learning process. It sends the latest policy weights to the rollout server (rollout_server.py
) at the beginning of each episode. The rollout server loads the policy weights, collects the latest observations from the robot, and sends the action computed using the policy network back to the robot. After completing an episode, the rollout server sends back the collected data to the train client. The newly collected data is added to a replay buffer and the agent is updated a few times by sampling from this replay buffer.
The train client and rollout server can be run in different machines. For example, the data collection (with rollout server) can be performed on a low-end computer and training (with train client) can be performed on a high-end computer.
Setup the robot and run
python rollout_server.py
and
python train_client.py --n_episodes 250
for reinforcement learning with the robot.
The train client logs all data into a newly created experiment folder. After each episode, the robot position and orientation should be reset manually (if necessary). If training gets stuck due to broken serial link or camera observations, restart the the respective script(s), the rollout server, and run python train_client.py --resume <exp_folder>
to resume training.
We also provide plotting code:
visualize_episode.py
to visualize the observations, actions, and rewards during an episode.visualize_returns.py
to plot the cumulative rewards of a training run.
The simulator environments of the RealAnt robot can be used as:
import gym
import realant_sim
env = gym.make('RealAntMujoco-v0')
env = gym.make('RealAntBullet-v0')
The results reported in the paper can be reproduced by running:
python train.py
Optional arguments:
Parameter | Default | Description |
---|---|---|
--agent | redq | 'redq' or 'td3' or 'sac' |
--env | mujoco | 'mujoco' or 'pybullet' |
--task | walk | 'sleep' or 'turn' or 'walk |
--seed | 1 | random seed |
--latency | 2 | number of steps by which observations are delayed, where 1 step = 0.05 s |
--xyz_noise_std | 0.01 | std of Gaussian noise added to body_xyz measurements |
--rpy_noise_std | 0.01 | std of Gaussian noise added to body_rpy measurements |
--min_obs_stack | 4 | number of past observations to be stacked |
--n_updates_mul | 8 | multiply number of updates after each episode |
--critic_num_nets | 4 | number of critic networks for REDQ |
The PyBullet environment only supports the 'walk' task and does not support the latency or delay argments.
This project is licensed under the terms of the MIT license. See LICENSE file for details.