-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] RuntimeError: index -9223372036854775808 is out of bounds for dimension 1 with size 1 #2402
Comments
Cc @albertbou92 |
I did use the CompositeDistribution class. class Net_Policy(nn.Module):
def __init__(self, num_cells, action_dims, device):
super().__init__()
self.policy_mlps = nn.ModuleList([
nn.Sequential(
nn.Linear(num_cells + 128, num_cells),
nn.ReLU(),
nn.Linear(num_cells, action_dim)
).to(device) for action_dim in action_dims
])
def forward(self, hc):
policies = [mlp(hc) for mlp in self.policy_mlps]
policy_dict = {
("params", f"action{i + 1}", ): policy for i, policy in enumerate(policies)
}
return policy_dict
def get_dist(self, tensordict):
params = self(tensordict['feature1'], tensordict['feature2'],...)
params_td = TensorDict(params, batch_size=tensordict.batch_size)
return CompositeDistribution(
params=params_td["params"],
distribution_map={
"action1": d.Categorical,
"action2": d.Categorical,
"action3": d.Categorical,
"action4": d.Categorical,
"action5": d.Categorical
},
) |
Hi! I will look into it. Will go back to you ASAP. |
I can run a ProbabilisticActor with the latest TensorDict code and set |
Sure, I will provide a demo that can trigger this error. The code for the maze environment might not be correct, but it can accurately trigger the "RuntimeError: index -9223372036854775808 is out of bounds for dimension 0 with size 8" error. train.pyimport warnings
from tensordict import TensorDict
from env import MazeEnv
from torch import nn, distributions as d
warnings.filterwarnings("ignore")
from torch import multiprocessing
import matplotlib.pyplot as plt
from torchrl.collectors import SyncDataCollector, MultiaSyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator, EGreedyModule
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from collections import defaultdict
import torch
from tensordict.nn import TensorDictModule, TensorDictSequential, CompositeDistribution
from torch import nn
from torchrl.envs.utils import check_env_specs, step_mdp
from tqdm import tqdm
import torch.nn.init as init
import torch
import numpy as np
import random
# 设置Python内置随机模块的种子
random.seed(42)
# 设置NumPy的种子
np.random.seed(42)
# 设置PyTorch的CPU和CUDA随机种子(如果使用GPU)
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
# PPO参数设置
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
device = "cpu"
num_cells = 128 # number of cells in each layer i.e. output dim.
lr = 1e-4
max_grad_norm = 1.0
frames_per_batch = 3000
total_frames = 10*frames_per_batch * 200
sub_batch_size = 64
num_epochs = 30
clip_epsilon = 0.2
gamma = 0.9995
lmbda = 0.95
entropy_eps = 1e-4
greedy_epsilon = 0.5
# 初始化迷宫环境
env = MazeEnv(device=device)
check_env_specs(env)
# 重置环境
td = env.reset()
print("reset tensordict", td)
# 执行随机步
td = env.rand_step(td)
print("random step tensordict", td)
# 调试:打印 observation_spec
print("observation_spec:", env.observation_spec)
# 使用实际观测值来初始化 dummy_input
dummy_input = td["position"].unsqueeze(0).to(device)
print("dummy_input shape:", dummy_input.shape)
class Policy(nn.Module):
def __init__(self, num_cells, action_dims, device):
super().__init__()
self.policy_mlps = nn.ModuleList([
nn.Sequential(
nn.LazyLinear(num_cells),
nn.ReLU(),
nn.Linear(num_cells, action_dim)
).to(device) for action_dim in action_dims
])
def forward(self, hc):
policies = [mlp(hc) for mlp in self.policy_mlps]
# 构建返回的字典
policy_dict = {
("params", "action1", "logits"): policies[0],
("params", "action2", "logits"): policies[1],
}
# print("policy_dict : ",policy_dict)
return policy_dict
def get_dist(self, tensordict):
params = self(tensordict['position'])
params_td = TensorDict(params, batch_size=tensordict.batch_size)
return CompositeDistribution(
params=params_td["params"],
distribution_map={
"action1": d.Categorical,
"action2": d.Categorical,
},
)
actor_net = Policy(128,[8,1],device=device)
# Value Net
value_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(1, device=device),
)
# Initialize LazyLinear layers with dummy input
actor_net(dummy_input)
value_net(dummy_input)
# Apply Xavier initialization
def init_weights(m):
if isinstance(m, (nn.Linear, nn.LazyLinear)):
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.zeros_(m.bias)
actor_net.apply(init_weights)
value_net.apply(init_weights)
policy_module = TensorDictModule(
actor_net, in_keys=["position"], out_keys=[
("params", "action1", "logits"),
("params", "action2", "logits"),
]
)
policy_module = ProbabilisticActor(
module=policy_module,
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={
"distribution_map": {
"action1": d.Categorical,
"action2": d.Categorical,
},
},
return_log_prob=True,
)
policy_module(dummy_input)
value_module = ValueOperator(
module=value_net,
in_keys=["position"],
)
replay_buffer = ReplayBuffer(
storage=LazyTensorStorage(max_size=frames_per_batch),
sampler=SamplerWithoutReplacement(),
)
advantage_module = GAE(
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
)
loss_module = ClipPPOLoss(
actor_network=policy_module,
critic_network=value_module,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
critic_coef=1.0,
loss_critic_type="smooth_l1",
normalize_advantage=True,
)
optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, total_frames // frames_per_batch, 0.0
)
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""
if __name__ == '__main__':
env_maker = lambda: MazeEnv(device=device)
# 在训练循环开始前定义
initial_epsilon = 0.2
final_epsilon = 0.01
collector = SyncDataCollector(
env,
policy=policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
exploration_type=ExplorationType.RANDOM
)
explorative_policy = TensorDictSequential(policy_module,
EGreedyModule(eps_init=initial_epsilon,
eps_end=final_epsilon,
spec=env.action_spec))
# 训练循环
for i, tensordict_data in enumerate(collector):
# env.reset()
tensordict_data = explorative_policy(tensordict_data)
for _ in range(num_epochs):
advantage_module(tensordict_data)
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view.cpu())
for _ in range(frames_per_batch // sub_batch_size):
subdata = replay_buffer.sample(sub_batch_size)
loss_vals = loss_module(subdata.to(device))
loss_value = (
loss_vals["loss_objective"]
+ loss_vals["loss_critic"]
+ loss_vals["loss_entropy"]
)
loss_value.backward()
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
optim.step()
optim.zero_grad()
logs["reward"].append(tensordict_data["next", "reward"].mean().item())
pbar.update(tensordict_data.numel())
cum_reward_str = (
f"average reward={logs['reward'][-1]: 4.5f} (init={logs['reward'][0]: 4.5f})"
)
logs["step_count"].append(tensordict_data["step_count"].max().item())
stepcount_str = f"step count (max): {logs['step_count'][-1]}"
logs["lr"].append(optim.param_groups[0]["lr"])
lr_str = f"lr policy: {logs['lr'][-1]: 4.5f}"
if (i + 1) % 10 == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
logs["eval reward (sum)"].append(
eval_rollout["next", "reward"].sum().item()
)
logs["eval step_count"].append(0)
eval_str = (
f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.5f} "
f"(init: {logs['eval reward (sum)'][0]: 4.5f}), "
f"eval step-count: {logs['eval step_count'][-1]}"
)
del eval_rollout
if i % 20 == 0 and i > 1:
# 在训练循环结束后
model_state = {
'policy_state_dict': policy_module.state_dict(),
'value_state_dict': value_module.state_dict(),
'optimizer_state_dict': optim.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'epoch': num_epochs
}
torch.save(model_state, 'ckpt/ppo_maze_model.pth')
print("Model saved!")
pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))
scheduler.step()
pbar.close()
collector.shutdown()
env.reset()
eval_rollout = env.rollout(2000, policy_module)
# print("评估结果:\n\n",eval_rollout['next'].values())
for k in eval_rollout['next'].keys():
if k == "position":
for v in eval_rollout['next'].get(k):
print("评估结果:\n", k + "\n", v.cpu().numpy())
# 结果展示
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()
board.pyclass Board:
width = 100
height = 100
start_pos = (1, 1)
end_pos = (78, 78)
components = [[(30, 30), (50, 70)],
[(10, 8), (15, 10)],
[(12, 2), (15, 7)],
[(2, 10), (5, 11)],
[(2, 13), (5, 14)],
# [(15, 12), (18, 18)],
]
lines = []
def __init__(self):
self.width = 20
self.height = 20
self.start_pos = (0, 0)
self.end_pos = (18, 12)
self.components = [[(3, 3), (5, 7)],
[(10, 8), (17, 10)]] env.pyfrom typing import Optional
import torch
from tensordict import TensorDict, TensorDictBase
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec, DiscreteTensorSpec
from torchrl.envs import (
EnvBase,
)
from board import Board
# 定义迷宫环境
# 定义迷宫环境
class MazeEnv(EnvBase):
metadata = {
"render_modes": ["human", "rgb_array"],
"render_fps": 30,
}
batch_locked = False
def __init__(self, maze_size=(Board.width, Board.height), start=(0, 0), goal=(8, 8), seed=None, device="cpu"):
self.maze_size = maze_size
self.start = Board.start_pos
# self.start=start
self.goal = Board.end_pos
td_params = self.gen_params()
super().__init__(device=device, batch_size=[])
self._make_spec(td_params)
if seed is None:
seed = torch.empty((), dtype=torch.int64).random_().item()
self.set_seed(seed)
self.visit_counts = torch.zeros(maze_size, device=device)
self.start = torch.tensor(self.start)
self.visit_counts[self.start[0].long(), self.start[1].long()] += 1
self.components = torch.tensor(Board.components, device=self.device)
@staticmethod
def gen_params(batch_size=None) -> TensorDictBase:
"""Returns a ``tensordict`` containing the maze parameters."""
if batch_size is None:
batch_size = []
td = TensorDict(
{
"params": TensorDict(
{
"maze_size": torch.tensor([Board.width, Board.height]),
"start": torch.tensor([Board.start_pos[0], Board.start_pos[1]]),
"goal": torch.tensor([Board.end_pos[0], Board.end_pos[1]]),
},
[],
)
},
[],
)
if batch_size:
td = td.expand(batch_size).contiguous()
return td
def _make_spec(self, td_params):
self.observation_spec = CompositeSpec(
position=BoundedTensorSpec(
low=torch.tensor([0, 0]),
high=td_params["params", "maze_size"] - 1,
shape=(2,),
dtype=torch.float32,
),
params=make_composite_from_td(td_params["params"]),
step_count=UnboundedContinuousTensorSpec(dtype=torch.int32, shape=()),
shape=(),
)
self.state_spec = self.observation_spec.clone()
# 修改这里,将动作空间改为离散的
self.action_spec = CompositeSpec(
action1=DiscreteTensorSpec(n=8),
action2=DiscreteTensorSpec(n=1),
# Add more actions as needed
) # 4 表示上下左右四个动作
self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))
def _reset(self, tensordict):
if tensordict is None or tensordict.is_empty():
tensordict = self.gen_params(batch_size=self.batch_size)
position = torch.tensor(self.start, device=self.device, dtype=torch.float32) # 修改为 float32
out = TensorDict(
{
"position": position,
"params": tensordict["params"],
"step_count": torch.tensor(0, device=self.device, dtype=torch.int32), # 初始化 step_count
},
batch_size=tensordict.shape,
).to(self.device)
return out
def _step(self, tensordict):
position = tensordict["position"]
action = tensordict["action1"].squeeze(-1) # 这里应该是一个整数,表示动作
maze_size = tensordict["params", "maze_size"]
goal = tensordict["params", "goal"]
old_position = position.clone()
new_position = position.clone()
if action == 0: # left
new_position[0] = torch.clamp(new_position[0] - 1, 0, maze_size[0] - 1)
elif action == 1: # right
new_position[0] = torch.clamp(new_position[0] + 1, 0, maze_size[0] - 1)
elif action == 2: # up
new_position[1] = torch.clamp(new_position[1] - 1, 0, maze_size[1] - 1)
elif action == 3: # down
new_position[1] = torch.clamp(new_position[1] + 1, 0, maze_size[1] - 1)
elif action == 4: # left up
new_position[0] = torch.clamp(new_position[0] - 1, 0, maze_size[0] - 1)
new_position[1] = torch.clamp(new_position[1] - 1, 0, maze_size[1] - 1)
elif action == 5: # left down
new_position[0] = torch.clamp(new_position[0] - 1, 0, maze_size[0] - 1)
new_position[1] = torch.clamp(new_position[1] + 1, 0, maze_size[1] - 1)
elif action == 6: # right up
new_position[0] = torch.clamp(new_position[0] + 1, 0, maze_size[0] - 1)
new_position[1] = torch.clamp(new_position[1] - 1, 0, maze_size[1] - 1)
elif action == 7: # right down
new_position[0] = torch.clamp(new_position[0] + 1, 0, maze_size[0] - 1)
new_position[1] = torch.clamp(new_position[1] + 1, 0, maze_size[1] - 1)
reward = 0
done = torch.tensor(False, dtype=torch.bool, device=self.device)
# 检查是否到达目标
if torch.equal(new_position, goal):
reward += 500.0 # 到达目标的奖励
done = torch.tensor(True, dtype=torch.bool, device=self.device)
# 检查是否碰到障碍物
if self.is_obstacle(new_position):
reward -= 10 # 碰到障碍物的惩罚
new_position = old_position # 保持原位置不变
# 距离奖励
old_distance = torch.norm(old_position - goal)
new_distance = torch.norm(new_position - goal)
distance_reward = old_distance - new_distance
reward += distance_reward * 10 # 调整系数以平衡奖励
# 步数惩罚
reward -= 0.1 # 每一步的小惩罚
# 重复访问惩罚
self.visit_counts[new_position[0].long(), new_position[1].long()] += 1
visit_count = self.visit_counts[new_position[0].long(), new_position[1].long()]
repetition_penalty = -0.1 * torch.log(visit_count)
reward += repetition_penalty
# 向目标前进奖励
if new_distance < old_distance:
reward += 0.1 # 朝目标方向移动的奖励
# 将好奇心奖励加入总奖励
total_reward = reward
step_count = tensordict["step_count"] + 1 # 增加 step_count
# 向目标前进奖励
if new_distance < old_distance:
reward += 0.001 # 朝目标方向移动的奖励
out = TensorDict(
{
"position": new_position,
"params": tensordict["params"],
"reward": torch.tensor(total_reward, device=self.device).view(*tensordict.shape, 1),
"done": done.view(*tensordict.shape, 1),
"step_count": step_count,
},
tensordict.shape,
).to(self.device)
return out
def _set_seed(self, seed: Optional[int]):
rng = torch.manual_seed(seed)
self.rng = rng
# def compute_curiosity_reward(self, position):
# # 计算好奇心奖励
# # 这里使用访问次数的倒数作为奖励,可以根据需要调整
# visit_count = self.visit_counts[position[0].long(), position[1].long()]
# curiosity_reward = 10 / (visit_count + 1) # 加1避免除以0
# return curiosity_reward.item()
import torch
def compute_curiosity_reward(self, position):
# 计算好奇心奖励
# 这里使用访问次数的倒数作为奖励,可以根据需要调整
visit_count = self.visit_counts[position[0].long(), position[1].long()]
curiosity_reward = 0.01 / (visit_count + 1) # 加1避免除以0
return curiosity_reward.item()
def generate_new_start_goal(self):
while True:
new_start = torch.randint(0, self.maze_size[0], (2,), device=self.device)
new_goal = torch.randint(0, self.maze_size[0], (2,), device=self.device)
if not self.is_obstacle(new_start) and not self.is_obstacle(new_goal):
return new_start, new_goal
def is_obstacle(self, position):
# 确保 position 是一个 2D 张量
if not isinstance(position, torch.Tensor):
position = torch.tensor(position, device=self.device)
position = position.to(self.device).float()
# 展开 position 以匹配 components 的形状
position = position.unsqueeze(0).unsqueeze(0) # Shape: (1, 1, 2)
# 检查 position 是否在每个障碍物的范围内
in_x_range = (self.components[:, 0, 0] <= position[:, :, 0]) & (position[:, :, 0] <= self.components[:, 1, 0])
in_y_range = (self.components[:, 0, 1] <= position[:, :, 1]) & (position[:, :, 1] <= self.components[:, 1, 1])
# 如果 position 在任何一个障碍物的范围内,则返回 True
return (in_x_range & in_y_range).any().item()
def is_obstacle(self, position):
# 确保 position 是一个 2D 张量
if not isinstance(position, torch.Tensor):
position = torch.tensor(position, device=self.device)
position = position.to(self.device).float()
# 检查是否超出边界
if (position[0] < 0 or position[0] >= self.maze_size[0] or
position[1] < 0 or position[1] >= self.maze_size[1]):
return True
# 展开 position 以匹配 components 的形状
position = position.unsqueeze(0).unsqueeze(0) # Shape: (1, 1, 2)
# 检查 position 是否在每个障碍物的范围内
in_x_range = (self.components[:, 0, 0] <= position[:, :, 0]) & (position[:, :, 0] <= self.components[:, 1, 0])
in_y_range = (self.components[:, 0, 1] <= position[:, :, 1]) & (position[:, :, 1] <= self.components[:, 1, 1])
# 如果 position 在任何一个障碍物的范围内,则返回 True
return (in_x_range & in_y_range).any().item()
def make_composite_from_td(td):
composite = CompositeSpec(
{
key: make_composite_from_td(tensor)
if isinstance(tensor, TensorDictBase)
else UnboundedContinuousTensorSpec(
dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
)
for key, tensor in td.items()
},
shape=td.shape,
)
return composite |
Hi! sorry for the late reply. Thanks for sharing the code. I have tested your script with the latest code of both
in env.py
in train.py
With the mentioned changes, your script runs fine for me. |
Closing since we can't reprod - feel free to reopen if the problem persists |
I have already made modifications based on the suggestions above, but my code encounters issues when calculating the PPO loss. After debugging, I found that the line |
Describe the bug
ProbabilisticActor cannot be configured with return_log_prob=True; it will throw an error in version 0.5.0, but switching back to version 0.4.0 resolves the issue.
To Reproduce
Expected behavior
it should output the correct log_prob.
System info
Describe the characteristic of your environment:
Checklist
The text was updated successfully, but these errors were encountered: