forked from paschalidoud/superquadric_parsing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
63 lines (54 loc) · 1.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pandas as pd
import pickle
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
def parse_train_test_splits(train_test_splits_file, model_tags):
splits = {}
if not train_test_splits_file.endswith("csv"):
raise Exception("Input file %s is not csv" % (train_test_splits_file,))
df = pd.read_csv(
train_test_splits_file,
names=["id", "synsetId", "subSynsetId", "modelId", "split"]
)
keep_from_model = reduce(
lambda a, x: a | (df["synsetId"] in x),
model_tags,
False
)
# Keep only the rows from the model we want
df_from_model = df[keep_from_model]
train_idxs = df_from_model["split"] == "train"
splits["train"] = df_from_model[train_idxs].modelId.values.tolist()
test_idxs = df_from_model["split"] == "test"
splits["test"] = df_from_model[test_idxs].modelId.values.tolist()
val_idxs = df_from_model["split"] == "val"
splits["val"] = df_from_model[val_idxs].modelId.values.tolist()
return splits
def get_colors(M):
return sns.color_palette("Paired")
def store_primitive_parameters(
size,
shape,
rotation,
location,
tapering,
probability,
color,
filepath
):
primitive_params = dict(
size=size,
shape=shape,
rotation=rotation,
location=location,
tapering=tapering,
probability=probability,
color=color
)
pickle.dump(
primitive_params,
open(filepath, "wb")
)