diff --git a/SMIL_torch_batch.py b/SMIL_torch_batch.py index fdeb69c..8d8290c 100644 --- a/SMIL_torch_batch.py +++ b/SMIL_torch_batch.py @@ -188,6 +188,7 @@ def forward(self, beta, pose, trans=None, simplify=False): G[self.parent[i]], self.rotate_translate(R_cube[:, i], J[:, i] - J[:, self.parent[i]]))) G = torch.stack(G, 1) + Jtr = G[..., :4, 3].clone() G = G - self.pack(torch.matmul(G, torch.cat([J, J.new_zeros(1).expand(*J.shape[:2], 1)], dim=2).unsqueeze(-1))) # T = torch.tensordot(self.weights, G, dims=([1], [1])) @@ -197,7 +198,6 @@ def forward(self, beta, pose, trans=None, simplify=False): T = torch.tensordot(G, self.weights, dims=([1], [1])).permute(0, 3, 1, 2) v = torch.matmul(T, torch.reshape(rest_shape_h, (batch_size, -1, 4, 1))).reshape(batch_size, -1, 4) - Jtr = self.regress_joints(v) if trans is not None: trans = trans.unsqueeze(1)