Skip to content

Commit

Permalink
Upon exit, replay buffer perists to disk, and can be optionally pre-p…
Browse files Browse the repository at this point in the history
…opulated from disk before training begins.
  • Loading branch information
fred-drake committed Apr 25, 2020
1 parent f5dd3d2 commit 04e63ea
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
22 changes: 20 additions & 2 deletions muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib
import os
import time
import pickle

import numpy
import ray
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, game_name):
# Weights used to initialize workers
self.muzero_weights = models.MuZeroNetwork(self.config).get_weights()

def train(self):
def train(self, replay_buffer_path=None):
ray.init()
os.makedirs(self.config.results_path, exist_ok=True)

Expand All @@ -64,6 +65,17 @@ def train(self):
copy.deepcopy(self.muzero_weights), self.game_name, self.config,
)
replay_buffer_worker = replay_buffer.ReplayBuffer.remote(self.config)

# Pre-load buffer if pulling from persistent storage
if replay_buffer_path is not None:
if os.path.exists(replay_buffer_path):
buffer = pickle.load(open(path, 'rb'))
for game_history in buffer:
replay_buffer_worker.save_game.remote(buffer[game_history])
print("Loaded {} games from replay buffer.".format(len(buffer)))
else:
print("Warning: Replay buffer path '{}' doesn't exist. Using empty buffer.".format(replay_buffer_path))

self_play_workers = [
self_play.SelfPlay.remote(
copy.deepcopy(self.muzero_weights),
Expand All @@ -88,6 +100,11 @@ def train(self):
self._logging_loop(shared_storage_worker, replay_buffer_worker)

self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())

# Persist replay buffer to disk
print("\n\nPersisting replay buffer games to disk...")
ray.get(replay_buffer_worker.persist_buffer.remote())

# End running actors
ray.shutdown()

Expand Down Expand Up @@ -260,7 +277,8 @@ def load_model(self, path=None):
choice = input("Invalid input, enter a number listed above: ")
choice = int(choice)
if choice == 0:
muzero.train()
path = input("Enter path for existing replay buffer, or ENTER if none: ")
muzero.train(path if len(path.strip()) > 0 else None)
elif choice == 1:
path = input("Enter a path to the model.weights: ")
while not os.path.isfile(path):
Expand Down
5 changes: 5 additions & 0 deletions replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import copy
import pickle
import os

import numpy
import ray
Expand Down Expand Up @@ -52,6 +54,9 @@ def save_game(self, game_history):
self.total_samples -= len(self.buffer[del_id].priorities)
del self.buffer[del_id]

def persist_buffer(self):
pickle.dump(self.buffer, open(os.path.join(self.config.results_path, 'replay_buffer.pkl'), 'wb'))

def get_self_play_count(self):
return self.self_play_count

Expand Down

0 comments on commit 04e63ea

Please sign in to comment.