Understanding some RNN logic for QMIX SMAX algorithm #119
Description
Hey all, hope everyone is doing well! What follows may be bit of a dumb question, but I just wanted to clarify how this is working for my own algorithm development based on your guys' excellent code.
The QMIX code uses a ScannedRNN class where you pass in a sequence of observations and dones, and anywhere where a done condition is true, the hidden state is reset, and we pass the corresponding obs at that timestep through:
class ScannedRNN(nn.Module):
@partial(
nn.scan,
variable_broadcast="params",
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
@nn.compact
def __call__(self, carry, x):
"""Applies the module."""
rnn_state = carry
ins, resets = x
hidden_size = ins.shape[-1]
rnn_state = jnp.where(
resets[:, np.newaxis],
self.initialize_carry(hidden_size, *ins.shape[:-1]),
rnn_state,
)
new_rnn_state, y = nn.GRUCell(hidden_size)(rnn_state, ins)
return new_rnn_state, y
This makes sense to me. However I noticed that when data is actually collected, any given timestep actually consists of the last obs + new done, instead of last obs + last done.
Therefore, doesn't it mean that when we're using this RNN and it resets the hidden state and then passes in the observation, we're actually using the previous observation (which is associated with the previous episode) as the first step in the RNN's new sequence with the reset hidden state, instead of the current/new observation (from the new episode after the environment was just reset) generated after the episode was ended with the done being True?