-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path05-apply_bench.py
153 lines (127 loc) · 4.06 KB
/
05-apply_bench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# %%
# !%load_ext autoreload
# !%autoreload 2
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from bench import change_model
from IPython.display import Markdown, display
from scipy.stats import uniform
from seaborn import heatmap
from sklearn.metrics import confusion_matrix
from dcsem.utils import stim_boxcar
from utils import set_style, simulate_bold
set_style()
# %%
# ======================================================================================
# Bilinear neural model parameters
NUM_LAYERS = 1
NUM_ROIS = 2
time = np.arange(100)
u = stim_boxcar([[0, 30, 1]]) # Input stimulus
# u = stim_boxcar([[0, 10, 1], [40, 10, 0.5], [50, 20, 1]])
# Parameters to set and estimate
params_to_set = ['a01', 'a10', 'c0', 'c1']
# Ground truth parameter values
bounds = {
'a01': (0.0, 1.0),
'a10': (0.0, 1.0),
'c0': (0.0, 1.0),
'c1': (0.0, 1.0),
}
# ======================================================================================
# %%
def calc_comps(method, **kwargs):
# Define the allowed parameters
allowed_keys = ['a01', 'a10', 'c0', 'c1']
# Find invalid keys
invalid_keys = [key for key in kwargs.keys() if key not in allowed_keys]
# Assert that all keys are allowed
assert (
not invalid_keys
), f'Invalid parameter keys: {invalid_keys}. Allowed keys are: {allowed_keys}.'
# Filter all arguments that are not None
params = {}
for key, val in kwargs.items():
if key == 'method':
continue
if val is not None:
# Convert the values to a numpy array
if not isinstance(val, (list, np.ndarray)):
val = [val]
if not isinstance(val, np.ndarray):
val = np.array(val)
params[key] = val
# Assert that all values have the same length
lengths = [len(v) for v in params.values()]
assert all(
length == lengths[0] for length in lengths
), 'All values must have the same length!'
# Initialize the BOLD signals
bold_true = simulate_bold(
params,
time=time,
u=u,
num_rois=NUM_ROIS,
)
bold_obsv = bold_true
tmp_bold = np.concatenate([bold_obsv[:, :, 0], bold_obsv[:, :, 1]], axis=1)
tmp_bold_c = tmp_bold - np.mean(tmp_bold, axis=1, keepdims=True)
if method == 'PCA':
pca = pickle.load(open('models/pca.pkl', 'rb'))
components = pca.transform(tmp_bold_c)
elif method == 'ICA':
ica = pickle.load(open('models/ica.pkl', 'rb'))
components = ica.transform(tmp_bold_c)
return components
# Check if the function works
comps = calc_comps('PCA', a01=[0.5, 1.0], a10=[1.0, 0.7])
print('PCA components:\n', comps)
# %%
priors = {
'a01': uniform(loc=bounds['a01'][0], scale=bounds['a01'][1] - bounds['a01'][0]),
'a10': uniform(loc=bounds['a10'][0], scale=bounds['a10'][1] - bounds['a10'][0]),
'c0': uniform(loc=bounds['c0'][0], scale=bounds['c0'][1] - bounds['c0'][0]),
'c1': uniform(loc=bounds['c1'][0], scale=bounds['c1'][1] - bounds['c1'][0]),
}
tr = change_model.Trainer(
forward_model=calc_comps,
priors=priors,
kwargs={'method': 'PCA'},
measurement_names=['PC1', 'PC2', 'PC3', 'PC4'],
)
mdl = tr.train(n_samples=5000, verbose=True)
# %%
n_test_samples = 2000
noise_level = 0.0001
effect_size = 0.3
n_repeats = 50
true_change, data, data2, sn = tr.generate_test_samples(
n_samples=n_test_samples,
n_repeats=n_repeats,
effect_size=effect_size,
noise_std=noise_level,
)
probs, infered_change_bench, amount, _ = mdl.infer(data, data2 - data, sn)
print('Accuracy:', np.mean(infered_change_bench == true_change))
# %%
conf_mat = confusion_matrix(true_change, infered_change_bench, normalize='true')
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
heatmap(
conf_mat,
annot=True,
fmt='.2f',
# cmap='coolwarm',
cbar=False,
square=True,
xticklabels=mdl.model_names,
yticklabels=mdl.model_names,
ax=ax,
)
ax.set_xlabel('Inferred Change')
ax.set_ylabel('Actual Change')
plt.title('BENCH')
plt.savefig('results/confusion_matrix_bench.png')
plt.show()
# %%