Skip to content

Commit

Permalink
Adding fix to make using NESTML a bit more robust + another NESTML ex…
Browse files Browse the repository at this point in the history
…ample
  • Loading branch information
kaeldai committed Jun 25, 2024
1 parent 1b58f44 commit 4eb01f9
Show file tree
Hide file tree
Showing 17 changed files with 271 additions and 1 deletion.
2 changes: 1 addition & 1 deletion bmtk/simulator/pointnet/modules/spikes_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from bmtk.simulator.pointnet.modules.sim_module import SimulatorMod
from bmtk.simulator.pointnet.io_tools import io
from bmtk.utils.reports.spike_trains import SpikeTrains
from bmtk.simulator.bionet.pyfunction_cache import py_modules
from bmtk.simulator.pointnet.pyfunction_cache import py_modules


class SpikesInputsMod(SimulatorMod):
Expand Down
6 changes: 6 additions & 0 deletions bmtk/simulator/pointnet/pointnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(self, **properties):
self._nodes_table = {}
self._gid2nestid = {}

self._nest_modules = []

self._gid_map = GidPool()
self._virtual_gids = GidPool()

Expand Down Expand Up @@ -131,6 +133,10 @@ def add_weight_function(self, fnc, name=None, **kwargs):
fnc_name = name if name is not None else function.__name__
self.__weight_functions[fnc_name] = functools.partial(fnc)

def add_nest_module(self, module_path):
if module_path not in self._nest_modules:
self._nest_modules.append(module_path)

def set_default_weight_function(self, fnc):
self.add_weight_function(fnc, 'default_weight_fnc', overwrite=True)

Expand Down
51 changes: 51 additions & 0 deletions bmtk/simulator/pointnet/pointsimulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import nest
from six import string_types
from six import moves
from pathlib import Path
import platform

from bmtk.simulator.core.simulator import Simulator
from bmtk.simulator.pointnet.config import Config
Expand Down Expand Up @@ -72,6 +74,7 @@ def __init__(self, graph, dt=0.001, overwrite=True, print_time=False, n_thread=1
# TODO: move this into it's own function and make sure it is called before network is built
nest.ResetKernel()
nest.SetKernelStatus({"resolution": self._dt, "overwrite_files": self._overwrite, "print_time": print_time, "local_num_threads": n_thread})
self._load_nest_modules()

@property
def tstart(self):
Expand Down Expand Up @@ -131,6 +134,54 @@ def _get_block_trial(self, duration):
data_res = -1
return n, res, data_res

def _add_library_path(self, lib_path):
if isinstance(lib_path, Path):
lib_path = lib_path.as_posix()
system = platform.system()
env_var = "LD_LIBRARY_PATH" if system in ['Linux', 'Windows', 'Java'] else 'DYLD_LIBRARY_PATH'
env_val = os.environ.get(env_var, '')

if lib_path not in env_val:
os.environ[env_var] = os.pathsep.join([lib_path, env_val])


def _load_nest_modules(self):
# If there is a "nest_modules" entry in the configuration "components" then go through and add them to
# the network _nest_modules list. They will be processed the same as calling network.add_nest_module().
components_path = self.net._components.get('nest_modules', [])
if isinstance(components_path, str):
self.net.add_nest_module(components_path)
elif isinstance(components_path, list):
for cpath in components_path:
self.net.add_nest_module(cpath)
else:
raise ValueError('Unable to load components/nest_modules value.')

# Go through all added nest modules and try to call nest.Install() for them.
for module in self.net._nest_modules:
if Path(module).is_dir():
# If module is a directory then add path to LD_LIBRARY_PATH then try to load all the .so/.a library binaries in directory
lib_dir = Path(module).resolve()
self._add_library_path(lib_dir)
so_search = Path(lib_dir) / '*.so'
a_search = Path(lib_dir) / '*.a'

for lib_file in glob.glob(so_search.as_posix()) + glob.glob(a_search.as_posix()): # lib_path.parent.resolve()
module_name = Path(lib_file).name
nest.Install(module_name)

elif Path(module).is_file():
# If user tries to pass in a path to a library binary
lib_path = Path(module)
lib_dir = lib_path.parent.resolve()
lib_filename = lib_path.name
self._add_library_path(lib_dir)
nest.Install(lib_filename)

else:
# If user just tries nest.Install('mymodule')
nest.Install(module)

'''
def set_spikes_recordings(self):
# TODO: Pass in output-dir and file name to save to
Expand Down
38 changes: 38 additions & 0 deletions examples/point_nestml_izh/config.circuit.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"manifest": {
"$BASE_DIR": ".",
"$NETWORK_DIR": "$BASE_DIR/network",
"$MODELS_DIR": "$BASE_DIR/../point_components"
},

"components": {
"point_neuron_models_dir": "$MODELS_DIR/cell_models",
"synaptic_models_dir": "$MODELS_DIR/synaptic_models",
"nest_modules": "components/nestml/nestml_izh_module.so"
},

"networks": {
"nodes": [
{
"nodes_file": "$NETWORK_DIR/cortex_nodes.h5",
"node_types_file": "$NETWORK_DIR/cortex_node_types.csv"
},
{
"nodes_file": "$NETWORK_DIR/thalamus_nodes.h5",
"node_types_file": "$NETWORK_DIR/thalamus_node_types.csv"
}
],
"edges": [
{
"edges_file": "$NETWORK_DIR/cortex_cortex_edges.h5",
"edge_types_file": "$NETWORK_DIR/cortex_cortex_edge_types.csv",
"enabled": false
},
{
"edges_file": "$NETWORK_DIR/thalamus_cortex_edges.h5",
"edge_types_file": "$NETWORK_DIR/thalamus_cortex_edge_types.csv",
"enabled": true
}
]
}
}
41 changes: 41 additions & 0 deletions examples/point_nestml_izh/config.simulation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"manifest": {
"$BASE_DIR": "${configdir}",
"$NETWORK_DIR": "$BASE_DIR/network",
"$MODELS_DIR": "$BASE_DIR/../point_components",
"$OUTPUT_DIR": "$BASE_DIR/output",
"$INPUT_DIR": "$BASE_DIR/inputs"
},

"run": {
"tstop": 3000.0,
"dt": 0.001,
"block_run": false,
"block_size": 1000.0
},

"inputs": {
"thalamus_spikes": {
"input_type": "spikes",
"module": "sonata",
"input_file": "$INPUT_DIR/thalamus_spikes.h5",
"node_set": "thalamus"
}
},

"reports": {
},

"output": {
"log_file": "log.txt",
"spikes_file": "spikes.h5",
"spikes_file_csv": "spikes.csv",
"output_dir": "$OUTPUT_DIR",
"overwrite_output_dir": true,
"quiet_simulator": true
},

"target_simulator":"NEST",
"rebuild_nestml": true,
"network": "config.circuit.json"
}
38 changes: 38 additions & 0 deletions examples/point_nestml_izh/custom_izh_neuron.nestml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

model custom_izh_neuron:

state:
v mV = -65 mV # Membrane potential in mV
u real = 0 # Membrane potential recovery variable

equations:
v' = (.04 * v * v / mV + 5 * v + (140 - u) * mV + (I_e * GOhm)) / ms
u' = a * (b * v - u * mV) / (mV * ms)

parameters:
a real = .02 # describes time scale of recovery variable
b real = .2 # sensitivity of recovery variable
c mV = -65 mV # after-spike reset value of v
d real = 8. # after-spike reset value of u

input:
spikes <- spike
I_e pA <- continuous

output:
spike

update:
integrate_odes()

onReceive(spikes):
# add synaptic current
v += spikes * mV * s

onCondition(v >= 30mV):
# threshold crossing
v = c
u += d
emit_spike()


54 changes: 54 additions & 0 deletions examples/point_nestml_izh/generate_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# import matplotlib.pyplot as plt
import nest
import numpy as np
import os
import re

from pynestml.codegeneration.nest_code_generator_utils import NESTCodeGeneratorUtils


nestml_izh_model = '''
model custom_izh_neuron:
state:
v mV = -65 mV # Membrane potential in mV
u real = 0 # Membrane potential recovery variable
equations:
v' = (.04 * v * v / mV + 5 * v + (140 - u) * mV + (I_e * GOhm)) / ms
u' = a * (b * v - u * mV) / (mV * ms)
parameters:
a real = .02 # describes time scale of recovery variable
b real = .2 # sensitivity of recovery variable
c mV = -65 mV # after-spike reset value of v
d real = 8. # after-spike reset value of u
input:
spikes <- spike
I_e pA <- continuous
output:
spike
update:
integrate_odes()
onReceive(spikes):
# add synaptic current
v += spikes * mV * s
onCondition(v >= 30mV):
# threshold crossing
v = c
u += d
emit_spike()
'''

# generate and build code
module_name, neuron_model_name_adapt_curr = NESTCodeGeneratorUtils.generate_code_for(
nestml_izh_model,
module_name='nestml_izh_module',
target_path='components/nestml'
)
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
edge_type_id target_query source_query dynamics_params syn_weight delay model_template
100 * ei=='e' ExcToInh.json 2.0 1.5 static_synapse
101 * ei=='i' InhToExc.json -1.5 1.5 static_synapse
Binary file not shown.
3 changes: 3 additions & 0 deletions examples/point_nestml_izh/network/cortex_node_types.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
node_type_id ei model_template dynamics_params pop_name model_type
100 e nestml:custom_izh_neuron_nestml iaf_psc_delta_exc.json LIF_exc point_neuron
101 i nestml:custom_izh_neuron_nestml iaf_psc_delta_inh.json LIF_inh point_neuron
Binary file added examples/point_nestml_izh/network/cortex_nodes.h5
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
edge_type_id target_query source_query dynamics_params syn_weight delay model_template
100 ei=='e' * ExcToExc.json 220.0 1.5 static_synapse
101 ei=='i' * ExcToExc.json 5.0 1.5 static_synapse
Binary file not shown.
2 changes: 2 additions & 0 deletions examples/point_nestml_izh/network/thalamus_node_types.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
node_type_id ei pop_name model_type
100 e input_network virtual
Binary file not shown.
31 changes: 31 additions & 0 deletions examples/point_nestml_izh/run_pointnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os, sys

from bmtk.simulator import pointnet
import nest

# nest.ResetKernel()
# nest.Install('nestml_39f5da2aef404887ac26009e1449206e_module')


# neuron = nest.Create('ornstein_uhlenbeck_noise_neuron')
# print(list(neuron))

def run(config_file):
configure = pointnet.Config.from_json(config_file)
configure.build_env()

network = pointnet.PointNetwork.from_config(configure)
# network.add_nest_module('/local1/workspace/bmtk/examples/point_nestml_izh/components/nestml/')
# network.add_nest_module('nestml_izh_module')
# network.add_nest_module('/local1/workspace/bmtk/examples/point_nestml_izh/components/nestml/nestml_izh_module.so')


# nest.Install('nestml_39f5da2aef404887ac26009e1449206e_module')
sim = pointnet.PointSimulator.from_config(configure, network)
sim.run()


if __name__ == '__main__':
# Find the appropriate config.json file
run('config.simulation.json')
# run('config.simulation_perturbations.json')

0 comments on commit 4eb01f9

Please sign in to comment.