Skip to content

Commit

Permalink
Merge pull request #847 from sabify:patch-4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 451449794
Change-Id: I8decf25a3dd011c3b29c6d7cc22cd0a5c01df8a3
  • Loading branch information
lanctot committed May 28, 2022
2 parents 9acbbc9 + b4148dc commit ca60e95
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions open_spiel/python/pytorch/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import collections
import math
import sys
import numpy as np
from scipy import stats
import torch
Expand All @@ -29,7 +30,7 @@
"Transition",
"info_state action reward next_info_state is_final_step legal_actions_mask")

ILLEGAL_ACTION_LOGITS_PENALTY = -1e9
ILLEGAL_ACTION_LOGITS_PENALTY = sys.float_info.min


class SonnetLinear(nn.Module):
Expand Down Expand Up @@ -323,9 +324,11 @@ def learn(self):
self._q_values = self._q_network(info_states)
self._target_q_values = self._target_q_network(next_info_states).detach()

illegal_actions = 1 - legal_actions_mask
illegal_logits = illegal_actions * ILLEGAL_ACTION_LOGITS_PENALTY
max_next_q = torch.max(self._target_q_values + illegal_logits, dim=1)[0]
illegal_actions_mask = 1 - legal_actions_mask
legal_target_q_values = self._target_q_values.masked_fill(
illegal_actions_mask, ILLEGAL_ACTION_LOGITS_PENALTY)
max_next_q = torch.max(legal_target_q_values, dim=1)[0]

target = (
rewards + (1 - are_final_steps) * self._discount_factor * max_next_q)
action_indices = torch.stack([
Expand Down

0 comments on commit ca60e95

Please sign in to comment.