Skip to content

Commit

Permalink
fix tqdm
Browse files Browse the repository at this point in the history
  • Loading branch information
juliusfrost committed Sep 19, 2021
1 parent 4fc336a commit 24b18da
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
6 changes: 4 additions & 2 deletions minerl_rllib/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from minerl_rllib.generate_kmeans import main as generate_kmeans


class LazyMineRLEnv(gym.Env):
class LazyMineRLEnv:
def __init__(self, env_spec, **kwargs):
self._kwargs = kwargs
self.env_spec: EnvSpec = env_spec
self._env = None
self.observation_space = self.env_spec.observation_space
self.action_space = self.env_spec.action_space
super().__init__()

def init_env(self):
Expand All @@ -30,7 +32,7 @@ def init_env(self):
def reset(self, **kwargs):
if self._env is None:
self.init_env()
return self._env.reset()
return self._env.reset(**kwargs)

def step(self, action):
return self._env.step(action)
Expand Down
12 changes: 10 additions & 2 deletions minerl_rllib/generate_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@
from sklearn.cluster import KMeans
from tqdm import tqdm

from minerl_rllib.envs.utils import patch_data_pipeline

parser = argparse.ArgumentParser()

parser.add_argument("--env", default=None)
parser.add_argument("--num-actions", type=int, default=32)
parser.add_argument("--data-dir", default=os.getenv("MINERL_DATA_ROOT", "data"))
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--use-tqdm", action="store_true")


def main(args=None):
args = parser.parse_args(args=args)
patch_data_pipeline()
if args.env is None:
env_list = []
for env_name in os.listdir(args.data_dir):
Expand All @@ -39,7 +43,10 @@ def main(args=None):

data = minerl.data.make(env_name)
actions = []
for trajectory_name in tqdm(list(data.get_trajectory_names())):
iter = list(data.get_trajectory_names())
if args.use_tqdm:
iter = tqdm(iter)
for trajectory_name in iter:
try:
for _, action, _, _, _ in data.load_data(trajectory_name):
actions.append(action["vector"])
Expand All @@ -53,7 +60,8 @@ def main(args=None):
print(kmeans)
np.save(file, kmeans.cluster_centers_)

if len(env_list) == 1:
if args.env is not None:
assert isinstance(return_path, str)
return return_path


Expand Down

0 comments on commit 24b18da

Please sign in to comment.