Skip to content

Commit

Permalink
Merge pull request #17 from nworb-cire/type-hints
Browse files Browse the repository at this point in the history
Update type hints
  • Loading branch information
haraschax authored Jun 19, 2024
2 parents e4c6dee + 182cc43 commit da72a0c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tinyphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import partial
from hashlib import md5
from pathlib import Path
from typing import List, Union, Tuple
from typing import List, Union, Tuple, Dict
from tqdm.contrib.concurrent import process_map

from controllers import BaseController
Expand Down Expand Up @@ -42,14 +42,14 @@ def __init__(self):
self.vocab_size = VOCAB_SIZE
self.bins = np.linspace(LATACCEL_RANGE[0], LATACCEL_RANGE[1], self.vocab_size)

def encode(self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
def encode(self, value: Union[float, np.ndarray, List[float]]) -> Union[int, np.ndarray]:
value = self.clip(value)
return np.digitize(value, self.bins, right=True)

def decode(self, token: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
return self.bins[token]

def clip(self, value: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
def clip(self, value: Union[float, np.ndarray, List[float]]) -> Union[float, np.ndarray]:
return np.clip(value, LATACCEL_RANGE[0], LATACCEL_RANGE[1])


Expand Down Expand Up @@ -142,7 +142,7 @@ def control_step(self, step_idx: int) -> None:
action = np.clip(action, STEER_RANGE[0], STEER_RANGE[1])
self.action_history.append(action)

def get_state_target_futureplan(self, step_idx: int) -> Tuple[State, float]:
def get_state_target_futureplan(self, step_idx: int) -> Tuple[State, float, FuturePlan]:
state = self.data.iloc[step_idx]
return (
State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']),
Expand Down Expand Up @@ -174,7 +174,7 @@ def plot_data(self, ax, lines, axis_labels, title) -> None:
ax.set_xlabel(axis_labels[0])
ax.set_ylabel(axis_labels[1])

def compute_cost(self) -> dict:
def compute_cost(self) -> Dict[str, float]:
target = np.array(self.target_lataccel_history)[CONTROL_START_IDX:COST_END_IDX]
pred = np.array(self.current_lataccel_history)[CONTROL_START_IDX:COST_END_IDX]

Expand All @@ -183,7 +183,7 @@ def compute_cost(self) -> dict:
total_cost = (lat_accel_cost * LAT_ACCEL_COST_MULTIPLIER) + jerk_cost
return {'lataccel_cost': lat_accel_cost, 'jerk_cost': jerk_cost, 'total_cost': total_cost}

def rollout(self) -> float:
def rollout(self) -> Dict[str, float]:
if self.debug:
plt.ion()
fig, ax = plt.subplots(4, figsize=(12, 14), constrained_layout=True)
Expand Down

0 comments on commit da72a0c

Please sign in to comment.