Skip to content

Commit

Permalink
Iterate over Python range instead jnp.arange.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 399392126
  • Loading branch information
RLaxDev authored and RLaxDev committed Oct 7, 2021
1 parent caf976f commit 8f72691
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions rlax/_src/multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def lambda_returns(
# Work backwards to compute `G_{T-1}`, ..., `G_0`.
returns = []
g = v_t[-1]
for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
for i in reversed(range(v_t.shape[0])):
g = r_t[i] + discount_t[i] * ((1-lambda_[i]) * v_t[i] + lambda_[i] * g)
returns.insert(0, g)

Expand Down Expand Up @@ -262,7 +262,7 @@ def importance_corrected_td_errors(

# Work backwards to compute `delta_{T-1}`, ..., `delta_0`.
delta, errors = 0.0, []
for i in jnp.arange(one_step_delta.shape[0] - 1, -1, -1):
for i in reversed(range(one_step_delta.shape[0])):
delta = one_step_delta[i] + discount_t[i] * rho_t[i] * lambda_[i] * delta
errors.insert(0, delta)

Expand Down Expand Up @@ -424,7 +424,7 @@ def general_off_policy_returns_from_q_and_v(
# Work backwards to compute `G_K-1`, ..., `G_1`, `G_0`.
g = r_t[-1] + discount_t[-1] * v_t[-1] # G_K-1.
returns = [g]
for i in jnp.arange(q_t.shape[0] - 1, -1, -1): # [K - 2, ..., 0]
for i in reversed(range(q_t.shape[0])): # [K - 2, ..., 0]
g = r_t[i] + discount_t[i] * (v_t[i] - c_t[i] * q_t[i] + c_t[i] * g)
returns.insert(0, g)

Expand Down
4 changes: 2 additions & 2 deletions rlax/_src/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def vtrace(
# Work backwards computing the td-errors.
err = 0.0
errors = []
for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
for i in reversed(range(v_t.shape[0])):
err = td_errors[i] + discount_t[i] * c_t[i] * err
errors.insert(0, err)

Expand Down Expand Up @@ -141,7 +141,7 @@ def leaky_vtrace(
# Work backwards computing the td-errors.
err = 0.0
errors = []
for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
for i in reversed(range(v_t.shape[0])):
err = td_errors[i] + discount_t[i] * c_t[i] * err
errors.insert(0, err)

Expand Down

0 comments on commit 8f72691

Please sign in to comment.