From d1bb7abded13fea4d1d222aa8695a8ebdd9b042e Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 13 Feb 2023 21:01:37 +0000 Subject: [PATCH 1/5] init --- torchrl/envs/libs/gym.py | 8 +-- torchrl/envs/libs/openml.py | 119 ++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 torchrl/envs/libs/openml.py diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 21ed74f0fe5..4f9364a58b8 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -20,11 +20,11 @@ UnboundedContinuousTensorSpec, ) -from ..._utils import implement_for -from ...data.utils import numpy_to_torch_dtype_dict +from torchrl._utils import implement_for +from torchrl.data.utils import numpy_to_torch_dtype_dict -from ..gym_like import default_info_dict_reader, GymLikeEnv -from ..utils import _classproperty +from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv +from torchrl.envs.utils import _classproperty IMPORT_ERROR = None _has_gym = False diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py new file mode 100644 index 00000000000..f7e4e498cb3 --- /dev/null +++ b/torchrl/envs/libs/openml.py @@ -0,0 +1,119 @@ +import numpy as np +from tensordict.tensordict import TensorDict, TensorDictBase + +from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, UnboundedDiscreteTensorSpec +from torchrl.envs import EnvBase +from sklearn.datasets import fetch_openml +from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler, LabelEncoder +import torch + +X, y = fetch_openml('adult', version=1, return_X_y=True) + +def _get_data(dataset_name): + if dataset_name in ['adult_num', 'adult_onehot']: + X, y = fetch_openml('adult', version=1, return_X_y=True) + is_NaN = X.isna() + row_has_NaN = is_NaN.any(axis=1) + X = X[~row_has_NaN] + # y = y[~row_has_NaN] + y = X["occupation"] + X = X.drop(["occupation"],axis=1) + cat_ix = X.select_dtypes(include=['category']).columns + num_ix = X.select_dtypes(include=['int64', 'float64']).columns + encoder = LabelEncoder() + # now apply the transformation to all the columns: + for col in cat_ix: + X[col] = encoder.fit_transform(X[col]) + y = encoder.fit_transform(y) + if dataset_name == 'adult_onehot': + cat_features = OneHotEncoder(sparse=False).fit_transform(X[cat_ix]) + num_features = StandardScaler().fit_transform(X[num_ix]) + X = np.concatenate((num_features, cat_features), axis=1) + else: + X = StandardScaler().fit_transform(X) + elif dataset_name in ['mushroom_num', 'mushroom_onehot']: + X, y = fetch_openml('mushroom', version=1, return_X_y=True) + encoder = LabelEncoder() + # now apply the transformation to all the columns: + for col in X.columns: + X[col] = encoder.fit_transform(X[col]) + # X = X.drop(["veil-type"],axis=1) + y = encoder.fit_transform(y) + if dataset_name == 'mushroom_onehot': + X = OneHotEncoder(sparse=False).fit_transform(X) + else: + X = StandardScaler().fit_transform(X) + elif dataset_name == 'covertype': + # https://www.openml.org/d/150 + # there are some 0/1 features -> consider just numeric + X, y = fetch_openml('covertype', version=3, return_X_y=True) + X = StandardScaler().fit_transform(X) + y = LabelEncoder().fit_transform(y) + elif dataset_name == 'shuttle': + # https://www.openml.org/d/40685 + # all numeric, no missing values + X, y = fetch_openml('shuttle', version=1, return_X_y=True) + X = StandardScaler().fit_transform(X) + y = LabelEncoder().fit_transform(y) + elif dataset_name == 'magic': + # https://www.openml.org/d/1120 + # all numeric, no missing values + X, y = fetch_openml('MagicTelescope', version=1, return_X_y=True) + X = StandardScaler().fit_transform(X) + y = LabelEncoder().fit_transform(y) + else: + raise RuntimeError('Dataset does not exist') + return TensorDict({"X": X, "y": y}, X.shape[:1]) + +def make_composite_from_td(td): + # custom funtion to convert a tensordict in a similar spec structure + # of unbounded values. + composite = CompositeSpec( + { + key: make_composite_from_td(tensor) + if isinstance(tensor, TensorDictBase) + else UnboundedContinuousTensorSpec( + dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + ) if tensor.dtype in (torch.float16, torch.float32, torch.float64) else + UnboundedDiscreteTensorSpec( + dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + ) + for key, tensor in td.items() + }, + shape=td.shape, + ) + return composite + + +class OpenMLEnv(EnvBase): + def __init__(self, dataset_name, device="cpu", batch_size=None): + if batch_size is None: + batch_size = [1] + self.dataset_name = dataset_name + self._data = _get_data(dataset_name) + super().__init__(device=device, batch_size=batch_size) + self.observation_spec = make_composite_from_td(self._data[:self.batch_size.numel()]) + self.action_spec = self.observation_spec["y"].clone() + self.reward_spec = UnboundedContinuousTensorSpec(shape=(*self.batch_size, 1)) + + def _reset(self, tensordict): + r_id = torch.randint(self._data.shape[0], (self.batch_size.numel(),)) + data = self._data[r_id] + return data + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + action = tensordict["action"] + reward = (action == tensordict["y"]).float().unsqueeze(-1) + done = torch.ones_like(reward, dtype=torch.bool) + return TensorDict({ + "done": done, + "reward": reward, + "X": tensordict["X"], + "y": tensordict["y"], + }, self.batch_size) + + def _set_seed(self, seed): + self.rng = torch.random.manual_seed(seed) From fbb929253790dd1d6837c13425c023475bfe9bf0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 17 Mar 2023 13:55:31 +0000 Subject: [PATCH 2/5] amend --- .circleci/config.yml | 59 +++ .../linux_examples/scripts/run_test.sh | 1 + .../scripts_sklearn/environment.yml | 20 + .../linux_libs/scripts_sklearn/install.sh | 51 +++ .../scripts_sklearn/post_process.sh | 6 + .../scripts_sklearn/run-clang-format.py | 356 ++++++++++++++++++ .../linux_libs/scripts_sklearn/run_test.sh | 27 ++ .../linux_libs/scripts_sklearn/setup_env.sh | 50 +++ docs/source/reference/data.rst | 4 +- docs/source/reference/envs.rst | 1 + examples/bandits/dqn.py | 117 ++++++ test/test_libs.py | 71 +++- torchrl/collectors/collectors.py | 5 +- torchrl/data/datasets/d4rl.py | 36 +- torchrl/data/datasets/openml.py | 150 ++++++++ torchrl/data/replay_buffers/replay_buffers.py | 9 + torchrl/envs/libs/gym.py | 4 +- torchrl/envs/libs/openml.py | 172 +++++---- torchrl/modules/models/models.py | 1 + .../modules/tensordict_module/exploration.py | 4 +- torchrl/objectives/dqn.py | 6 +- torchrl/objectives/utils.py | 2 + 22 files changed, 1052 insertions(+), 100 deletions(-) create mode 100644 .circleci/unittest/linux_libs/scripts_sklearn/environment.yml create mode 100755 .circleci/unittest/linux_libs/scripts_sklearn/install.sh create mode 100755 .circleci/unittest/linux_libs/scripts_sklearn/post_process.sh create mode 100755 .circleci/unittest/linux_libs/scripts_sklearn/run-clang-format.py create mode 100755 .circleci/unittest/linux_libs/scripts_sklearn/run_test.sh create mode 100755 .circleci/unittest/linux_libs/scripts_sklearn/setup_env.sh create mode 100644 examples/bandits/dqn.py create mode 100644 torchrl/data/datasets/openml.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 6a254816111..8515c743663 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -529,6 +529,61 @@ jobs: - store_test_results: path: test-results + unittest_linux_sklearn_gpu: + <<: *binary_common + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.medium + environment: + image_name: "nvidia/cudagl:11.4.0-base" + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + CU_VERSION: << parameters.cu_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + keys: + - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_sklearn/environment.yml" }}-{{ checksum ".circleci-weekly" }} + - run: + name: Setup + command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_sklearn/setup_env.sh + - save_cache: + + key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_sklearn/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + paths: + - conda + - env + - run: + # Here we create an envlist file that contains some env variables that we want the docker container to be aware of. + # Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables. + # They're available in all the other workflows (OSX and Windows). + # But here, we're running the unittest_linux_gpu workflows in a docker container, where those variables aren't accessible. + # So instead we dump the variables we need in env.list and we pass that file when invoking "docker run". + name: export CIRCLECI env var + command: echo "CIRCLECI=true" >> ./env.list + - run: + name: Install torchrl + command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_sklearn/install.sh + - run: + name: Run tests + command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_sklearn/run_test.sh + - run: + name: Codecov upload + command: | + bash <(curl -s https://codecov.io/bash) -Z -F sklearn-gpu + - run: + name: Post Process + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_sklearn/post_process.sh + - store_test_results: + path: test-results + unittest_linux_jumanji_gpu: <<: *binary_common @@ -1283,6 +1338,10 @@ workflows: # cu_version: cu117 # name: unittest_linux_d4rl_gpu_py3.8 # python_version: '3.8' + - unittest_linux_sklearn_gpu: + cu_version: cu117 + name: unittest_linux_sklearn_gpu_py3.8 + python_version: '3.8' - unittest_linux_jumanji_gpu: cu_version: cu117 name: unittest_linux_jumanji_gpu_py3.8 diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index f8fd2154010..ef16b63c636 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -231,6 +231,7 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_onli mode=offline \ collector_devices=cuda:0 +python .circleci/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100 coverage combine coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/environment.yml b/.circleci/unittest/linux_libs/scripts_sklearn/environment.yml new file mode 100644 index 00000000000..5cf208f00c1 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_sklearn/environment.yml @@ -0,0 +1,20 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - expecttest + - pyyaml + - scipy + - hydra-core + - scikit-learn + - pandas diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/install.sh b/.circleci/unittest/linux_libs/scripts_sklearn/install.sh new file mode 100755 index 00000000000..437900b3323 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_sklearn/install.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +pip3 install -e . + +# smoke test +python -c "import torchrl" diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/post_process.sh b/.circleci/unittest/linux_libs/scripts_sklearn/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_sklearn/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/run-clang-format.py b/.circleci/unittest/linux_libs/scripts_sklearn/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_sklearn/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/run_test.sh b/.circleci/unittest/linux_libs/scripts_sklearn/run_test.sh new file mode 100755 index 00000000000..d357c033d34 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_sklearn/run_test.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get install -y git gcc +ln -s /usr/bin/swig3.0 /usr/bin/swig + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +# this workflow only tests the libs +python -c "import sklearn, pandas" + +python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 20 --capture no -k TestOpenML +coverage combine +coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_sklearn/setup_env.sh b/.circleci/unittest/linux_libs/scripts_sklearn/setup_env.sh new file mode 100755 index 00000000000..b8073884ff3 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_sklearn/setup_env.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ unzip + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index c115514650c..abf1e4ab175 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -174,7 +174,8 @@ Here's an example: Installing dependencies is the responsibility of the user. For D4RL, a clone of `the repository `_ is needed as - the latest wheels are not published on PyPI. + the latest wheels are not published on PyPI. For OpenML, `scikit-learn `_ and + `pandas `_ are required. .. autosummary:: :toctree: generated/ @@ -183,6 +184,7 @@ Here's an example: .. currentmodule:: torchrl.data.datasets D4RLExperienceReplay + OpenMLExperienceReplay TensorSpec ---------- diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 427c5df4d20..8b661bfa391 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -406,5 +406,6 @@ Libraries habitat.HabitatEnv jumanji.JumanjiEnv jumanji.JumanjiWrapper + openml.OpenMLEnv vmas.VmasEnv vmas.VmasWrapper diff --git a/examples/bandits/dqn.py b/examples/bandits/dqn.py new file mode 100644 index 00000000000..6b43b08d4b7 --- /dev/null +++ b/examples/bandits/dqn.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import torch +import tqdm +from torch import nn + +from torchrl.envs.libs.openml import OpenMLEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import DistributionalQValueActor, EGreedyWrapper, MLP, QValueActor +from torchrl.objectives import DistributionalDQNLoss, DQNLoss + +parser = argparse.ArgumentParser() + +# Add arguments +parser.add_argument("--batch_size", type=int, default=256, help="batch size") +parser.add_argument("--n_steps", type=int, default=10000, help="number of steps") +parser.add_argument( + "--eps_greedy", type=float, default=0.1, help="epsilon-greedy parameter" +) +parser.add_argument("--lr", type=float, default=2e-4, help="learning rate") +parser.add_argument("--wd", type=float, default=1e-4, help="weight decay") +parser.add_argument("--n_cells", type=int, default=128, help="number of cells") +parser.add_argument( + "--distributional", action="store_true", help="enable distributional Q-learning" +) +parser.add_argument( + "--dataset", + default="adult_onehot", + choices=[ + "adult_num", + "adult_onehot", + "mushroom_num", + "mushroom_onehot", + "covertype", + "shuttle", + "magic", + ], + help="OpenML dataset", +) + +if __name__ == "__main__": + # Parse arguments + args = parser.parse_args() + + # Access arguments + batch_size = args.batch_size + n_steps = args.n_steps + eps_greedy = args.eps_greedy + lr = args.lr + wd = args.wd + n_cells = args.n_cells + distributional = args.distributional + dataset = args.dataset + + env = OpenMLEnv(dataset, batch_size=[batch_size]) + n_actions = env.action_spec.space.n + if distributional: + # does not really make sense since the value is either 0 or 1 and hopefully we + # should always predict 1 + nbins = 2 + model = MLP( + out_features=(nbins, n_actions), + depth=3, + num_cells=n_cells, + activation_class=nn.Tanh, + ) + actor = DistributionalQValueActor( + model, support=torch.arange(2), action_space="categorical" + ) + actor(env.reset()) + loss = DistributionalDQNLoss( + actor, + gamma=0.0, + ) + else: + model = MLP( + out_features=n_actions, depth=3, num_cells=n_cells, activation_class=nn.Tanh + ) + actor = QValueActor(model, action_space="categorical") + actor(env.reset()) + loss = DQNLoss(actor, gamma=0.0, loss_function="smooth_l1") + policy = EGreedyWrapper( + actor, eps_greedy, 0.0, annealing_num_steps=n_steps, spec=env.action_spec + ) + optim = torch.optim.Adam(loss.parameters(), lr, weight_decay=wd) + + pbar = tqdm.tqdm(range(n_steps)) + + init_r = None + init_loss = None + for i in pbar: + with set_exploration_mode("random"): + data = env.step(policy(env.reset())) + loss_vals = loss(data) + loss_val = sum( + value for key, value in loss_vals.items() if key.startswith("loss") + ) + loss_val.backward() + optim.step() + optim.zero_grad() + if i % 10 == 0: + test_data = env.step(policy(env.reset())) + if init_r is None: + init_r = test_data["next", "reward"].sum() / env.numel() + if init_loss is None: + init_loss = loss_val.detach().item() + pbar.set_description( + f"reward: {test_data['next', 'reward'].sum() / env.numel(): 4.4f} (init={init_r: 4.4f}), " + f"training reward {data['next', 'reward'].sum() / env.numel() : 4.4f}, " + f"loss {loss_val: 4.4f} (init: {init_loss: 4.4f})" + ) + policy.step() diff --git a/test/test_libs.py b/test/test_libs.py index ae564d72f71..8a0f820c87f 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -23,20 +23,48 @@ from torchrl._utils import implement_for from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import RandomPolicy -from torchrl.data.datasets.d4rl import _has_d4rl, D4RL_ERR, D4RLExperienceReplay +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement -from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs import ( + Compose, + DoubleToFloat, + EnvCreator, + ParallelEnv, + RenameTransform, +) from torchrl.envs.libs.brax import _has_brax, BraxEnv from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv +from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.utils import check_env_specs from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator + +D4RL_ERR = None +try: + import d4rl # noqa + + _has_d4rl = True +except Exception as err: + # many things can wrong when importing d4rl :( + _has_d4rl = False + D4RL_ERR = err + +SKLEARN_ERR = None +try: + import sklearn # noqa + + _has_sklearn = True +except ModuleNotFoundError as err: + _has_sklearn = False + SKLEARN_ERR = err + if _has_gym: try: import gymnasium as gym @@ -1194,6 +1222,45 @@ def test_d4rl_iteration(self, task, split_trajs): print(f"completed test after {time.time()-t0}s") +@pytest.mark.skipif(not _has_sklearn, reason=f"Scikit-learn not found: {SKLEARN_ERR}") +@pytest.mark.parametrize( + "dataset", + [ + "adult_num", + "adult_onehot", + "mushroom_num", + "mushroom_onehot", + "covertype", + "shuttle", + "magic", + ], +) +class TestOpenML: + @pytest.mark.parametrize("batch_size", [(), (2,), (2, 3)]) + def test_env(self, dataset, batch_size): + env = OpenMLEnv(dataset, batch_size=batch_size) + td = env.reset() + assert td.shape == torch.Size(batch_size) + td = env.rand_step(td) + assert td.shape == torch.Size(batch_size) + assert "index" not in td.keys() + check_env_specs(env) + + def test_data(self, dataset): + data = OpenMLExperienceReplay( + dataset, + batch_size=2048, + transform=Compose( + RenameTransform(["X"], ["observation"]), + DoubleToFloat(["observation"]), + ), + ) + # check that dataset eventually runs out + for i, _ in enumerate(data): # noqa: B007 + continue + assert len(data) // 2048 in (i, i - 1) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 9e2640522ca..d7dabbf70fc 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -750,7 +750,10 @@ def rollout(self) -> TensorDictBase: self._tensordict_out.lock() self._step_and_maybe_reset() - if self.interruptor is not None and self.interruptor.collection_stopped(): + if ( + self.interruptor is not None + and self.interruptor.collection_stopped() + ): break return self._tensordict_out diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 3ebb41a7bd9..438b6fe064b 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -5,6 +5,8 @@ from typing import Callable, Optional +import gym # noqa + import numpy as np import torch @@ -16,15 +18,6 @@ from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.replay_buffers.writers import Writer -D4RL_ERR = None -try: - import d4rl, gym # noqa - - _has_d4rl = True -except ModuleNotFoundError as err: - _has_d4rl = False - D4RL_ERR = err - class D4RLExperienceReplay(TensorDictReplayBuffer): """An Experience replay class for D4RL. @@ -93,6 +86,18 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): """ + D4RL_ERR = None + + @classmethod + def _import_d4rl(cls): + try: + import d4rl # noqa + + cls._has_d4rl = True + except ModuleNotFoundError as err: + cls._has_d4rl = False + cls.D4RL_ERR = err + def __init__( self, name, @@ -108,8 +113,10 @@ def __init__( **env_kwargs, ): - if not _has_d4rl: - raise ImportError("Could not import d4rl") from D4RL_ERR + type(self)._import_d4rl() + + if not self._has_d4rl: + raise ImportError("Could not import d4rl") from self.D4RL_ERR self.from_env = from_env if from_env: dataset = self._get_dataset_from_env(name, env_kwargs) @@ -134,6 +141,13 @@ def __init__( def _get_dataset_direct(self, name, env_kwargs): from torchrl.envs.libs.gym import GymWrapper + try: + import d4rl + except ModuleNotFoundError: + raise ModuleNotFoundError( + "d4rl not found or not importable" + ) from self.D4RL_ERR + env = GymWrapper(gym.make(name)) dataset = d4rl.qlearning_dataset(env._env, **env_kwargs) diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py new file mode 100644 index 00000000000..78b90793682 --- /dev/null +++ b/torchrl/data/datasets/openml.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import numpy as np +from tensordict.tensordict import TensorDict + +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.replay_buffers import Sampler, SamplerWithoutReplacement, Writer + + +class OpenMLExperienceReplay(TensorDictReplayBuffer): + """An experience replay for OpenML data. + + This class provides an easy entry point for public datasets. + See "Dua, D. and Graff, C. (2017) UCI Machine Learning Repository. http://archive.ics.uci.edu/ml" + + The data is accessed via scikit-learn. Make sure sklearn and pandas are + installed before retrieving the data: + + .. code-block:: + + $ pip install scikit-learn pandas -U + + Args: + name (str): the following datasets are supported: + ``"adult_num"``, ``"adult_onehot"``, ``"mushroom_num"``, ``"mushroom_onehot"``, + ``"covertype"``, ``"shuttle"`` and ``"magic"``. + batch_size (int): the batch size to use during sampling. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default RoundRobinWriter() will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. + + """ + + def __init__( + self, + name: str, + batch_size: int, + sampler: Optional[Sampler] = None, + writer: Optional[Writer] = None, + collate_fn: Optional[Callable] = None, + pin_memory: bool = False, + prefetch: Optional[int] = None, + transform: Optional["Transform"] = None, # noqa-F821 + ): + + if sampler is None: + sampler = SamplerWithoutReplacement() + + dataset = self._get_data( + name, + ) + self.max_outcome_val = dataset["y"].max().item() + + storage = LazyMemmapStorage(dataset.shape[0]) + super().__init__( + batch_size=batch_size, + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + transform=transform, + ) + self.extend(dataset) + + @classmethod + def _get_data(cls, dataset_name): + try: + import pandas # noqa: F401 + from sklearn.datasets import fetch_openml + from sklearn.preprocessing import ( + LabelEncoder, + OneHotEncoder, + StandardScaler, + ) + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Make sure scikit-learn and pandas are installed before " + f"creating a {cls.__name__} instance." + ) + if dataset_name in ["adult_num", "adult_onehot"]: + X, y = fetch_openml("adult", version=1, return_X_y=True) + is_NaN = X.isna() + row_has_NaN = is_NaN.any(axis=1) + X = X[~row_has_NaN] + # y = y[~row_has_NaN] + y = X["occupation"] + X = X.drop(["occupation"], axis=1) + cat_ix = X.select_dtypes(include=["category"]).columns + num_ix = X.select_dtypes(include=["int64", "float64"]).columns + encoder = LabelEncoder() + # now apply the transformation to all the columns: + for col in cat_ix: + X[col] = encoder.fit_transform(X[col]) + y = encoder.fit_transform(y) + if dataset_name == "adult_onehot": + cat_features = OneHotEncoder(sparse=False).fit_transform(X[cat_ix]) + num_features = StandardScaler().fit_transform(X[num_ix]) + X = np.concatenate((num_features, cat_features), axis=1) + else: + X = StandardScaler().fit_transform(X) + elif dataset_name in ["mushroom_num", "mushroom_onehot"]: + X, y = fetch_openml("mushroom", version=1, return_X_y=True) + encoder = LabelEncoder() + # now apply the transformation to all the columns: + for col in X.columns: + X[col] = encoder.fit_transform(X[col]) + # X = X.drop(["veil-type"],axis=1) + y = encoder.fit_transform(y) + if dataset_name == "mushroom_onehot": + X = OneHotEncoder(sparse=False).fit_transform(X) + else: + X = StandardScaler().fit_transform(X) + elif dataset_name == "covertype": + # https://www.openml.org/d/150 + # there are some 0/1 features -> consider just numeric + X, y = fetch_openml("covertype", version=3, return_X_y=True) + X = StandardScaler().fit_transform(X) + y = LabelEncoder().fit_transform(y) + elif dataset_name == "shuttle": + # https://www.openml.org/d/40685 + # all numeric, no missing values + X, y = fetch_openml("shuttle", version=1, return_X_y=True) + X = StandardScaler().fit_transform(X) + y = LabelEncoder().fit_transform(y) + elif dataset_name == "magic": + # https://www.openml.org/d/1120 + # all numeric, no missing values + X, y = fetch_openml("MagicTelescope", version=1, return_X_y=True) + X = StandardScaler().fit_transform(X) + y = LabelEncoder().fit_transform(y) + else: + raise RuntimeError("Dataset does not exist") + return TensorDict({"X": X, "y": y}, X.shape[:1]) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 8ed1db370f7..2ad9b3d65b9 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -176,6 +176,15 @@ def __getitem__(self, index: Union[int, torch.Tensor]) -> Any: if not isinstance(index, INT_CLASSES): data = self._collate_fn(data) + if self._transform is not None: + is_td = True + if not isinstance(data, TensorDictBase): + data = TensorDict({"data": data}, []) + is_td = False + data = self._transform(data) + if not is_td: + data = data["data"] + return data def state_dict(self) -> Dict[str, Any]: diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 030cbdcdc2e..78c93ac55eb 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -8,6 +8,8 @@ from warnings import warn import torch + +from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -19,8 +21,6 @@ TensorSpec, UnboundedContinuousTensorSpec, ) - -from torchrl._utils import implement_for from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index f7e4e498cb3..8cbc9dfb5b4 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -1,81 +1,34 @@ -import numpy as np -from tensordict.tensordict import TensorDict, TensorDictBase +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. -from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, UnboundedDiscreteTensorSpec -from torchrl.envs import EnvBase -from sklearn.datasets import fetch_openml -from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler, LabelEncoder import torch +from tensordict.tensordict import TensorDict, TensorDictBase -X, y = fetch_openml('adult', version=1, return_X_y=True) +from torchrl.data import ( + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) +from torchrl.data.datasets.openml import OpenMLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.envs import Compose, DoubleToFloat, EnvBase, RenameTransform -def _get_data(dataset_name): - if dataset_name in ['adult_num', 'adult_onehot']: - X, y = fetch_openml('adult', version=1, return_X_y=True) - is_NaN = X.isna() - row_has_NaN = is_NaN.any(axis=1) - X = X[~row_has_NaN] - # y = y[~row_has_NaN] - y = X["occupation"] - X = X.drop(["occupation"],axis=1) - cat_ix = X.select_dtypes(include=['category']).columns - num_ix = X.select_dtypes(include=['int64', 'float64']).columns - encoder = LabelEncoder() - # now apply the transformation to all the columns: - for col in cat_ix: - X[col] = encoder.fit_transform(X[col]) - y = encoder.fit_transform(y) - if dataset_name == 'adult_onehot': - cat_features = OneHotEncoder(sparse=False).fit_transform(X[cat_ix]) - num_features = StandardScaler().fit_transform(X[num_ix]) - X = np.concatenate((num_features, cat_features), axis=1) - else: - X = StandardScaler().fit_transform(X) - elif dataset_name in ['mushroom_num', 'mushroom_onehot']: - X, y = fetch_openml('mushroom', version=1, return_X_y=True) - encoder = LabelEncoder() - # now apply the transformation to all the columns: - for col in X.columns: - X[col] = encoder.fit_transform(X[col]) - # X = X.drop(["veil-type"],axis=1) - y = encoder.fit_transform(y) - if dataset_name == 'mushroom_onehot': - X = OneHotEncoder(sparse=False).fit_transform(X) - else: - X = StandardScaler().fit_transform(X) - elif dataset_name == 'covertype': - # https://www.openml.org/d/150 - # there are some 0/1 features -> consider just numeric - X, y = fetch_openml('covertype', version=3, return_X_y=True) - X = StandardScaler().fit_transform(X) - y = LabelEncoder().fit_transform(y) - elif dataset_name == 'shuttle': - # https://www.openml.org/d/40685 - # all numeric, no missing values - X, y = fetch_openml('shuttle', version=1, return_X_y=True) - X = StandardScaler().fit_transform(X) - y = LabelEncoder().fit_transform(y) - elif dataset_name == 'magic': - # https://www.openml.org/d/1120 - # all numeric, no missing values - X, y = fetch_openml('MagicTelescope', version=1, return_X_y=True) - X = StandardScaler().fit_transform(X) - y = LabelEncoder().fit_transform(y) - else: - raise RuntimeError('Dataset does not exist') - return TensorDict({"X": X, "y": y}, X.shape[:1]) -def make_composite_from_td(td): +def _make_composite_from_td(td): # custom funtion to convert a tensordict in a similar spec structure # of unbounded values. composite = CompositeSpec( { - key: make_composite_from_td(tensor) + key: _make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) else UnboundedContinuousTensorSpec( dtype=tensor.dtype, device=tensor.device, shape=tensor.shape - ) if tensor.dtype in (torch.float16, torch.float32, torch.float64) else - UnboundedDiscreteTensorSpec( + ) + if tensor.dtype in (torch.float16, torch.float32, torch.float64) + else UnboundedDiscreteTensorSpec( dtype=tensor.dtype, device=tensor.device, shape=tensor.shape ) for key, tensor in td.items() @@ -86,34 +39,95 @@ def make_composite_from_td(td): class OpenMLEnv(EnvBase): + """An environment interface to OpenML data to be used in bandits contexts. + + Args: + dataset_name (str): the following datasets are supported: + ``"adult_num"``, ``"adult_onehot"``, ``"mushroom_num"``, ``"mushroom_onehot"``, + ``"covertype"``, ``"shuttle"`` and ``"magic"``. + device (torch.device or compatible, optional): the device where the input + and output data is to be expected. Defaults to ``"cpu"``. + batch_size (torch.Size or compatible, optional): the batch size of the environment, + ie. the number of elements samples and returned when a :meth:`~.reset` is + called. Defaults to an empty batch size, ie. one element is sampled + at a time. + + Examples: + >>> env = OpenMLEnv("adult_onehot", batch_size=[2, 3]) + >>> print(env.reset()) + TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([2, 3, 106]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + y: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([2, 3]), + device=cpu, + is_shared=False) + + """ + def __init__(self, dataset_name, device="cpu", batch_size=None): if batch_size is None: - batch_size = [1] + batch_size = torch.Size([]) + else: + batch_size = torch.Size(batch_size) self.dataset_name = dataset_name - self._data = _get_data(dataset_name) + self._data = OpenMLExperienceReplay( + dataset_name, + batch_size=batch_size.numel(), + sampler=SamplerWithoutReplacement(drop_last=True), + transform=Compose( + RenameTransform(["X"], ["observation"]), + DoubleToFloat(["observation"]), + ), + ) super().__init__(device=device, batch_size=batch_size) - self.observation_spec = make_composite_from_td(self._data[:self.batch_size.numel()]) - self.action_spec = self.observation_spec["y"].clone() + self.observation_spec = _make_composite_from_td( + self._data[: self.batch_size.numel()] + .reshape(self.batch_size) + .exclude("index") + ) + self.action_spec = DiscreteTensorSpec( + self._data.max_outcome_val + 1, shape=self.batch_size, device=self.device + ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(*self.batch_size, 1)) def _reset(self, tensordict): - r_id = torch.randint(self._data.shape[0], (self.batch_size.numel(),)) - data = self._data[r_id] + data = self._data.sample() + data = data.exclude("index") + data = data.reshape(self.batch_size).to(self.device) return data def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: - action = tensordict["action"] + action = tensordict.get("action") + y = tensordict.get("y", None) + if y is None: + raise KeyError( + "did not find the 'y' key in the input tensordict. " + "Make sure you call env.step() on a tensordict that results " + "from env.reset()." + ) + + if action.shape != y.shape: + raise RuntimeError( + f"Action and outcome shape differ: {action.shape} vs {y.shape}." + ) reward = (action == tensordict["y"]).float().unsqueeze(-1) done = torch.ones_like(reward, dtype=torch.bool) - return TensorDict({ - "done": done, - "reward": reward, - "X": tensordict["X"], - "y": tensordict["y"], - }, self.batch_size) + td = TensorDict( + { + "done": done, + "reward": reward, + **tensordict.select(*self.observation_spec.keys()), + }, + self.batch_size, + device=self.device, + ) + return td.select().set("next", td) def _set_seed(self, seed): self.rng = torch.random.manual_seed(seed) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index ba33932da76..068836bb4d4 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -169,6 +169,7 @@ def __init__( _out_features_num = out_features if not isinstance(out_features, Number): + print(out_features, type(out_features)) _out_features_num = prod(out_features) self.out_features = out_features self._out_features_num = _out_features_num diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 3b16d8e6d2b..7ea094aa2d3 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -124,9 +124,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if spec is not None: if isinstance(spec, CompositeSpec): spec = spec[self.action_key] - out = ( - cond * spec.rand(tensordict.shape).to(out.device) + (1 - cond) * out - ) + out = cond * spec.rand().to(out.device) + (1 - cond) * out else: raise RuntimeError( "spec must be provided by the policy or directly to the exploration wrapper." diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index d79f202fca4..65bf5f9bbd8 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -101,6 +101,9 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: pred_val = td_copy.get("action_value") if self.action_space == "categorical": + if action.shape != pred_val.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1) else: action = action.to(torch.float) @@ -180,11 +183,12 @@ def _log_ps_a_default(action, action_log_softmax, batch_size, atoms): @staticmethod def _log_ps_a_categorical(action, action_log_softmax): # Reshaping action of shape `[*batch_sizes, 1]` to `[*batch_sizes, atoms, 1]` for gather. + if action.shape[-1] != 1: + action = action.unsqueeze(-1) action = action.unsqueeze(-2) new_shape = [-1] * len(action.shape) new_shape[-2] = action_log_softmax.shape[-2] # calculating atoms action = action.expand(new_shape) - return torch.gather(action_log_softmax, -1, index=action).squeeze(-1) def forward(self, input_tensordict: TensorDictBase) -> TensorDict: diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index e3d1050bc3b..1d0fcdf6b25 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -334,6 +334,8 @@ def next_state_value( rewards = tensordict.get(("next", "reward")).squeeze(-1) done = tensordict.get(("next", "done")).squeeze(-1) + if done.all() or gamma == 0: + return rewards if pred_next_val is None: next_td = step_mdp(tensordict) # next_observation -> observation From 150b2c431d5bcdd0faedfbe2ef738eb8e254d345 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 27 Mar 2023 09:03:16 +0100 Subject: [PATCH 3/5] bf --- test/test_collector.py | 2 +- torchrl/collectors/collectors.py | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index fe264585c47..07e1d591607 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1251,7 +1251,7 @@ def env_fn(seed): frames_per_batch=50, total_frames=200, device="cpu", - interrupter=interruptor, + interruptor=interruptor, split_trajs=False, ) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 288f81a8c42..853c6c8970e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -86,6 +86,8 @@ class _Interruptor: by a lock to ensure thread-safety. """ + # interrupter vs interruptor: google trends seems to indicate that "or" is more + # widely used than "er" even if my IDE complains about that... def __init__(self): self._collect = True self._lock = mp.Lock() @@ -400,7 +402,7 @@ class SyncDataCollector(DataCollectorBase): tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False. - interrupter (_Interruptor, optional): + interruptor (_Interruptor, optional): An _Interruptor object that can be used from outside the class to control rollout collection. The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement strategies such as preeptively stopping rollout collection. @@ -482,7 +484,7 @@ def __init__( exploration_mode: str = DEFAULT_EXPLORATION_MODE, return_same_td: bool = False, reset_when_done: bool = True, - interrupter=None, + interruptor=None, ): self.closed = True @@ -616,7 +618,7 @@ def __init__( self.split_trajs = split_trajs self._has_been_done = None self._exclude_private_keys = True - self.interrupter = interrupter + self.interruptor = interruptor # for RPC def next(self): @@ -1139,10 +1141,10 @@ def device_err_msg(device_name, devices_list): self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) manager = _InterruptorManager() manager.start() - self.interrupter = manager._Interruptor() + self.interruptor = manager._Interruptor() else: self.preemptive_threshold = 1.0 - self.interrupter = None + self.interruptor = None self._run_processes() self._exclude_private_keys = True @@ -1195,7 +1197,7 @@ def _run_processes(self) -> None: "exploration_mode": self.exploration_mode, "reset_when_done": self.reset_when_done, "idx": i, - "interruptor": self.interrupter, + "interruptor": self.interruptor, } proc = mp.Process(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself @@ -1504,13 +1506,13 @@ def iterator(self) -> Iterator[TensorDictBase]: i += 1 max_traj_idx = None - if self.interrupter is not None and self.preemptive_threshold < 1.0: - self.interrupter.start_collection() + if self.interruptor is not None and self.preemptive_threshold < 1.0: + self.interruptor.start_collection() while self.queue_out.qsize() < int( self.num_workers * self.preemptive_threshold ): continue - self.interrupter.stop_collection() + self.interruptor.stop_collection() # Now wait for stragglers to return while self.queue_out.qsize() < int(self.num_workers): continue @@ -1941,7 +1943,7 @@ def _main_async_collector( exploration_mode=exploration_mode, reset_when_done=reset_when_done, return_same_td=True, - interrupter=interruptor, + interruptor=interruptor, ) if verbose: print("Sync data collector created") From 440687a3391fe011d7f9a2dac4f462822b34a10a Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 27 Mar 2023 09:15:34 +0100 Subject: [PATCH 4/5] bf --- torchrl/data/datasets/d4rl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 86d88d2f3e2..3edd02f1db9 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -156,7 +156,8 @@ def _get_dataset_direct(self, name, env_kwargs): if not self._has_d4rl: raise ImportError("Could not import d4rl") from self.D4RL_ERR - import d4rl, gym + import d4rl + import gym env = GymWrapper(gym.make(name)) dataset = d4rl.qlearning_dataset(env._env, **env_kwargs) @@ -227,6 +228,7 @@ def _get_dataset_from_env(self, name, env_kwargs): raise RuntimeError("env_kwargs cannot be passed with using from_env=True") # we do a local import to avoid circular import issues from torchrl.envs.libs.gym import GymWrapper + import gym env = GymWrapper(gym.make(name)) dataset = make_tensordict( From c5b957e2db5fe669b2fc395f833e3b4ebf05b5a1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 27 Mar 2023 11:47:25 +0100 Subject: [PATCH 5/5] lint --- torchrl/data/datasets/d4rl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 3edd02f1db9..087793937f3 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -226,9 +226,10 @@ def _get_dataset_from_env(self, name, env_kwargs): """ if env_kwargs: raise RuntimeError("env_kwargs cannot be passed with using from_env=True") + import gym + # we do a local import to avoid circular import issues from torchrl.envs.libs.gym import GymWrapper - import gym env = GymWrapper(gym.make(name)) dataset = make_tensordict(