forked from commaai/controls_challenge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
103 lines (83 loc) · 4.18 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import base64
import numpy as np
import pandas as pd
import seaborn as sns
from io import BytesIO
from matplotlib import pyplot as plt
from pathlib import Path
from tqdm import tqdm
from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator, CONTROLLERS, CONTROL_START_IDX
sns.set_theme()
SAMPLE_ROLLOUTS = 5
def img2base64(fig):
buf = BytesIO()
fig.savefig(buf, format='png')
data = base64.b64encode(buf.getbuffer()).decode("ascii")
return data
def create_report(test, baseline, sample_rollouts, costs):
res = []
res.append("<h1>Comma Controls Challenge: Report</h1>")
res.append(f"<b>Test Controller: {test}, Baseline Controller: {baseline}</b>")
res.append("<h2>Aggregate Costs</h2>")
res_df = pd.DataFrame(costs)
fig, axs = plt.subplots(ncols=3, figsize=(18, 6), sharey=True)
bins = np.arange(0, 1000, 10)
for ax, cost in zip(axs, ['lataccel_cost', 'jerk_cost', 'total_cost']):
for controller in ['test', 'baseline']:
ax.hist(res_df[res_df['controller'] == controller][cost], bins=bins, label=controller, alpha=0.5)
ax.set_xlabel('Cost')
ax.set_ylabel('Frequency')
ax.set_title(f'Cost Distribution: {cost}')
ax.legend()
res.append(f'<img src="data:image/png;base64,{img2base64(fig)}" alt="Plot">')
res.append(res_df.groupby('controller').agg({'lataccel_cost': 'mean', 'jerk_cost': 'mean', 'total_cost': 'mean'}).round(3).reset_index().to_html(index=False))
res.append("<h2>Sample Rollouts</h2>")
fig, axs = plt.subplots(ncols=1, nrows=SAMPLE_ROLLOUTS, figsize=(15, 3 * SAMPLE_ROLLOUTS), sharex=True)
for ax, rollout in zip(axs, sample_rollouts):
ax.plot(rollout['desired_lataccel'], label='Desired Lateral Acceleration')
ax.plot(rollout['test_controller_lataccel'], label='Test Controller Lateral Acceleration')
ax.plot(rollout['baseline_controller_lataccel'], label='Baseline Controller Lateral Acceleration')
ax.set_xlabel('Step')
ax.set_ylabel('Lateral Acceleration')
ax.set_title(f"Segment: {rollout['seg']}")
ax.axline((CONTROL_START_IDX, 0), (CONTROL_START_IDX, 1), color='black', linestyle='--', alpha=0.5, label='Control Start')
ax.legend()
fig.tight_layout()
res.append(f'<img src="data:image/png;base64,{img2base64(fig)}" alt="Plot">')
with open("report.html", "w") as fob:
fob.write("\n".join(res))
print("Report saved to: './report.html'")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--num_segs", type=int, default=100)
parser.add_argument("--test_controller", default='simple', choices=CONTROLLERS.keys())
parser.add_argument("--baseline_controller", default='simple', choices=CONTROLLERS.keys())
args = parser.parse_args()
tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=False)
test_controller = CONTROLLERS[args.test_controller]()
baseline_controller = CONTROLLERS[args.baseline_controller]()
data_path = Path(args.data_path)
assert data_path.is_dir(), "data_path should be a directory"
costs = []
sample_rollouts = []
files = sorted(data_path.iterdir())[:args.num_segs]
for d, data_file in enumerate(tqdm(files, total=len(files))):
test_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=test_controller, debug=False)
test_cost = test_sim.rollout()
baseline_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=baseline_controller, debug=False)
baseline_cost = baseline_sim.rollout()
if d < SAMPLE_ROLLOUTS:
sample_rollouts.append({
'seg': data_file.stem,
'test_controller': args.test_controller,
'baseline_controller': args.baseline_controller,
'desired_lataccel': test_sim.target_lataccel_history,
'test_controller_lataccel': test_sim.current_lataccel_history,
'baseline_controller_lataccel': baseline_sim.current_lataccel_history,
})
costs.append({'seg': data_file.stem, 'controller': 'test', **test_cost})
costs.append({'seg': data_file.stem, 'controller': 'baseline', **baseline_cost})
create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs)