Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] RuntimeError: index -9223372036854775808 is out of bounds for dimension 1 with size 1 #2402

Closed
2 tasks
Sui-Xing opened this issue Aug 26, 2024 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@Sui-Xing
Copy link

Describe the bug

ProbabilisticActor cannot be configured with return_log_prob=True; it will throw an error in version 0.5.0, but switching back to version 0.4.0 resolves the issue.

To Reproduce

policy_module = TensorDictModule(net_policy,
                                 in_keys=['hidden'],
                                 out_keys=[
                                     ("params", "action1", "logits"),
                                     ("params", "action2", "logits"),
                                     ("params", "action3", "logits"),
                                     ("params", "action4", "logits"),
                                     ("params", "action5", "logits")
                                 ])
actor = ProbabilisticActor(
    module=policy_module,
    in_keys=["params"],
    distribution_class=CompositeDistribution,
    distribution_kwargs={
        "distribution_map": {
            "action1": d.Categorical,
            "action2": d.Categorical,
            "action3": d.Categorical,
            "action4": d.Categorical,
            "action5": d.Categorical
        },

    },
    return_log_prob=True,
)
net_value = Net_Value(num_cells, device=device)
net_value.apply(init_weights)
net_value(hidden)

value_module = ValueOperator(
    module=net_value,
    in_keys=["hidden"],
    out_keys=["state_action_value"]
)

a_c_model = ActorCriticOperator(shared_module, actor, value_module)

test_td = a_c_model.get_policy_operator()(td)
Traceback (most recent call last):
  File "********", line 159, in <module>
    test_td = a_c_model.get_policy_operator()(td)
  File "E:\tools\miniconda\envs\***\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\nn\common.py", line 297, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\_contextlib.py", line 127, in decorate_context
    return func(*args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\nn\utils.py", line 293, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\nn\probabilistic.py", line 655, in forward
    return self.module[-1](tensordict_out, _requires_sample=self._requires_sample)
  File "E:\tools\miniconda\envs\***\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\nn\common.py", line 297, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\_contextlib.py", line 127, in decorate_context
    return func(*args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\nn\utils.py", line 293, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\nn\probabilistic.py", line 439, in forward
    tensordict_out = dist.log_prob(tensordict_out)
  File "E:\tools\miniconda\envs\***\lib\site-packages\tensordict\nn\distributions\composite.py", line 150, in log_prob
    d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))
  File "E:\tools\miniconda\envs\***\lib\site-packages\torch\distributions\categorical.py", line 142, in log_prob
    return log_pmf.gather(-1, value).squeeze(-1)
RuntimeError: index -9223372036854775808 is out of bounds for dimension 1 with size 1

Expected behavior

it should output the correct log_prob.

System info

Describe the characteristic of your environment:

  • pip install
  • Python version: 3.10
  • Versions of any other relevant libraries: torch==2.4.0,torchrl==0.5.0,tensordict==0.5.0
  • cuda11.8 or cpu

Checklist

  • I have checked, and did not find any similar issues.
  • I have read the documentation (required)
@Sui-Xing Sui-Xing added the bug Something isn't working label Aug 26, 2024
@vmoens
Copy link
Contributor

vmoens commented Aug 26, 2024

Cc @albertbou92
Could it be caused by pytorch/tensordict#961?

@Sui-Xing
Copy link
Author

Cc @albertbou92 Could it be caused by pytorch/tensordict#961?

I did use the CompositeDistribution class.

class Net_Policy(nn.Module):
    def __init__(self, num_cells, action_dims, device):
        super().__init__()
        self.policy_mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(num_cells + 128, num_cells),
                nn.ReLU(),
                nn.Linear(num_cells, action_dim)
            ).to(device) for action_dim in action_dims
        ])


    def forward(self, hc):
        policies = [mlp(hc) for mlp in self.policy_mlps]
        policy_dict = {
             ("params", f"action{i + 1}", ): policy for i, policy in enumerate(policies)
        }
        return policy_dict

    def get_dist(self, tensordict):
        params = self(tensordict['feature1'], tensordict['feature2'],...)
        params_td = TensorDict(params, batch_size=tensordict.batch_size)
        return CompositeDistribution(
            params=params_td["params"],
            distribution_map={

                "action1": d.Categorical,
                "action2": d.Categorical,
                "action3": d.Categorical,
                "action4": d.Categorical,
                "action5": d.Categorical

            },

        )

@Sui-Xing Sui-Xing reopened this Aug 26, 2024
@Sui-Xing Sui-Xing reopened this Aug 26, 2024
@albertbou92
Copy link
Contributor

Hi! I will look into it. Will go back to you ASAP.

@albertbou92
Copy link
Contributor

I can run a ProbabilisticActor with the latest TensorDict code and set return_log_prob=True with no issues.
Could provide a minimal reproducible full code example? I can further look into it

@Sui-Xing
Copy link
Author

I can run a ProbabilisticActor with the latest TensorDict code and set return_log_prob=True with no issues. Could provide a minimal reproducible full code example? I can further look into it

Sure, I will provide a demo that can trigger this error. The code for the maze environment might not be correct, but it can accurately trigger the "RuntimeError: index -9223372036854775808 is out of bounds for dimension 0 with size 8" error.
I'm not sure, but I suspect it might be caused by the get_dist function.

train.py

import warnings

from tensordict import TensorDict

from env import MazeEnv
from torch import nn, distributions as d
warnings.filterwarnings("ignore")
from torch import multiprocessing
import matplotlib.pyplot as plt
from torchrl.collectors import SyncDataCollector, MultiaSyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator, EGreedyModule
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from collections import defaultdict
import torch
from tensordict.nn import TensorDictModule, TensorDictSequential, CompositeDistribution
from torch import nn
from torchrl.envs.utils import check_env_specs, step_mdp
from tqdm import tqdm
import torch.nn.init as init

import torch
import numpy as np
import random

# 设置Python内置随机模块的种子
random.seed(42)

# 设置NumPy的种子
np.random.seed(42)

# 设置PyTorch的CPU和CUDA随机种子(如果使用GPU)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# PPO参数设置
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
device = "cpu"
num_cells = 128  # number of cells in each layer i.e. output dim.
lr = 1e-4
max_grad_norm = 1.0

frames_per_batch = 3000
total_frames =  10*frames_per_batch * 200

sub_batch_size = 64
num_epochs = 30
clip_epsilon = 0.2
gamma = 0.9995
lmbda = 0.95
entropy_eps = 1e-4
greedy_epsilon = 0.5

# 初始化迷宫环境
env = MazeEnv(device=device)
check_env_specs(env)

# 重置环境
td = env.reset()
print("reset tensordict", td)

# 执行随机步
td = env.rand_step(td)
print("random step tensordict", td)

# 调试:打印 observation_spec
print("observation_spec:", env.observation_spec)

# 使用实际观测值来初始化 dummy_input
dummy_input = td["position"].unsqueeze(0).to(device)
print("dummy_input shape:", dummy_input.shape)

class Policy(nn.Module):
    def __init__(self, num_cells, action_dims, device):
        super().__init__()
        self.policy_mlps = nn.ModuleList([
            nn.Sequential(
                nn.LazyLinear(num_cells),
                nn.ReLU(),
                nn.Linear(num_cells, action_dim)
            ).to(device) for action_dim in action_dims
        ])


    def forward(self, hc):
        policies = [mlp(hc) for mlp in self.policy_mlps]
        # 构建返回的字典
        policy_dict = {
            ("params", "action1", "logits"):   policies[0],
            ("params", "action2",  "logits"):  policies[1],

        }
        # print("policy_dict : ",policy_dict)
        return policy_dict

    def get_dist(self, tensordict):
        params = self(tensordict['position'])
        params_td = TensorDict(params, batch_size=tensordict.batch_size)

        return CompositeDistribution(
            params=params_td["params"],
            distribution_map={

                "action1": d.Categorical,
                "action2": d.Categorical,


            },

        )

actor_net = Policy(128,[8,1],device=device)
# Value Net
value_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(1, device=device),
)

# Initialize LazyLinear layers with dummy input
actor_net(dummy_input)
value_net(dummy_input)


# Apply Xavier initialization
def init_weights(m):
    if isinstance(m, (nn.Linear, nn.LazyLinear)):
        init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.zeros_(m.bias)


actor_net.apply(init_weights)
value_net.apply(init_weights)

policy_module = TensorDictModule(
    actor_net, in_keys=["position"], out_keys=[
                                     ("params", "action1", "logits"),
                                     ("params", "action2", "logits"),
                                 ]
)

policy_module = ProbabilisticActor(
    module=policy_module,
    in_keys=["params"],
    distribution_class=CompositeDistribution,
    distribution_kwargs={
        "distribution_map": {
            "action1": d.Categorical,
            "action2": d.Categorical,

        },

    },
    return_log_prob=True,
)
policy_module(dummy_input)

value_module = ValueOperator(
    module=value_net,
    in_keys=["position"],
)

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)

advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
)

loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
    normalize_advantage=True,
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

logs = defaultdict(list)

pbar = tqdm(total=total_frames)
eval_str = ""

if __name__ == '__main__':
    env_maker = lambda: MazeEnv(device=device)
    # 在训练循环开始前定义
    initial_epsilon = 0.2
    final_epsilon = 0.01
    collector = SyncDataCollector(
        env,
        policy=policy_module,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
        split_trajs=False,
        device=device,
        exploration_type=ExplorationType.RANDOM

    )

    explorative_policy = TensorDictSequential(policy_module,
                                              EGreedyModule(eps_init=initial_epsilon,
                                                            eps_end=final_epsilon,
                                                            spec=env.action_spec))

    # 训练循环
    for i, tensordict_data in enumerate(collector):
        # env.reset()

        tensordict_data = explorative_policy(tensordict_data)
        for _ in range(num_epochs):
            advantage_module(tensordict_data)
            data_view = tensordict_data.reshape(-1)
            replay_buffer.extend(data_view.cpu())
            for _ in range(frames_per_batch // sub_batch_size):
                subdata = replay_buffer.sample(sub_batch_size)
                loss_vals = loss_module(subdata.to(device))
                loss_value = (
                        loss_vals["loss_objective"]
                        + loss_vals["loss_critic"]
                        + loss_vals["loss_entropy"]
                )

                loss_value.backward()
                torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
                optim.step()
                optim.zero_grad()

        logs["reward"].append(tensordict_data["next", "reward"].mean().item())
        pbar.update(tensordict_data.numel())
        cum_reward_str = (
            f"average reward={logs['reward'][-1]: 4.5f} (init={logs['reward'][0]: 4.5f})"
        )
        logs["step_count"].append(tensordict_data["step_count"].max().item())
        stepcount_str = f"step count (max): {logs['step_count'][-1]}"
        logs["lr"].append(optim.param_groups[0]["lr"])
        lr_str = f"lr policy: {logs['lr'][-1]: 4.5f}"
        if (i + 1) % 10 == 0:
            with set_exploration_type(ExplorationType.MODE), torch.no_grad():
                eval_rollout = env.rollout(1000, policy_module)

                logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
                logs["eval reward (sum)"].append(
                    eval_rollout["next", "reward"].sum().item()
                )
                logs["eval step_count"].append(0)
                eval_str = (
                    f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.5f} "
                    f"(init: {logs['eval reward (sum)'][0]: 4.5f}), "
                    f"eval step-count: {logs['eval step_count'][-1]}"
                )
                del eval_rollout
        if i % 20 == 0 and i > 1:
            # 在训练循环结束后
            model_state = {
                'policy_state_dict': policy_module.state_dict(),
                'value_state_dict': value_module.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'epoch': num_epochs
            }
            torch.save(model_state, 'ckpt/ppo_maze_model.pth')
            print("Model saved!")
        pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))

        scheduler.step()
    pbar.close()
    collector.shutdown()

    env.reset()
    eval_rollout = env.rollout(2000, policy_module)

    # print("评估结果:\n\n",eval_rollout['next'].values())
    for k in eval_rollout['next'].keys():
        if k == "position":
            for v in eval_rollout['next'].get(k):
                print("评估结果:\n", k + "\n", v.cpu().numpy())

    # 结果展示
    plt.figure(figsize=(10, 10))
    plt.subplot(2, 2, 1)
    plt.plot(logs["reward"])
    plt.title("training rewards (average)")
    plt.subplot(2, 2, 2)
    plt.plot(logs["step_count"])
    plt.title("Max step count (training)")
    plt.subplot(2, 2, 3)
    plt.plot(logs["eval reward (sum)"])
    plt.title("Return (test)")
    plt.subplot(2, 2, 4)
    plt.plot(logs["eval step_count"])
    plt.title("Max step count (test)")
    plt.show()

board.py

class Board:
    width = 100
    height = 100
    start_pos = (1, 1)
    end_pos = (78, 78)
    components = [[(30, 30), (50, 70)],
                  [(10, 8), (15, 10)],
                  [(12, 2), (15, 7)],
                  [(2, 10), (5, 11)],
                  [(2, 13), (5, 14)],
                  # [(15, 12), (18, 18)],
                  ]
    lines = []

    def __init__(self):
        self.width = 20
        self.height = 20
        self.start_pos = (0, 0)
        self.end_pos = (18, 12)
        self.components = [[(3, 3), (5, 7)],
                           [(10, 8), (17, 10)]]

env.py

from typing import Optional

import torch
from tensordict import TensorDict, TensorDictBase

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec, DiscreteTensorSpec
from torchrl.envs import (
    EnvBase,
)

from board import Board


# 定义迷宫环境
# 定义迷宫环境
class MazeEnv(EnvBase):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30,
    }
    batch_locked = False

    def __init__(self, maze_size=(Board.width, Board.height), start=(0, 0), goal=(8, 8), seed=None, device="cpu"):
        self.maze_size = maze_size
        self.start = Board.start_pos
        # self.start=start
        self.goal = Board.end_pos

        td_params = self.gen_params()
        super().__init__(device=device, batch_size=[])
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)
        self.visit_counts = torch.zeros(maze_size, device=device)
        self.start = torch.tensor(self.start)

        self.visit_counts[self.start[0].long(), self.start[1].long()] += 1
        self.components = torch.tensor(Board.components, device=self.device)

    @staticmethod
    def gen_params(batch_size=None) -> TensorDictBase:
        """Returns a ``tensordict`` containing the maze parameters."""
        if batch_size is None:
            batch_size = []
        td = TensorDict(
            {
                "params": TensorDict(
                    {
                        "maze_size": torch.tensor([Board.width, Board.height]),
                        "start": torch.tensor([Board.start_pos[0], Board.start_pos[1]]),
                        "goal": torch.tensor([Board.end_pos[0], Board.end_pos[1]]),
                    },
                    [],
                )
            },
            [],
        )
        if batch_size:
            td = td.expand(batch_size).contiguous()
        return td

    def _make_spec(self, td_params):
        self.observation_spec = CompositeSpec(
            position=BoundedTensorSpec(
                low=torch.tensor([0, 0]),
                high=td_params["params", "maze_size"] - 1,
                shape=(2,),
                dtype=torch.float32,
            ),
            params=make_composite_from_td(td_params["params"]),
            step_count=UnboundedContinuousTensorSpec(dtype=torch.int32, shape=()),
            shape=(),
        )
        self.state_spec = self.observation_spec.clone()

        # 修改这里,将动作空间改为离散的
        self.action_spec = CompositeSpec(

            action1=DiscreteTensorSpec(n=8),
            action2=DiscreteTensorSpec(n=1),

            # Add more actions as needed

        ) # 4 表示上下左右四个动作

        self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))

    def _reset(self, tensordict):
        if tensordict is None or tensordict.is_empty():
            tensordict = self.gen_params(batch_size=self.batch_size)

        position = torch.tensor(self.start, device=self.device, dtype=torch.float32)  # 修改为 float32
        out = TensorDict(
            {
                "position": position,
                "params": tensordict["params"],
                "step_count": torch.tensor(0, device=self.device, dtype=torch.int32),  # 初始化 step_count
            },
            batch_size=tensordict.shape,
        ).to(self.device)
        return out



    def _step(self, tensordict):
        position = tensordict["position"]
        action = tensordict["action1"].squeeze(-1)  # 这里应该是一个整数,表示动作
        maze_size = tensordict["params", "maze_size"]
        goal = tensordict["params", "goal"]

        old_position = position.clone()
        new_position = position.clone()
        if action == 0:  # left
            new_position[0] = torch.clamp(new_position[0] - 1, 0, maze_size[0] - 1)
        elif action == 1:  # right
            new_position[0] = torch.clamp(new_position[0] + 1, 0, maze_size[0] - 1)
        elif action == 2:  # up
            new_position[1] = torch.clamp(new_position[1] - 1, 0, maze_size[1] - 1)
        elif action == 3:  # down
            new_position[1] = torch.clamp(new_position[1] + 1, 0, maze_size[1] - 1)
        elif action == 4:  # left up
            new_position[0] = torch.clamp(new_position[0] - 1, 0, maze_size[0] - 1)
            new_position[1] = torch.clamp(new_position[1] - 1, 0, maze_size[1] - 1)
        elif action == 5:  # left down
            new_position[0] = torch.clamp(new_position[0] - 1, 0, maze_size[0] - 1)
            new_position[1] = torch.clamp(new_position[1] + 1, 0, maze_size[1] - 1)
        elif action == 6:  # right up
            new_position[0] = torch.clamp(new_position[0] + 1, 0, maze_size[0] - 1)
            new_position[1] = torch.clamp(new_position[1] - 1, 0, maze_size[1] - 1)
        elif action == 7:  # right down
            new_position[0] = torch.clamp(new_position[0] + 1, 0, maze_size[0] - 1)
            new_position[1] = torch.clamp(new_position[1] + 1, 0, maze_size[1] - 1)

        reward = 0
        done = torch.tensor(False, dtype=torch.bool, device=self.device)

        # 检查是否到达目标
        if torch.equal(new_position, goal):
            reward += 500.0  # 到达目标的奖励
            done = torch.tensor(True, dtype=torch.bool, device=self.device)

        # 检查是否碰到障碍物
        if self.is_obstacle(new_position):
            reward -= 10  # 碰到障碍物的惩罚
            new_position = old_position  # 保持原位置不变

        # 距离奖励
        old_distance = torch.norm(old_position - goal)
        new_distance = torch.norm(new_position - goal)
        distance_reward = old_distance - new_distance
        reward += distance_reward * 10  # 调整系数以平衡奖励

        # 步数惩罚
        reward -= 0.1  # 每一步的小惩罚

        # 重复访问惩罚
        self.visit_counts[new_position[0].long(), new_position[1].long()] += 1
        visit_count = self.visit_counts[new_position[0].long(), new_position[1].long()]
        repetition_penalty = -0.1 * torch.log(visit_count)
        reward += repetition_penalty

        # 向目标前进奖励
        if new_distance < old_distance:
            reward += 0.1  # 朝目标方向移动的奖励

        # 将好奇心奖励加入总奖励
        total_reward = reward
        step_count = tensordict["step_count"] + 1  # 增加 step_count


        # 向目标前进奖励
        if new_distance < old_distance:
            reward += 0.001  # 朝目标方向移动的奖励
        out = TensorDict(
            {
                "position": new_position,
                "params": tensordict["params"],
                "reward": torch.tensor(total_reward, device=self.device).view(*tensordict.shape, 1),
                "done": done.view(*tensordict.shape, 1),
                "step_count": step_count,
            },
            tensordict.shape,
        ).to(self.device)
        return out

    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    # def compute_curiosity_reward(self, position):
    #     # 计算好奇心奖励
    #     # 这里使用访问次数的倒数作为奖励,可以根据需要调整
    #     visit_count = self.visit_counts[position[0].long(), position[1].long()]
    #     curiosity_reward = 10 / (visit_count + 1)  # 加1避免除以0
    #     return curiosity_reward.item()

    import torch

    def compute_curiosity_reward(self, position):
        # 计算好奇心奖励
        # 这里使用访问次数的倒数作为奖励,可以根据需要调整
        visit_count = self.visit_counts[position[0].long(), position[1].long()]
        curiosity_reward = 0.01 / (visit_count + 1)  # 加1避免除以0
        return curiosity_reward.item()

    def generate_new_start_goal(self):
        while True:
            new_start = torch.randint(0, self.maze_size[0], (2,), device=self.device)
            new_goal = torch.randint(0, self.maze_size[0], (2,), device=self.device)

            if not self.is_obstacle(new_start) and not self.is_obstacle(new_goal):
                return new_start, new_goal

    def is_obstacle(self, position):
        # 确保 position 是一个 2D 张量
        if not isinstance(position, torch.Tensor):
            position = torch.tensor(position, device=self.device)
        position = position.to(self.device).float()

        # 展开 position 以匹配 components 的形状
        position = position.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, 2)

        # 检查 position 是否在每个障碍物的范围内
        in_x_range = (self.components[:, 0, 0] <= position[:, :, 0]) & (position[:, :, 0] <= self.components[:, 1, 0])
        in_y_range = (self.components[:, 0, 1] <= position[:, :, 1]) & (position[:, :, 1] <= self.components[:, 1, 1])

        # 如果 position 在任何一个障碍物的范围内,则返回 True
        return (in_x_range & in_y_range).any().item()


def is_obstacle(self, position):
    # 确保 position 是一个 2D 张量
    if not isinstance(position, torch.Tensor):
        position = torch.tensor(position, device=self.device)
    position = position.to(self.device).float()

    # 检查是否超出边界
    if (position[0] < 0 or position[0] >= self.maze_size[0] or
            position[1] < 0 or position[1] >= self.maze_size[1]):
        return True

    # 展开 position 以匹配 components 的形状
    position = position.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, 2)

    # 检查 position 是否在每个障碍物的范围内
    in_x_range = (self.components[:, 0, 0] <= position[:, :, 0]) & (position[:, :, 0] <= self.components[:, 1, 0])
    in_y_range = (self.components[:, 0, 1] <= position[:, :, 1]) & (position[:, :, 1] <= self.components[:, 1, 1])

    # 如果 position 在任何一个障碍物的范围内,则返回 True
    return (in_x_range & in_y_range).any().item()


def make_composite_from_td(td):
    composite = CompositeSpec(
        {
            key: make_composite_from_td(tensor)
            if isinstance(tensor, TensorDictBase)
            else UnboundedContinuousTensorSpec(
                dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
            )
            for key, tensor in td.items()
        },
        shape=td.shape,
    )
    return composite

@albertbou92
Copy link
Contributor

albertbou92 commented Sep 2, 2024

Hi! sorry for the late reply.

Thanks for sharing the code.

I have tested your script with the latest code of both torchrl and tensordict. I cloned the repos and installed them as explained in the README. I could run your code with only minor changes without errors. The main issue is that some functionalities are still not 100% adapted for composite distributions (we are on it), but your script is mostly fine.

  1. I only had to modify the structure of the action space to be a tree with root "action"

in env.py

self.action_spec = CompositeSpec(
    {
        "action": {
            "action1": DiscreteTensorSpec(n=8),
            "action2": DiscreteTensorSpec(n=1),
        }
    }
)

in train.py

policy_module = ProbabilisticActor(
    module=policy_module,
    in_keys=["params"],
    distribution_class=CompositeDistribution,
    distribution_kwargs={
        "distribution_map": {
            "action1": d.Categorical,
            "action2": d.Categorical,

        },
        "name_map": {
            "action1": ("action", "action1"),
            "action2": ("action", "action2"),
        },    

    },
    return_log_prob=True,
)
policy_module(dummy_input)
  1. We are currently adapting the losses to work with composite distributions. The code for A2C and PPO losses will be integrated in this PR: [BugFix] Allow for composite action distributions in PPO/A2C losses #2391, which will probably be merged soon. @vmoens .

  2. The EGreedyModule and the entropy bonus of PPO need to be reviewed to better work with composite distributions. I had to comment out the EGreedyModule code and set entropy_bonus=False for PPO for now.

With the mentioned changes, your script runs fine for me.

@vmoens
Copy link
Contributor

vmoens commented Sep 17, 2024

Closing since we can't reprod - feel free to reopen if the problem persists

@Sui-Xing
Copy link
Author

I have already made modifications based on the suggestions above, but my code encounters issues when calculating the PPO loss. After debugging, I found that the line gain1 = log_weight.exp() * advantage always results in a tensor where all values are zero. I also discovered that this might be due to the fact that the result from return self.log_prob_composite(sample, include_sum=True) is too small (e.g., -300, -200, etc.). I can't figure out why self.log_prob_composite computes such values, and I hope someone can help me with this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants
@albertbou92 @vmoens @Sui-Xing and others