-
Notifications
You must be signed in to change notification settings - Fork 326
/
functional.py
48 lines (42 loc) · 2.06 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import torch
def cross_entropy_loss(
log_policy: torch.Tensor, action: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
"""Returns the cross entropy loss defined as the log-softmax value indexed by the action index.
Supports discrete (integer) actions or one-hot encodings.
Args:
log_policy: Tensor of the log_softmax values of the policy.
action: Integer or one-hot representation of the actions undertaken. Must have a shape log_policy.shape[:-1]
(integer representation) or log_policy.shape (one-hot).
inplace: fills log_policy in-place with 0.0 at non-selected actions before summing along the last dimensions.
This is usually faster but it will change the value of log-policy in place, which may lead to unwanted
behaviors.
"""
if action.shape == log_policy.shape:
if action.dtype not in (torch.bool, torch.long, torch.uint8):
raise TypeError(
f"Cross-entropy loss with {action.dtype} dtype is not permitted"
)
if not ((action == 1).sum(-1) == 1).all():
raise RuntimeError(
"Expected the action tensor to be a one hot encoding of the actions taken, "
"but got more/less than one non-null boolean index on the last dimension"
)
if inplace:
cross_entropy = log_policy.masked_fill_(action, 0.0).sum(-1)
else:
cross_entropy = (log_policy * action).sum(-1)
elif action.shape == log_policy.shape[:-1]:
cross_entropy = torch.gather(log_policy, dim=-1, index=action[..., None])
cross_entropy.squeeze_(-1)
else:
raise RuntimeError(
f"unexpected action shape in cross_entropy_loss with log_policy.shape={log_policy.shape} and"
f"action.shape={action.shape}"
)
return cross_entropy