Skip to content

Commit

Permalink
Add imitation baselines for offline RL (thu-ml#566)
Browse files Browse the repository at this point in the history
add imitation baselines for offline RL; make the choice of env/task and D4RL dataset explicit; on expert datasets, IL easily outperforms; after reading the D4RL paper, I'll rerun the exps on medium data
  • Loading branch information
nuance1979 authored Mar 12, 2022
1 parent 74f430e commit 9cb74e6
Show file tree
Hide file tree
Showing 11 changed files with 659 additions and 222 deletions.
58 changes: 36 additions & 22 deletions examples/offline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,30 @@ We provide implementation of BCQ and CQL algorithm for continuous control.

### Train

Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `d4rl_bcq.py` is an example of offline RL using the d4rl dataset.

To train an agent with BCQ algorithm:
## Results

```bash
python offline_bcq.py --task halfcheetah-expert-v1
```
### IL (Imitation Learning, aka, Behavior Cloning)

After 1M steps:
| Environment | Dataset | IL | Parameters |
| --------------------- | --------------------- | --------------- | -------------------------------------------------------- |
| HalfCheetah-v2 | halfcheetah-expert-v2 | 11355.31 | `python3 d4rl_il.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
| HalfCheetah-v2 | halfcheetah-medium-v2 | 5098.16 | `python3 d4rl_il.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |

![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png)
### BCQ

`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the off-policy algorithms in mujoco environment.
| Environment | Dataset | BCQ | Parameters |
| --------------------- | --------------------- | --------------- | -------------------------------------------------------- |
| HalfCheetah-v2 | halfcheetah-expert-v2 | 11509.95 | `python3 d4rl_bcq.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
| HalfCheetah-v2 | halfcheetah-medium-v2 | 5147.43 | `python3 d4rl_bcq.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |

## Results
### CQL

| Environment | BCQ |
| --------------------- | --------------- |
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |
| Environment | Dataset | CQL | Parameters |
| --------------------- | --------------------- | --------------- | -------------------------------------------------------- |
| HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` |
| HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` |

## Discrete control

Expand All @@ -42,47 +47,56 @@ To running CQL algorithm on Atari, you need to do the following things:
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`.

### IL

We test our IL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | IL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.0 (epoch 5) | `python3 atari_il.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 121.9 (epoch 12, could be higher) | `python3 atari_il.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` |

### BCQ

We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | BCQ | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` |

### CQL

We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |

We reduce the size of the offline data to 10% and 1% of the above and get:

Buffer size 100000:

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |

Buffer size 10000:

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |

### CRR

We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |

Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
74 changes: 43 additions & 31 deletions examples/offline/atari_bcq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

import argparse
import datetime
import os
Expand All @@ -9,12 +11,11 @@
from torch.utils.tensorboard import SummaryWriter

from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import wrap_deepmind
from examples.atari.atari_wrapper import make_atari_env
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor

Expand All @@ -33,12 +34,21 @@ def get_args():
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--update-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512])
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark")
parser.add_argument(
"--watch",
default=False,
Expand All @@ -56,35 +66,24 @@ def get_args():
return args


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(
def test_discrete_bcq(args=get_args()):
# envs
env, _, test_envs = make_atari_env(
args.task,
args.seed,
1,
args.test_num,
scale=args.scale_obs,
frame_stack=args.frames_stack,
episode_life=False,
clip_rewards=False
)


def test_discrete_bcq(args=get_args()):
# envs
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
feature_net = DQN(
*args.state_shape, args.action_shape, device=args.device, features_only=True
Expand Down Expand Up @@ -118,9 +117,9 @@ def test_discrete_bcq(args=get_args()):
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
Expand All @@ -130,16 +129,29 @@ def test_discrete_bcq(args=get_args()):
test_collector = Collector(policy, test_envs, exploration_noise=True)

# log
log_path = os.path.join(
args.logdir, args.task, 'bcq',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}'
)
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "bcq"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=args.log_interval)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
return False
Expand Down
Loading

0 comments on commit 9cb74e6

Please sign in to comment.