diff --git a/doc/running/running_nest_compartmental.rst b/doc/running/running_nest_compartmental.rst index da3a4583d..7cbf5a2d6 100644 --- a/doc/running/running_nest_compartmental.rst +++ b/doc/running/running_nest_compartmental.rst @@ -19,7 +19,7 @@ Writing a compartmental NESTML model Defining the membrane potential variable ---------------------------------------- -One variable in the model represents the local membrane potential in a compartment. By default, it is called ``v_comp``. (This name is defined in the compartmental code generator options as the ``compartmental_variable_name`` option.). This variable needs to be defined as a state in any compartmental model to be referenced in the equations describing channels and synapses. +One variable in the model represents the local membrane potential in a compartment. By default, it is called ``v_comp``. (This name is defined in the compartmental code generator options as the ``compartmental_variable_name`` option.) This variable needs to be defined as a state in any compartmental model to be referenced in the equations describing channels and receptors. .. code-block:: nestml @@ -113,12 +113,12 @@ The only difference here is that the equation that is marked with the ``@mechani For a complete example, please see `concmech.nestml `_ and its associated unit test, `test__concmech_model.py `_. -Synapse description +Receptor description ------------------- -Here synapse models are based on convolutions over a buffer of incoming spikes. This means that the equation for the +Here receptor models are based on convolutions over a buffer of incoming spikes. This means that the equation for the current-contribution must contain a convolve() call and a description of the kernel used for that convolution is needed. -The descriptor for synapses is ``@mechanism::receptor``. +The descriptor for receptors is ``@mechanism::receptor``. .. code-block:: nestml @@ -160,8 +160,29 @@ For a complete example, please see `continuous_test.nestml `_ + +Synapses +-------- +We have also changed the way synapses may interact with the neuron. The background to this is that NESTML STDP-synapse models are always co-generated with the postsynaptic neuron model to communicate certain variables, such as the postsynaptic spike trace. We found this to be insufficient; instead, the synapse models are fully integrated with the receptor mechanisms of the neuron. This enables the user to access any receptor variables and other mechanism values within the synapse model by simply declaring them as states in the synapse model, without requiring further assignments. Another result of this merge is that all ODE equations in the synapse model are implicitly continuously integrated at each timestep, making unnecessary the calls to integrate_odes(). + +An example of such a model is implemented here: +`third_factor_stdp_synapse.nestml `_ Technical Notes --------------- diff --git a/pynestml/cocos/co_co_cm_channel_model.py b/pynestml/cocos/co_co_cm_channel_model.py index bc556d9b2..968c09e64 100644 --- a/pynestml/cocos/co_co_cm_channel_model.py +++ b/pynestml/cocos/co_co_cm_channel_model.py @@ -26,10 +26,10 @@ class CoCoCmChannelModel(CoCo): @classmethod - def check_co_co(cls, model: ASTModel): + def check_co_co(cls, model: ASTModel, global_info): """ Checks if this compartmental condition applies to the handed over neuron. If yes, it checks the presence of expected functions and declarations. :param model: a single neuron instance. """ - return ChannelProcessing.check_co_co(model) + return ChannelProcessing.check_co_co(model, global_info) diff --git a/pynestml/cocos/co_co_cm_concentration_model.py b/pynestml/cocos/co_co_cm_concentration_model.py index 88eeea042..fce1cf9dd 100644 --- a/pynestml/cocos/co_co_cm_concentration_model.py +++ b/pynestml/cocos/co_co_cm_concentration_model.py @@ -27,10 +27,10 @@ class CoCoCmConcentrationModel(CoCo): @classmethod - def check_co_co(cls, model: ASTModel): + def check_co_co(cls, model: ASTModel, global_info): """ Check if this compartmental condition applies to the handed over neuron. If yes, it checks the presence of expected functions and declarations. :param model: a single neuron instance. """ - return ConcentrationProcessing.check_co_co(model) + return ConcentrationProcessing.check_co_co(model, global_info) diff --git a/pynestml/cocos/co_co_cm_continuous_input_model.py b/pynestml/cocos/co_co_cm_continuous_input_model.py index c3e6eb2fa..38cd47d6e 100644 --- a/pynestml/cocos/co_co_cm_continuous_input_model.py +++ b/pynestml/cocos/co_co_cm_continuous_input_model.py @@ -26,11 +26,11 @@ class CoCoCmContinuousInputModel(CoCo): @classmethod - def check_co_co(cls, neuron: ASTModel): + def check_co_co(cls, neuron: ASTModel, global_info): """ Checks if this compartmental condition applies to the handed over neuron. If yes, it checks the presence of expected functions and declarations. :param neuron: a single neuron instance. :type neuron: ast_neuron """ - return ContinuousInputProcessing.check_co_co(neuron) + return ContinuousInputProcessing.check_co_co(neuron, global_info) diff --git a/pynestml/cocos/co_co_cm_global.py b/pynestml/cocos/co_co_cm_global.py new file mode 100644 index 000000000..a48e15364 --- /dev/null +++ b/pynestml/cocos/co_co_cm_global.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# co_co_cm_global.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_model import ASTModel +from pynestml.utils.global_processing import GlobalProcessing + + +class CoCoCmGlobal(CoCo): + @classmethod + def check_co_co(cls, neuron: ASTModel): + """ + Checks if this compartmental condition applies to the handed over neuron. + If yes, it checks the presence of expected functions and declarations. + :param neuron: a single neuron instance. + :type neuron: ast_neuron + """ + return GlobalProcessing.check_co_co(neuron) diff --git a/pynestml/cocos/co_co_cm_mech_shared_code.py b/pynestml/cocos/co_co_cm_mech_shared_code.py new file mode 100644 index 000000000..f22b72186 --- /dev/null +++ b/pynestml/cocos/co_co_cm_mech_shared_code.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +# co_co_cm_mech_shared_code.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.cocos.co_co import CoCo +from pynestml.utils.logger import LoggingLevel, Logger +from pynestml.utils.messages import Messages +from pynestml.utils.channel_processing import ChannelProcessing +from pynestml.utils.concentration_processing import ConcentrationProcessing +from pynestml.utils.receptor_processing import ReceptorProcessing +from pynestml.utils.continuous_input_processing import ContinuousInputProcessing +from pynestml.meta_model.ast_model import ASTModel + + +class CoCoCmMechSharedCode(CoCo): + @classmethod + def check_co_co(cls, model: ASTModel): + chan_info = ChannelProcessing.get_mechs_info(model) + conc_info = ConcentrationProcessing.get_mechs_info(model) + rec_info = ReceptorProcessing.get_mechs_info(model) + con_in_info = ContinuousInputProcessing.get_mechs_info(model) + + used_vars = dict() + all_info = chan_info | conc_info | rec_info | con_in_info + for info_name, info in all_info.items(): + all_vars = list(set(info['States'].keys()) | set(info["Parameters"].keys()) | set( + info["Internals"].keys())) # + [e.get_name() for e in info["Dependencies"]["global"]] + for var in all_vars: + if var not in used_vars.keys(): + used_vars[var] = list() + used_vars[var].append(info_name) + + for var_name, var in used_vars.items(): + if len(var) > 1: + code, message = Messages.cm_shared_variables_not_allowed(var_name, var) + Logger.log_message(error_position=None, + code=code, message=message, + log_level=LoggingLevel.ERROR) diff --git a/pynestml/cocos/co_co_cm_receptor_model.py b/pynestml/cocos/co_co_cm_receptor_model.py new file mode 100644 index 000000000..d7a0afc56 --- /dev/null +++ b/pynestml/cocos/co_co_cm_receptor_model.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# co_co_cm_receptor_model.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_model import ASTModel +from pynestml.utils.receptor_processing import ReceptorProcessing + + +class CoCoCmReceptorModel(CoCo): + + @classmethod + def check_co_co(cls, model: ASTModel, global_info): + """ + Checks if this compartmental condition applies to the handed over neuron. + If yes, it checks the presence of expected functions and declarations. + :param model: a single neuron instance. + """ + return ReceptorProcessing.check_co_co(model, global_info) diff --git a/pynestml/cocos/co_co_cm_synapse_model.py b/pynestml/cocos/co_co_cm_synapse_model.py index 5359e15cf..d5fc72f5f 100644 --- a/pynestml/cocos/co_co_cm_synapse_model.py +++ b/pynestml/cocos/co_co_cm_synapse_model.py @@ -31,6 +31,6 @@ def check_co_co(cls, model: ASTModel): """ Checks if this compartmental condition applies to the handed over neuron. If yes, it checks the presence of expected functions and declarations. - :param model: a single neuron instance. + :param model: a single synapse instance. """ return SynapseProcessing.check_co_co(model) diff --git a/pynestml/cocos/co_co_v_comp_exists.py b/pynestml/cocos/co_co_v_comp_exists.py index 51308f2cc..3d4bc975a 100644 --- a/pynestml/cocos/co_co_v_comp_exists.py +++ b/pynestml/cocos/co_co_v_comp_exists.py @@ -49,6 +49,9 @@ def check_co_co(cls, neuron: ASTModel): if not FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL': return + if not isinstance(neuron, ASTModel): + return + enforced_variable_name = NESTCompartmentalCodeGenerator._default_options["compartmental_variable_name"] state_blocks = neuron.get_state_blocks() diff --git a/pynestml/cocos/co_cos_manager.py b/pynestml/cocos/co_cos_manager.py index 6858151f0..d016acbeb 100644 --- a/pynestml/cocos/co_cos_manager.py +++ b/pynestml/cocos/co_cos_manager.py @@ -22,11 +22,11 @@ from typing import Union from pynestml.cocos.co_co_all_variables_defined import CoCoAllVariablesDefined +from pynestml.cocos.co_co_cm_global import CoCoCmGlobal +from pynestml.cocos.co_co_cm_mech_shared_code import CoCoCmMechSharedCode +from pynestml.cocos.co_co_cm_synapse_model import CoCoCmSynapseModel from pynestml.cocos.co_co_cm_channel_model import CoCoCmChannelModel -from pynestml.cocos.co_co_cm_concentration_model import CoCoCmConcentrationModel from pynestml.cocos.co_co_cm_continuous_input_model import CoCoCmContinuousInputModel -from pynestml.cocos.co_co_cm_synapse_model import CoCoCmSynapseModel -from pynestml.cocos.co_co_convolve_has_correct_parameter import CoCoConvolveHasCorrectParameter from pynestml.cocos.co_co_convolve_cond_correctly_built import CoCoConvolveCondCorrectlyBuilt from pynestml.cocos.co_co_correct_numerator_of_unit import CoCoCorrectNumeratorOfUnit from pynestml.cocos.co_co_correct_order_in_equation import CoCoCorrectOrderInEquation @@ -36,13 +36,11 @@ from pynestml.cocos.co_co_function_calls_consistent import CoCoFunctionCallsConsistent from pynestml.cocos.co_co_function_unique import CoCoFunctionUnique from pynestml.cocos.co_co_illegal_expression import CoCoIllegalExpression -from pynestml.cocos.co_co_input_port_not_assigned_to import CoCoInputPortNotAssignedTo from pynestml.cocos.co_co_integrate_odes_params_correct import CoCoIntegrateODEsParamsCorrect from pynestml.cocos.co_co_inline_expressions_have_rhs import CoCoInlineExpressionsHaveRhs from pynestml.cocos.co_co_inline_expression_not_assigned_to import CoCoInlineExpressionNotAssignedTo from pynestml.cocos.co_co_inline_max_one_lhs import CoCoInlineMaxOneLhs from pynestml.cocos.co_co_input_port_not_assigned_to import CoCoInputPortNotAssignedTo -from pynestml.cocos.co_co_input_port_qualifier_unique import CoCoInputPortQualifierUnique from pynestml.cocos.co_co_internals_assigned_only_in_internals_block import CoCoInternalsAssignedOnlyInInternalsBlock from pynestml.cocos.co_co_integrate_odes_called_if_equations_defined import CoCoIntegrateOdesCalledIfEquationsDefined from pynestml.cocos.co_co_invariant_is_boolean import CoCoInvariantIsBoolean @@ -61,6 +59,10 @@ from pynestml.cocos.co_co_resolution_func_used import CoCoResolutionOrStepsFuncUsed from pynestml.cocos.co_co_simple_delta_function import CoCoSimpleDeltaFunction from pynestml.cocos.co_co_state_variables_initialized import CoCoStateVariablesInitialized +from pynestml.cocos.co_co_convolve_has_correct_parameter import CoCoConvolveHasCorrectParameter +from pynestml.cocos.co_co_cm_receptor_model import CoCoCmReceptorModel +from pynestml.cocos.co_co_cm_concentration_model import CoCoCmConcentrationModel +from pynestml.cocos.co_co_input_port_qualifier_unique import CoCoInputPortQualifierUnique from pynestml.cocos.co_co_timestep_function_legally_used import CoCoTimestepFuncLegallyUsed from pynestml.cocos.co_co_user_defined_function_correctly_defined import CoCoUserDefinedFunctionCorrectlyDefined from pynestml.cocos.co_co_v_comp_exists import CoCoVCompDefined @@ -70,6 +72,7 @@ from pynestml.cocos.co_co_vector_parameter_declared_in_right_block import CoCoVectorParameterDeclaredInRightBlock from pynestml.cocos.co_co_vector_variable_in_non_vector_declaration import CoCoVectorVariableInNonVectorDeclaration from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.utils.global_processing import GlobalProcessing from pynestml.meta_model.ast_model import ASTModel from pynestml.utils.logger import Logger @@ -142,17 +145,25 @@ def check_v_comp_requirement(cls, neuron: ASTModel): CoCoVCompDefined.check_co_co(neuron) @classmethod - def check_compartmental_model(cls, neuron: ASTModel) -> None: + def check_compartmental_neuron_model(cls, neuron: ASTModel) -> None: """ collects all relevant information for the different compartmental mechanism classes for later code-generation searches for inlines or odes with decorator @mechanism:: and performs a base and, depending on type, specific information collection process. See nestml documentation on compartmental code generation. """ - CoCoCmChannelModel.check_co_co(neuron) - CoCoCmConcentrationModel.check_co_co(neuron) - CoCoCmSynapseModel.check_co_co(neuron) - CoCoCmContinuousInputModel.check_co_co(neuron) + cls.check_v_comp_requirement(neuron) + CoCoCmGlobal.check_co_co(neuron) + global_info = GlobalProcessing.get_global_info(neuron) + CoCoCmChannelModel.check_co_co(neuron, global_info) + CoCoCmConcentrationModel.check_co_co(neuron, global_info) + CoCoCmReceptorModel.check_co_co(neuron, global_info) + CoCoCmContinuousInputModel.check_co_co(neuron, global_info) + CoCoCmMechSharedCode.check_co_co(neuron) + + @classmethod + def check_compartmental_synapse_model(cls, synapse: ASTModel) -> None: + CoCoCmSynapseModel.check_co_co(synapse) @classmethod def check_inline_expressions_have_rhs(cls, model: ASTModel): @@ -428,7 +439,7 @@ def check_co_co_nest_random_functions_legally_used(cls, model: ASTModel): CoCoNestRandomFunctionsLegallyUsed.check_co_co(model) @classmethod - def check_cocos(cls, model: ASTModel, after_ast_rewrite: bool = False): + def check_cocos(cls, model: ASTModel, after_ast_rewrite: bool = False, syn_model: bool = False): """ Checks all context conditions. :param model: a single model object. @@ -443,8 +454,8 @@ def check_cocos(cls, model: ASTModel, after_ast_rewrite: bool = False): cls.check_variables_defined_before_usage(model) if FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL': # XXX: TODO: refactor this out; define a ``cocos_from_target_name()`` in the frontend instead. - cls.check_v_comp_requirement(model) - cls.check_compartmental_model(model) + if not syn_model: + cls.check_compartmental_neuron_model(model) cls.check_inline_expressions_have_rhs(model) cls.check_inline_has_max_one_lhs(model) cls.check_input_ports_not_assigned_to(model) diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py index 82be5734f..e7d6b23f2 100644 --- a/pynestml/codegeneration/nest_compartmental_code_generator.py +++ b/pynestml/codegeneration/nest_compartmental_code_generator.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import datetime import os @@ -32,6 +32,8 @@ import pynestml from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.codegeneration.code_generator import CodeGenerator +from pynestml.codegeneration.code_generator_utils import CodeGeneratorUtils +from pynestml.codegeneration.nest_code_generator import NESTCodeGenerator from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper from pynestml.codegeneration.nest_declarations_helper import NestDeclarationsHelper from pynestml.codegeneration.printers.constant_printer import ConstantPrinter @@ -57,8 +59,9 @@ from pynestml.meta_model.ast_variable import ASTVariable from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.symbols.symbol import SymbolKind +from pynestml.utils.global_info_enricher import GlobalInfoEnricher +from pynestml.utils.global_processing import GlobalProcessing from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer -from pynestml.utils.ast_vector_parameter_setter_and_printer import ASTVectorParameterSetterAndPrinter from pynestml.utils.ast_vector_parameter_setter_and_printer_factory import ASTVectorParameterSetterAndPrinterFactory from pynestml.utils.mechanism_processing import MechanismProcessing from pynestml.utils.channel_processing import ChannelProcessing @@ -72,8 +75,11 @@ from pynestml.utils.logger import LoggingLevel from pynestml.utils.messages import Messages from pynestml.utils.model_parser import ModelParser -from pynestml.utils.syns_info_enricher import SynsInfoEnricher +from pynestml.utils.string_utils import removesuffix from pynestml.utils.synapse_processing import SynapseProcessing +from pynestml.utils.syns_info_enricher import SynsInfoEnricher +from pynestml.utils.recs_info_enricher import RecsInfoEnricher +from pynestml.utils.receptor_processing import ReceptorProcessing from pynestml.visitors.ast_random_number_generator_visitor import ASTRandomNumberGeneratorVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor @@ -85,6 +91,8 @@ class NESTCompartmentalCodeGenerator(CodeGenerator): Options: - **neuron_parent_class**: The C++ class from which the generated NESTML neuron class inherits. Examples: ``"ArchivingNode"``, ``"StructuralPlasticityNode"``. Default: ``"ArchivingNode"``. - **neuron_parent_class_include**: The C++ header filename to include that contains **neuron_parent_class**. Default: ``"archiving_node.h"``. + - **neuron_synapse_pairs**: List of pairs of (neuron, synapse) model names. + - **synapse_models**: List of synapse model names. Instructs the code generator that models with these names are synapse models. - **preserve_expressions**: Set to True, or a list of strings corresponding to individual variable names, to disable internal rewriting of expressions, and return same output as input expression where possible. Only applies to variables specified as first-order differential equations. (This parameter is passed to ODE-toolbox.) - **simplify_expression**: For all expressions ``expr`` that are rewritten by ODE-toolbox: the contents of this parameter string are ``eval()``ed in Python to obtain the final output expression. Override for custom expression simplification steps. Example: ``sympy.simplify(expr)``. Default: ``"sympy.logcombine(sympy.powsimp(sympy.expand(expr)))"``. (This parameter is passed to ODE-toolbox.) - **templates**: Path containing jinja templates used to generate code for NEST simulator. @@ -92,9 +100,14 @@ class NESTCompartmentalCodeGenerator(CodeGenerator): - **model_templates**: A list of the jinja templates or a relative path to a directory containing the templates related to the neuron model(s). - **module_templates**: A list of the jinja templates or a relative path to a directory containing the templates related to generating the NEST module. - **nest_version**: A string identifying the version of NEST Simulator to generate code for. The string corresponds to the NEST Simulator git repository tag or git branch name, for instance, ``"v2.20.2"`` or ``"master"``. The default is the empty string, which causes the NEST version to be automatically identified from the ``nest`` Python module. + - **delay_variable**: A mapping identifying, for each synapse (the name of which is given as a key), the variable or parameter in the model that corresponds with the NEST ``Connection`` class delay property. + - **weight_variable**: Like ``delay_variable``, but for synaptic weight. """ _default_options = { + "neuron_synapse_pairs": [], + "neuron_models": [], + "synapse_models": [], "neuron_parent_class": "ArchivingNode", "neuron_parent_class_include": "archiving_node.h", "preserve_expressions": True, @@ -111,7 +124,10 @@ class NESTCompartmentalCodeGenerator(CodeGenerator): "cm_tree_@NEURON_NAME@.h.jinja2"]}, "module_templates": ["setup"]}, "nest_version": "", - "compartmental_variable_name": "v_comp"} + "compartmental_variable_name": "v_comp", + "delay_variable": {}, + "weight_variable": {} + } _variable_matching_template = r"(\b)({})(\b)" _model_templates = dict() @@ -120,6 +136,8 @@ class NESTCompartmentalCodeGenerator(CodeGenerator): def __init__(self, options: Optional[Mapping[str, Any]] = None): super().__init__(options) + self._nest_code_generator = NESTCodeGenerator(options) + # auto-detect NEST Simulator installed version if not self.option_exists("nest_version") or not self.get_option("nest_version"): from pynestml.codegeneration.nest_tools import NESTTools @@ -192,15 +210,20 @@ def raise_helper(self, msg): raise TemplateRuntimeError(msg) def set_options(self, options: Mapping[str, Any]) -> Mapping[str, Any]: + self._nest_code_generator.set_options(options) ret = super().set_options(options) self.setup_template_env() return ret def generate_code(self, models: List[ASTModel]) -> None: - self.analyse_transform_neurons(models) - self.generate_neurons(models) - self.generate_module_code(models) + neurons, synapses = CodeGeneratorUtils.get_model_types_from_names(models, synapse_models=self.get_option( + "synapse_models")) + synapses_per_neuron = self.arrange_synapses_per_neuron(neurons, synapses) + self.analyse_transform_neurons(neurons) + self.analyse_transform_synapses(synapses) + self.generate_compartmental_neurons(neurons, synapses_per_neuron) + self.generate_module_code(neurons) def generate_module_code(self, neurons: List[ASTModel]) -> None: """t @@ -278,6 +301,9 @@ def get_cm_syns_main_file_prefix(self, neuron): def get_cm_syns_tree_file_prefix(self, neuron): return "cm_tree_" + neuron.get_name() + def get_stdp_synapse_main_file_prefix(self, synapse): + return synapse.get_name() + def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: """ Analyse and transform a list of neurons. @@ -290,11 +316,92 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: spike_updates = self.analyse_neuron(neuron) neuron.spike_updates = spike_updates - def create_ode_indict(self, - neuron: ASTModel, - parameters_block: ASTBlockWithVariables, - kernel_buffers: Mapping[ASTKernel, - ASTInputPort]): + equations_block = neuron.get_equations_blocks()[0] + kernel_buffers = ASTUtils.generate_kernel_buffers(neuron, equations_block) + + analytic_solver, numeric_solver = self._nest_code_generator.ode_toolbox_analysis(neuron, kernel_buffers) + + delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block) + + spike_updates, post_spike_updates = self._nest_code_generator.get_spike_update_expressions(neuron, + kernel_buffers, + [analytic_solver, + numeric_solver], + delta_factors) + + neuron.spike_updates = spike_updates + neuron.post_spike_updates = post_spike_updates + + def analyse_transform_synapses(self, synapses: List[ASTModel]) -> None: + """ + Analyse and transform a list of synapses. + :param synapses: a list of synapses. + """ + for synapse in synapses: + Logger.log_message(None, None, "Analysing/transforming synapse {}.".format(synapse.get_name()), None, + LoggingLevel.INFO) + SynapseProcessing.process(synapse, self.get_option("neuron_synapse_pairs")) + self.analyse_synapse(synapse) + + def analyse_synapse(self, synapse: ASTModel): # -> Dict[str, ASTAssignment]: + """ + Analyse and transform a single synapse. + :param synapse: a single synapse. + """ + """ + equations_block = synapse.get_equations_blocks()[0] + ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block) + ASTUtils.add_timestep_symbol(synapse) + self.update_symbol_table(synapse) + """ + + code, message = Messages.get_start_processing_model(synapse.get_name()) + Logger.log_message(synapse, code, message, synapse.get_source_position(), LoggingLevel.INFO) + + spike_updates = {} + if synapse.get_equations_blocks(): + if len(synapse.get_equations_blocks()) > 1: + raise Exception("Only one equations block per model supported for now") + + equations_block = synapse.get_equations_blocks()[0] + + kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, equations_block) + + # substitute inline expressions with each other + # such that no inline expression references another inline expression; + # deference inline_expressions inside ode_equations + InlineExpressionExpansionTransformer().transform(synapse) + + delta_factors = ASTUtils.get_delta_factors_(synapse, equations_block) + ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block) + + analytic_solver, numeric_solver = self.ode_toolbox_analysis(synapse, kernel_buffers) + self.analytic_solver[synapse.get_name()] = analytic_solver + self.numeric_solver[synapse.get_name()] = numeric_solver + + ASTUtils.remove_initial_values_for_kernels(synapse) + kernels = ASTUtils.remove_kernel_definitions_from_equations_block(synapse) + ASTUtils.update_initial_values_for_odes(synapse, [analytic_solver, numeric_solver]) + ASTUtils.remove_ode_definitions_from_equations_block(synapse) + ASTUtils.create_initial_values_for_kernels(synapse, [analytic_solver, numeric_solver], kernels) + ASTUtils.create_integrate_odes_combinations(synapse) + ASTUtils.replace_variable_names_in_expressions(synapse, [analytic_solver, numeric_solver]) + self.update_symbol_table(synapse, True) + + else: + self.update_symbol_table(synapse, True) + + synapse_name_stripped = removesuffix(removesuffix(synapse.name.split("_with_")[0], "_"), + FrontendConfiguration.suffix) + # special case for NEST delay variable (state or parameter) + + ASTUtils.update_blocktype_for_common_parameters(synapse) + # assert synapse_name_stripped in self.get_option("delay_variable").keys(), "Please specify a delay variable for synapse '" + synapse_name_stripped + "' in the code generator options" + # assert ASTUtils.get_variable_by_name(synapse, self.get_option("delay_variable")[synapse_name_stripped]), "Delay variable '" + self.get_option("delay_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "'" + + return spike_updates + + def create_ode_indict(self, neuron: ASTModel, parameters_block: ASTBlockWithVariables, kernel_buffers: Mapping[ASTKernel, ASTInputPort]): odetoolbox_indict = self.transform_ode_and_kernels_to_json( neuron, parameters_block, kernel_buffers) odetoolbox_indict["options"] = {} @@ -303,11 +410,7 @@ def create_ode_indict(self, return odetoolbox_indict - def ode_solve_analytically(self, - neuron: ASTModel, - parameters_block: ASTBlockWithVariables, - kernel_buffers: Mapping[ASTKernel, - ASTInputPort]): + def ode_solve_analytically(self, neuron: ASTModel, parameters_block: ASTBlockWithVariables, kernel_buffers: Mapping[ASTKernel, ASTInputPort]): odetoolbox_indict = self.create_ode_indict( neuron, parameters_block, kernel_buffers) @@ -332,8 +435,8 @@ def ode_toolbox_analysis(self, neuron: ASTModel, """ Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output. """ - assert len(neuron.get_equations_blocks()) == 1, "Only one equations block supported for now" - assert len(neuron.get_parameters_blocks()) == 1, "Only one parameters block supported for now" + assert len(neuron.get_equations_blocks()) <= 1, "Only one equations block supported for now" + assert len(neuron.get_parameters_blocks()) <= 1, "Only one parameters block supported for now" equations_block = neuron.get_equations_blocks()[0] @@ -342,7 +445,9 @@ def ode_toolbox_analysis(self, neuron: ASTModel, # no equations defined -> no changes to the neuron return None, None - parameters_block = neuron.get_parameters_blocks()[0] + parameters_block = None + if len(neuron.get_parameters_blocks()): + parameters_block = neuron.get_parameters_blocks()[0] solver_result, analytic_solver = self.ode_solve_analytically( neuron, parameters_block, kernel_buffers) @@ -415,8 +520,8 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: Logger.log_message(neuron, code, message, neuron.get_source_position(), LoggingLevel.INFO) - assert len(neuron.get_equations_blocks()) == 1, "Only one equations block supported for now" - assert len(neuron.get_state_blocks()) == 1, "Only one state block supported for now" + assert len(neuron.get_equations_blocks()) <= 1, "Only one equations block supported for now" + assert len(neuron.get_state_blocks()) <= 1, "Only one state block supported for now" equations_block = neuron.get_equations_blocks()[0] @@ -494,11 +599,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: ASTUtils.update_initial_values_for_odes( neuron, [analytic_solver, numeric_solver]) - # remove differential equations from equations block - # those are now resolved into zero order variables and their - # corresponding updates - ASTUtils.remove_ode_definitions_from_equations_block(neuron) - # restore state variables that were referenced by kernels # and set their initial values by those suggested by ODE-toolbox ASTUtils.create_initial_values_for_kernels( @@ -515,13 +615,8 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: # conventions of ODE-toolbox ASTUtils.replace_convolution_aliasing_inlines(neuron) - # add propagator variables calculated by odetoolbox into internal blocks - if self.analytic_solver[neuron.get_name()] is not None: - neuron = ASTUtils.add_declarations_to_internals( - neuron, self.analytic_solver[neuron.get_name()]["propagators"]) - # generate how to calculate the next spike update - self.update_symbol_table(neuron, kernel_buffers) + self.update_symbol_table(neuron) # find any spike update expressions defined by the user spike_updates = self.get_spike_update_expressions( neuron, kernel_buffers, [analytic_solver, numeric_solver], delta_factors) @@ -565,7 +660,7 @@ def getUniqueSuffix(self, neuron: ASTModel) -> str: underscore_pos = ret.find("_") return ret - def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: + def _get_neuron_model_namespace(self, neuron: ASTModel, paired_synapse: ASTModel = None) -> Dict: """ Returns a standard namespace for generating neuron code for NEST :param neuron: a single neuron instance @@ -597,7 +692,20 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["nest_printer"] = self._nest_printer namespace["nestml_printer"] = NESTMLPrinter() namespace["type_symbol_printer"] = self._type_symbol_printer - namespace["vector_printer_factory"] = ASTVectorParameterSetterAndPrinterFactory(neuron, self._printer_no_origin) + + class VectorPrinter(): + def __init__(self, neuron, printer): + self.printer = ASTVectorParameterSetterAndPrinterFactory(neuron, printer) + self.std_vector_parameter = None + + def print(self, expression, index="i"): + self.std_vector_parameter = index + index_printer = self.printer.create_ast_vector_parameter_setter_and_printer(index) + return index_printer.print(expression) + + vector_printer = VectorPrinter(neuron, self._printer_no_origin) + + namespace["vector_printer"] = vector_printer # NESTML syntax keywords namespace["PyNestMLLexer"] = {} @@ -622,6 +730,8 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: "neuron_parent_class_include") namespace["PredefinedUnits"] = pynestml.symbols.predefined_units.PredefinedUnits + namespace["PredefinedFunctions"] = pynestml.symbols.predefined_functions.PredefinedFunctions + namespace["UnitTypeSymbol"] = pynestml.symbols.unit_type_symbol.UnitTypeSymbol namespace["SymbolKind"] = pynestml.symbols.symbol.SymbolKind @@ -704,8 +814,8 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["chan_info"] = ChannelProcessing.get_mechs_info(neuron) namespace["chan_info"] = ChanInfoEnricher.enrich_with_additional_info(neuron, namespace["chan_info"]) - namespace["syns_info"] = SynapseProcessing.get_mechs_info(neuron) - namespace["syns_info"] = SynsInfoEnricher.enrich_with_additional_info(neuron, namespace["syns_info"]) + namespace["recs_info"] = ReceptorProcessing.get_mechs_info(neuron) + namespace["recs_info"] = RecsInfoEnricher.enrich_with_additional_info(neuron, namespace["recs_info"]) namespace["conc_info"] = ConcentrationProcessing.get_mechs_info(neuron) namespace["conc_info"] = ConcInfoEnricher.enrich_with_additional_info(neuron, namespace["conc_info"]) @@ -713,12 +823,32 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["con_in_info"] = ContinuousInputProcessing.get_mechs_info(neuron) namespace["con_in_info"] = ConInInfoEnricher.enrich_with_additional_info(neuron, namespace["con_in_info"]) + if paired_synapse: + namespace["syns_info"] = SynapseProcessing.get_syn_info(paired_synapse) + namespace["syns_info"] = SynsInfoEnricher.enrich_with_additional_info(paired_synapse, + namespace["syns_info"], + namespace["chan_info"], + namespace["recs_info"], + namespace["conc_info"], + namespace["con_in_info"]) + else: + namespace["syns_info"] = dict() + + namespace["global_info"] = GlobalProcessing.get_global_info(neuron) + namespace["global_info"] = GlobalInfoEnricher.enrich_with_additional_info(neuron, namespace["global_info"]) + chan_info_string = MechanismProcessing.print_dictionary(namespace["chan_info"], 0) - syns_info_string = MechanismProcessing.print_dictionary(namespace["syns_info"], 0) + recs_info_string = MechanismProcessing.print_dictionary(namespace["recs_info"], 0) conc_info_string = MechanismProcessing.print_dictionary(namespace["conc_info"], 0) con_in_info_string = MechanismProcessing.print_dictionary(namespace["con_in_info"], 0) + if paired_synapse: + syns_info_string = MechanismProcessing.print_dictionary(namespace["syns_info"], 0) + else: + syns_info_string = "" + global_info_string = MechanismProcessing.print_dictionary(namespace["global_info"], 0) + code, message = Messages.get_mechs_dictionary_info(chan_info_string, recs_info_string, conc_info_string, + con_in_info_string, syns_info_string, global_info_string) - code, message = Messages.get_mechs_dictionary_info(chan_info_string, syns_info_string, conc_info_string, con_in_info_string) Logger.log_message(None, code, message, None, LoggingLevel.DEBUG) neuron_specific_filenames = { @@ -734,16 +864,19 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["types_printer"] = self._type_symbol_printer + # python utils + namespace["set"] = set + return namespace - def update_symbol_table(self, neuron, kernel_buffers): + def update_symbol_table(self, neuron, syn_model=False): """ Update symbol table and scope. """ SymbolTable.delete_model_scope(neuron.get_name()) symbol_table_visitor = ASTSymbolTableVisitor() neuron.accept(symbol_table_visitor) - CoCosManager.check_cocos(neuron, after_ast_rewrite=True) + CoCosManager.check_cocos(neuron, after_ast_rewrite=True, syn_model=syn_model) SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) def _get_ast_variable(self, neuron, var_name) -> Optional[ASTVariable]: @@ -947,3 +1080,47 @@ def transform_ode_and_kernels_to_json( )] = self._ode_toolbox_printer.print(decl.get_expression()) return odetoolbox_indict + + def generate_compartmental_neuron_code(self, neuron: ASTModel, paired_synapse=None) -> None: + self.generate_model_code(neuron.get_name(), + model_templates=self._model_templates["neuron"], + template_namespace=self._get_neuron_model_namespace(neuron, paired_synapse), + model_name_escape_string="@NEURON_NAME@") + + def generate_compartmental_neurons(self, neurons: Sequence[ASTModel], paired_synapses: dict) -> None: + """ + Generate code for the given neurons. + + :param neurons: a list of neurons. + """ + from pynestml.frontend.frontend_configuration import FrontendConfiguration + neuron_index = 0 + for neuron in neurons: + paired_syn_exists = False + for synapse in paired_synapses[neuron.get_name()]: + paired_syn_exists = True + self.generate_compartmental_neuron_code(neuron, synapse) + if not Logger.has_errors(neuron): + code, message = Messages.get_code_generated(neuron.get_name(), + FrontendConfiguration.get_target_path()) + Logger.log_message(neuron, code, message, neuron.get_source_position(), LoggingLevel.INFO) + if not paired_syn_exists: + self.generate_compartmental_neuron_code(neuron) + if not Logger.has_errors(neuron): + code, message = Messages.get_code_generated(neuron.get_name(), + FrontendConfiguration.get_target_path()) + Logger.log_message(neuron, code, message, neuron.get_source_position(), LoggingLevel.INFO) + neuron_index += 1 + + def arrange_synapses_per_neuron(self, neurons: Sequence[ASTModel], synapses: Sequence[ASTModel]): + paired_synapses = dict() + for neuron in neurons: + paired_synapses[neuron.get_name()] = list() + + neuron_synapse_pairs = self.get_option("neuron_synapse_pairs") + for pair in neuron_synapse_pairs: + for synapse in synapses: + if synapse.get_name() == (pair["synapse"] + "_nestml"): + paired_synapses[pair["neuron"] + "_nestml"].append(synapse) + + return paired_synapses diff --git a/pynestml/codegeneration/printers/cpp_function_call_printer.py b/pynestml/codegeneration/printers/cpp_function_call_printer.py index 11beba1bd..188f64f05 100644 --- a/pynestml/codegeneration/printers/cpp_function_call_printer.py +++ b/pynestml/codegeneration/printers/cpp_function_call_printer.py @@ -83,6 +83,9 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) -> """ function_name = function_call.get_name() + if function_name == PredefinedFunctions.HEAVISIDE: + return '({!s} > 0)' + if function_name == PredefinedFunctions.CLIP: # the arguments of this function must be swapped and are therefore [v_max, v_min, v] return 'std::min({2!s}, std::max({1!s}, {0!s}))' @@ -90,7 +93,7 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) -> if function_name == PredefinedFunctions.MAX: return 'std::max({!s}, {!s})' - if function_name == PredefinedFunctions.MIN: + if function_name == PredefinedFunctions.MIN or function_name == "Min": return 'std::min({!s}, {!s})' if function_name == PredefinedFunctions.ABS: diff --git a/pynestml/codegeneration/printers/nest_variable_printer.py b/pynestml/codegeneration/printers/nest_variable_printer.py index 1516a984d..79d5e038e 100644 --- a/pynestml/codegeneration/printers/nest_variable_printer.py +++ b/pynestml/codegeneration/printers/nest_variable_printer.py @@ -20,7 +20,6 @@ # along with NEST. If not, see . from __future__ import annotations - from typing import Dict, Optional from pynestml.codegeneration.nest_code_generator_utils import NESTCodeGeneratorUtils @@ -33,7 +32,6 @@ from pynestml.symbols.predefined_variables import PredefinedVariables from pynestml.symbols.symbol import SymbolKind from pynestml.symbols.unit_type_symbol import UnitTypeSymbol -from pynestml.symbols.variable_symbol import BlockType from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/FunctionDeclaration.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/FunctionDeclaration.jinja2 index 13c8c5ab1..cf9cf2d4d 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/FunctionDeclaration.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/FunctionDeclaration.jinja2 @@ -1,4 +1,4 @@ -{%- macro FunctionDeclaration(ast_function, namespace_prefix) -%} +{%- macro FunctionDeclaration(ast_function, namespace_prefix, pass_by_reference = false) -%} {%- with function_symbol = ast_function.get_scope().resolve_to_symbol(ast_function.get_name(), SymbolKind.FUNCTION) -%} {%- if function_symbol is none -%} {{ raise('Cannot resolve the method ' + ast_function.get_name()) }} @@ -8,7 +8,7 @@ {%- for param in ast_function.get_parameters() %} {%- with typeSym = param.get_data_type().get_type_symbol() -%} {%- filter indent(1, True) -%} -{{ type_symbol_printer.print(typeSym) }} {{ param.get_name() }} +{{ type_symbol_printer.print(typeSym) }}{% if pass_by_reference %}&{% endif %} {{ param.get_name() }} {%- if not loop.last -%} , {%- endif -%} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.cpp.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.cpp.jinja2 index 9fc7bf66f..3ac5c9e2d 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.cpp.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.cpp.jinja2 @@ -205,6 +205,19 @@ nest::{{neuronSpecificFileNamesCmSyns["main"]}}::set_status( const DictionaryDat * recordables map */ init_recordables_pointers_(); + + /** + * Set new values for any recordable and potentially evolving states + * explicitly defined in the model or implicitly created during + * code generation. + */ + std::map< Name, double* > recordables = c_tree_.get_recordables(); + for(auto recordable : recordables){ + if(statusdict->known(recordable.first)){ + *recordable.second = getValue< double >(statusdict, recordable.first); + statusdict->remove(recordable.first); + } + } } void nest::{{neuronSpecificFileNamesCmSyns["main"]}}::add_compartment_( DictionaryDatum& dd ) @@ -323,6 +336,8 @@ nest::{{neuronSpecificFileNamesCmSyns["main"]}}::update( Time const& origin, con SpikeEvent se; kernel().event_delivery_manager.send( *this, se, lag ); + + c_tree_.neuron_currents.postsynaptic_synaptic_processing(); } logger_.record_data( origin.get_steps() + lag ); @@ -365,4 +380,251 @@ nest::{{neuronSpecificFileNamesCmSyns["main"]}}::handle( DataLoggingRequest& e ) logger_.handle( e ); } +{%- if paired_synapse is defined %} +// ------------------------------------------------------------------------- +// Methods for neuron/synapse co-generation +// ------------------------------------------------------------------------- + +inline double +{{neuronName}}::get_spiketime_ms() const +{ + return last_spike_; +} + +void +{{neuronName}}::register_stdp_connection( double t_first_read, double delay ) +{ + // Mark all entries in the deque, which we will not read in future as read by + // this input input, so that we safely increment the incoming number of + // connections afterwards without leaving spikes in the history. + // For details see bug #218. MH 08-04-22 + + for ( std::deque< histentry__{{neuronName}} >::iterator runner = history_.begin(); + runner != history_.end() and ( t_first_read - runner->t_ > -1.0 * nest::kernel().connection_manager.get_stdp_eps() ); + ++runner ) + { + ( runner->access_counter_ )++; + } + + n_incoming_++; + + max_delay_ = std::max( delay, max_delay_ ); +} + + +void +{{neuronName}}::get_history__( double t1, + double t2, + std::deque< histentry__{{neuronName}} >::iterator* start, + std::deque< histentry__{{neuronName}} >::iterator* finish ) +{ + *finish = history_.end(); + if ( history_.empty() ) + { + *start = *finish; + return; + } + std::deque< histentry__{{neuronName}} >::reverse_iterator runner = history_.rbegin(); + const double t2_lim = t2 + nest::kernel().connection_manager.get_stdp_eps(); + const double t1_lim = t1 + nest::kernel().connection_manager.get_stdp_eps(); + while ( runner != history_.rend() and runner->t_ >= t2_lim ) + { + ++runner; + } + *finish = runner.base(); + while ( runner != history_.rend() and runner->t_ >= t1_lim ) + { + runner->access_counter_++; + ++runner; + } + *start = runner.base(); +} + +void +{{neuronName}}::set_spiketime( nest::Time const& t_sp, double offset ) +{ + {{neuron_parent_class}}::set_spiketime( t_sp, offset ); + + unsigned int num_transferred_variables = 0; +{%- for var in transferred_variables %} + ++num_transferred_variables; {# XXX: TODO: make this into a const member variable #} +{%- endfor %} + + const double t_sp_ms = t_sp.get_ms() - offset; + + if ( n_incoming_ ) + { + // prune all spikes from history which are no longer needed + // only remove a spike if: + // - its access counter indicates it has been read out by all connected + // STDP synapses, and + // - there is another, later spike, that is strictly more than + // (min_global_delay + max_delay_ + eps) away from the new spike (at t_sp_ms) + while ( history_.size() > 1 ) + { + const double next_t_sp = history_[ 1 ].t_; + // Note that ``access_counter`` now has an extra multiplicative factor equal (``n_incoming_``) to the number of trace values that exist, so that spikes are removed from the history only after they have been read out for the sake of computing each trace. + // see https://www.frontiersin.org/files/Articles/1382/fncom-04-00141-r1/image_m/fncom-04-00141-g003.jpg (Potjans et al. 2010) + + if ( history_.front().access_counter_ >= n_incoming_ * num_transferred_variables + and t_sp_ms - next_t_sp > max_delay_ + nest::Time::delay_steps_to_ms(nest::kernel().connection_manager.get_min_delay()) + nest::kernel().connection_manager.get_stdp_eps() ) + { + history_.pop_front(); + } + else + { + break; + } + } + + if (history_.size() > 0) + { + assert(history_.back().t_ == last_spike_); + + /** + * print extra on-emit statements transferred from synapse + **/ + +{%- filter indent(4, True) %} +{%- for stmt in extra_on_emit_spike_stmts_from_synapse %} +{%- include "directives_cpp/Statement.jinja2" %} +{%- endfor %} +{%- endfilter %} + + /** + * print updates due to convolutions + **/ + +{%- for _, spike_update in post_spike_updates.items() %} + {{ printer.print(utils.get_variable_by_name(astnode, spike_update.get_variable().get_complete_name())) }} += 1.; +{%- endfor %} + + last_spike_ = t_sp_ms; + history_.push_back( histentry__{{neuronName}}( last_spike_, 0) ); + } + else + { + last_spike_ = t_sp_ms; + } +} + + +void +{{neuronName}}::clear_history() +{ + last_spike_ = -1.0; + history_.clear(); +} + + +{# + generate getter functions for the transferred variables +#} + +{%- for var in transferred_variables %} +{%- with variable_symbol = transferred_variables_syms[var] %} + +{%- if not var == variable_symbol.get_symbol_name() %} +{{ raise('Error in resolving variable to symbol') }} +{%- endif %} + +double +{{neuronName}}::get_{{var}}( double t, const bool before_increment ) +{ +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: getting value at t = " << t << std::endl; +#endif + + // case when the neuron has not yet spiked + if ( history_.empty() ) + { +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: \thistory empty, returning initial value = " << {{var}}__iv << std::endl; +#endif + // return initial value + return {{var}}__iv; + } + + // search for the latest post spike in the history buffer that came strictly before `t` + int i = history_.size() - 1; + double eps = 0.; + if ( before_increment ) + { + eps = nest::kernel().connection_manager.get_stdp_eps(); + } + while ( i >= 0 ) + { + if ( t - history_[ i ].t_ >= eps ) + { +#ifdef DEBUG + std::cout<<"{{neuronName}}::get_{{var}}: \tspike occurred at history[i].t_ = " << history_[i].t_ << std::endl; +#endif + + /** + * update state variables transferred from synapse from `history[i].t_` to `t` + * + * variables that will be integrated: {{ purely_numeric_state_variables_moved + analytic_state_variables_moved }} + **/ + + if ( t - history_[ i ].t_ >= nest::kernel().connection_manager.get_stdp_eps() ) + { + const double old___h = V_.__h; + V_.__h = t - history_[i].t_; + assert(V_.__h > 0); + recompute_internal_variables(true); + + V_.__h = old___h; + recompute_internal_variables(true); + } + +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: \treturning " << {{ printer.print(utils.get_variable_by_name(astnode, var)) }} << std::endl; +#endif + return {{ printer.print(utils.get_variable_by_name(astnode, var)) }}; // type: {{declarations.print_variable_type(variable_symbol)}} + } + --i; + } + + // this case occurs when the trace was requested at a time precisely at that of the first spike in the history + if ( (!before_increment) and t == history_[ 0 ].t_) + { + +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: \ttrace requested at exact time of history entry 0, returning " << {{ printer.print(utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name())) }} << std::endl; +#endif + return {{ printer.print(utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name())) }}; + } + + // this case occurs when the trace was requested at a time before the first spike in the history + // return initial value propagated in time +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: \tfall-through, returning initial value = " << {{var}}__iv << std::endl; +#endif + + if (t == 0.) + { + return 0.; // initial value for convolution is always 0 + } + + /** + * update state variables transferred from synapse from initial condition to `t` + * + * variables that will be integrated: {{ purely_numeric_state_variables_moved + analytic_state_variables_moved }} + **/ + + const double old___h = V_.__h; + V_.__h = t; // from time 0 to the requested time + assert(V_.__h > 0); + recompute_internal_variables(true); + + V_.__h = old___h; + recompute_internal_variables(true); + + return {{ printer.print(utils.get_variable_by_name(astnode, var)) }}; +} +{%- endwith -%} +{%- endfor %} + +{%- endif %} + } // namespace diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.h.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.h.jinja2 index a3c9f3665..c09a0ba6d 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.h.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.h.jinja2 @@ -20,8 +20,8 @@ * */ -#ifndef CM_DEFAULT_H -#define CM_DEFAULT_H +#ifndef CM_{{neuronSpecificFileNamesCmSyns["main"].upper()}} +#define CM_{{neuronSpecificFileNamesCmSyns["main"].upper()}} // Includes from nestkernel: #include "archiving_node.h" @@ -323,12 +323,11 @@ inline size_t inline size_t {{neuronSpecificFileNamesCmSyns["main"]}}::handles_test_event( CurrentEvent&, size_t receptor_type ) { - // if get_compartment returns nullptr, raise the error - if ( not c_tree_.get_compartment( long( receptor_type ), c_tree_.get_root(), 0 ) ) + if ( ( receptor_type < 0 ) or ( receptor_type >= static_cast< size_t >( syn_buffers_.size() ) ) ) { std::ostringstream msg; - msg << "Valid current receptor ports for " << get_name() << " are in "; - msg << "[" << 0 << ", " << c_tree_.get_size() << "["; + msg << "Valid spike receptor ports for " << get_name() << " are in "; + msg << "[" << 0 << ", " << syn_buffers_.size() << "["; throw UnknownPort( receptor_type, msg.str() ); } return receptor_type; diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/CompoundStatement.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/CompoundStatement.jinja2 new file mode 100644 index 000000000..8439f340c --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/CompoundStatement.jinja2 @@ -0,0 +1,18 @@ +{# + Handles the compound statement. + @grammar: Compound_Stmt = IF_Stmt | FOR_Stmt | WHILE_Stmt; +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- if stmt.is_if_stmt() %} +{%- with ast = stmt.get_if_stmt() %} +{%- include "cm_directives_cpp/IfStatement.jinja2" %} +{%- endwith %} +{%- elif stmt.is_for_stmt() %} +{%- with ast = stmt.get_for_stmt() %} +{%- include "cm_directives_cpp/ForStatement.jinja2" %} +{%- endwith %} +{%- elif stmt.is_while_stmt() %} +{%- with ast = stmt.get_while_stmt() %} +{%- include "cm_directives_cpp/WhileStatement.jinja2" %} +{%- endwith %} +{%- endif %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Declaration.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Declaration.jinja2 new file mode 100644 index 000000000..b72323c90 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Declaration.jinja2 @@ -0,0 +1,21 @@ +{# + Generates C++ declaration + @param ast ASTDeclaration +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- for variable in declarations.get_variables(ast) %} +{%- if ast.has_size_parameter() %} +{{declarations.print_variable_type(variable)}} {{variable.get_symbol_name()}}(P_.{{declarations.print_size_parameter(ast)}}); +{%- if ast.has_expression() %} +for (long i=0; i < get_{{declarations.print_size_parameter(ast)}}(); i++) { + {{variable.get_symbol_name()}}[i] = {{printer.print(ast.getExpr())}}; +} +{%- endif %} +{%- else %} +{%- if ast.has_expression() %} +{{variable.get_symbol_name()}}[{{ printer.std_vector_parameter }}] = {{printer.print(ast.get_expression())}}; +{%- else %} +{{variable.get_symbol_name()}}[{{ printer.std_vector_parameter }}] = 0; +{%- endif %} +{%- endif %} +{%- endfor -%} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/FunctionCall.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/FunctionCall.jinja2 new file mode 100644 index 000000000..fb2d13f0c --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/FunctionCall.jinja2 @@ -0,0 +1,11 @@ +{# + Generates C++ function call + @param ast ASTFunctionCall +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- if ast.get_name() == PredefinedFunctions.EMIT_SPIKE %} +{%- include "cm_directives_cpp/PredefinedFunction_emit_spike.jinja2" %} +{%- else %} +{# call to a non-predefined function #} +{{ printer.print(ast) }}; +{%- endif %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/IfStatement.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/IfStatement.jinja2 new file mode 100644 index 000000000..4ec143389 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/IfStatement.jinja2 @@ -0,0 +1,33 @@ +{# + Generates C++ if..then..else statement + @param ast ASTIfStmt +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +if ({{ printer.print(ast.get_if_clause().get_condition()) }}) +{ +{%- filter indent(2, True) %} +{%- with ast = ast.get_if_clause().get_stmts_body() %} +{%- include "cm_directives_cpp/StmtsBody.jinja2" %} +{%- endwith %} +{%- endfilter %} +{%- for elif in ast.get_elif_clauses() %} +} +else if ({{ printer.print(elif.get_condition()) }}) +{ +{%- filter indent(2, True) %} +{%- with ast = elif.get_stmts_body() %} +{%- include "cm_directives_cpp/StmtsBody.jinja2" %} +{%- endwith %} +{%- endfilter %} +{%- endfor %} +{%- if ast.has_else_clause() %} +} +else +{ +{%- filter indent(2, True) %} +{%- with ast = ast.get_else_clause().get_stmts_body() %} +{%- include "cm_directives_cpp/StmtsBody.jinja2" %} +{%- endwith %} +{%- endfilter %} +{%- endif %} +} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/PredefinedFunction_emit_spike.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/PredefinedFunction_emit_spike.jinja2 new file mode 100644 index 000000000..02da8a9f9 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/PredefinedFunction_emit_spike.jinja2 @@ -0,0 +1,19 @@ +{# + Generates code for emit_spike() function call + @param ast ASTFunctionCall +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} + +/** + * generated code for emit_spike() function +**/ +{% if ast.get_args() | length == 0 %} +{#- no parameters -- emit_spike() called from within neuron #} +set_spiketime(nest::Time::step(origin.get_steps() + lag + 1)); +nest::SpikeEvent se; +nest::kernel().event_delivery_manager.send(*this, se, lag); +{%- else %} +delayed_spikes[{{ printer.std_vector_parameter }}].push({{ printer.print(ast.get_args()[0]) }}); +{%- endif %} + + diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/SmallStatement.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/SmallStatement.jinja2 new file mode 100644 index 000000000..cd9631768 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/SmallStatement.jinja2 @@ -0,0 +1,22 @@ +{# + Generates a single small statement into equivalent C++ syntax. + @param stmt ASTSmallStmt +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- if stmt.is_assignment() %} +{%- with ast = stmt.get_assignment() %} +{%- include "directives_cpp/Assignment.jinja2" %} +{%- endwith %} +{%- elif stmt.is_function_call() %} +{%- with ast = stmt.get_function_call() %} +{%- include "cm_directives_cpp/FunctionCall.jinja2" %} +{%- endwith %} +{%- elif stmt.is_declaration() %} +{%- with ast = stmt.get_declaration() %} +{%- include "cm_directives_cpp/Declaration.jinja2" %} +{%- endwith %} +{%- elif stmt.is_return_stmt() %} +{%- with ast = stmt.get_return_stmt() %} +{%- include "directives_cpp/ReturnStatement.jinja2" %} +{%- endwith %} +{%- endif %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Statement.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Statement.jinja2 new file mode 100644 index 000000000..b0cade8e0 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Statement.jinja2 @@ -0,0 +1,16 @@ +{# + Generates a single statement, either a simple or compound, to equivalent C++ syntax. + @param ast ASTSmallStmt or ASTCompoundStmt +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- if stmt.has_comment() %} +{{stmt.print_comment('//')}}{%- endif %} +{%- if stmt.is_small_stmt() %} +{%- with stmt = stmt.small_stmt %} +{%- include "cm_directives_cpp/SmallStatement.jinja2" %} +{%- endwith %} +{%- elif stmt.is_compound_stmt() %} +{%- with stmt = stmt.compound_stmt %} +{%- include "cm_directives_cpp/CompoundStatement.jinja2" %} +{%- endwith %} +{%- endif %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/StmtsBody.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/StmtsBody.jinja2 new file mode 100644 index 000000000..7bdea7e83 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/StmtsBody.jinja2 @@ -0,0 +1,11 @@ +{# + Handles an ASTStmtsBody +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- for statement in ast.get_stmts() %} +{%- with stmt = statement %} +{%- filter indent(2) %} +{%- include "cm_directives_cpp/Statement.jinja2" %} +{%- endfilter %} +{%- endwith %} +{%- endfor %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2 index 3cc3e4730..fcbee900c 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2 @@ -38,10 +38,10 @@ along with NEST. If not, see . {{ pure_variable_name~"_"~ion_channel_name }} {%- endmacro -%} -{% macro render_time_resolution_variable(synapse_info) -%} +{% macro render_time_resolution_variable(receptor_info) -%} {# we assume here that there is only one such variable ! #} {%- with %} -{%- for analytic_helper_name, analytic_helper_info in synapse_info["analytic_helpers"].items() -%} +{%- for analytic_helper_name, analytic_helper_info in receptor_info["analytic_helpers"].items() -%} {%- if analytic_helper_info["is_time_resolution"] -%} {{ analytic_helper_name }} {%- endif -%} @@ -82,7 +82,8 @@ along with NEST. If not, see . {% macro render_channel_function(function, ion_channel_name) -%} {%- with %} -inline {{ function_declaration.FunctionDeclaration(function, "nest::"~ion_channel_name~cm_unique_suffix~"::") }} + {%- set printer = printer_no_origin %} +inline {{ function_declaration.FunctionDeclaration(function, "nest::"~ion_channel_name~cm_unique_suffix~"::", true) }} { {%- filter indent(2,True) %} {%- with ast = function.get_stmts_body() %} @@ -95,11 +96,11 @@ inline {{ function_declaration.FunctionDeclaration(function, "nest::"~ion_channe {% macro render_vectorized_channel_function(function, ion_channel_name) -%} {%- with %} -{{ vectorized_function_declaration.FunctionDeclaration(function, "nest::"~ion_channel_name~cm_unique_suffix~"::") }} +{{ vectorized_function_declaration.FunctionDeclaration(function, "nest::"~ion_channel_name~cm_unique_suffix~"::", true) }} { {%- filter indent(2,True) %} {%- with ast = function.get_stmts_body() %} -{%- include "directives_cpp/VectorizedStmtsBody.jinja2" %} +{%- include "directives_cpp/VectorizedBlock.jinja2" %} {%- endwith %} {%- endfilter %} } @@ -127,6 +128,13 @@ inline {{ function_declaration.FunctionDeclaration(function, "nest::"~ion_channe {%- endwith -%} {%- endmacro -%} +{% macro render_variable_type(variable) -%} +{%- with -%} + {%- set symbol = variable.get_scope().resolve_to_symbol(variable.name, SymbolKind.VARIABLE) -%} + {{ types_printer.print(symbol.type_symbol) }} +{%- endwith -%} +{%- endmacro %} + {%- with %} {%- for ion_channel_name, channel_info in chan_info.items() %} @@ -157,22 +165,49 @@ void nest::{{ion_channel_name}}{{cm_unique_suffix}}::new_channel(std::size_t com // state variable {{pure_variable_name}} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+ion_channel_name+"_channel_count").print(rhs_expression) -}}); + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count-1") -}}); {%- endfor %} {% for variable_type, variable_info in channel_info["Parameters"].items() %} // channel parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+ion_channel_name+"_channel_count").print(rhs_expression) -}}); + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count-1") -}}); {%- endfor %} {% for variable_type, variable_info in channel_info["Internals"].items() %} // channel parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+ion_channel_name+"_channel_count").print(rhs_expression) -}}); + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count-1") -}}); {%- endfor %} + + {% for state in channel_info["Dependencies"]["global"] %} + {{ printer_no_origin.print(state) }}.push_back(0); + {% endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for variable_info in channel_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + self_spikes.push_back(false); } } @@ -208,7 +243,7 @@ void nest::{{ion_channel_name}}{{cm_unique_suffix}}::new_channel(std::size_t com // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+ion_channel_name+"_channel_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count-1") -}}); {%- endfor %} {%- with %} @@ -235,7 +270,7 @@ void nest::{{ion_channel_name}}{{cm_unique_suffix}}::new_channel(std::size_t com // channel parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+ion_channel_name+"_channel_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count-1") -}}); {%- endfor %} {%- with %} @@ -252,8 +287,35 @@ void nest::{{ion_channel_name}}{{cm_unique_suffix}}::new_channel(std::size_t com // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+ion_channel_name+"_channel_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count-1") -}}); {%- endfor %} + + {% for state in channel_info["Dependencies"]["global"] %} + {{ printer_no_origin.print(state) }}.push_back(0); + {% endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for variable_info in channel_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + self_spikes.push_back(false); } } @@ -275,15 +337,54 @@ nest::{{ion_channel_name}}{{cm_unique_suffix}}::append_recordables(std::map< Nam } if(!found_rec) ( *recordables )[ Name( std::string("{{variable.name}}") + std::to_string(compartment_idx))] = &zero_recordable; {%- endfor %} + {%- for variable_info in channel_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable_name = variable_info.variable_name %} + found_rec = false; + for(size_t chan_id = 0; chan_id < neuron_{{ ion_channel_name }}_channel_count; chan_id++){ + if(compartment_association[chan_id] == compartment_idx){ + ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &{{variable_name}}[chan_id]; + found_rec = true; + } + } + if(!found_rec) ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &zero_recordable; + {%- endif %} + {%- endfor %} {% endwith %} found_rec = false; for(size_t chan_id = 0; chan_id < neuron_{{ ion_channel_name }}_channel_count; chan_id++){ if(compartment_association[chan_id] == compartment_idx){ - ( *recordables )[ Name( std::string("i_tot_{{ion_channel_name}}") + std::to_string(compartment_idx))] = &i_tot_{{ion_channel_name}}[chan_id]; + ( *recordables )[ Name( std::string("{{ion_channel_name}}") + std::to_string(compartment_idx))] = &i_tot_{{ion_channel_name}}[chan_id]; found_rec = true; } } - if(!found_rec) ( *recordables )[ Name( std::string("i_tot_{{ion_channel_name}}") + std::to_string(compartment_idx))] = &zero_recordable; + if(!found_rec) ( *recordables )[ Name( std::string("{{ion_channel_name}}") + std::to_string(compartment_idx))] = &zero_recordable; +} + + // initialization channel +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} +void calibrate() { +{%- else %} +void nest::{{ion_channel_name}}{{cm_unique_suffix}}::pre_run_hook() { +{%- endif %} + {% if "time_resolution_var" in channel_info %} + std::vector< double > {{ printer_no_origin.print(channel_info["time_resolution_var"]) }}(neuron_{{ ion_channel_name }}_channel_count, Time::get_resolution().get_ms()); + {% endif %} + for(std::size_t i = 0; i < neuron_{{ ion_channel_name }}_channel_count; i++){ + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}[i] = 0; + {%- endfor %} + {%- endfor %} + } } std::pair< std::vector< double >, std::vector< double > > nest::{{ion_channel_name}}{{cm_unique_suffix}}::f_numstep(std::vector< double > v_comp{% for ode in channel_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} @@ -296,7 +397,9 @@ std::pair< std::vector< double >, std::vector< double > > nest::{{ion_channel_na std::vector< double > d_i_tot_dv(neuron_{{ ion_channel_name }}_channel_count, 0.); - {% if channel_info["ODEs"].items()|length %} std::vector< double > {{ printer_no_origin.print(channel_info["time_resolution_var"]) }}(neuron_{{ ion_channel_name }}_channel_count, Time::get_resolution().get_ms()); {% endif %} + {% if "time_resolution_var" in channel_info %} + std::vector< double > {{ printer_no_origin.print(channel_info["time_resolution_var"]) }}(neuron_{{ ion_channel_name }}_channel_count, Time::get_resolution().get_ms()); + {% endif %} {%- for ode_variable, ode_info in channel_info["ODEs"].items() %} {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} @@ -305,12 +408,19 @@ std::pair< std::vector< double >, std::vector< double > > nest::{{ion_channel_na {%- endfor %} #pragma omp simd for(std::size_t i = 0; i < neuron_{{ ion_channel_name }}_channel_count; i++){ + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {{state_variable_name}}[i] += self_spikes[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + {%- for ode_variable, ode_info in channel_info["ODEs"].items() %} {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} - {{ propagator }}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(propagator_info["init_expression"]) }}; + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; {%- endfor %} {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} - {{state}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(state_solution_info["update_expression"]) }}; + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; {%- endfor %} {%- endfor %} @@ -318,17 +428,68 @@ std::pair< std::vector< double >, std::vector< double > > nest::{{ion_channel_na {%- set inline_expression_d = channel_info["inline_derivative"] %} // compute the conductance of the {{ion_channel_name}} channel - this->i_tot_{{ion_channel_name}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(inline_expression.get_expression()) }}; + this->i_tot_{{ion_channel_name}}[i] = {{ vector_printer.print(inline_expression.get_expression(), "i") }}; // derivative - d_i_tot_dv[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(inline_expression_d) }}; + d_i_tot_dv[i] = {{ vector_printer.print(inline_expression_d, "i") }}; g_val[i] = - d_i_tot_dv[i]; i_val[i] = this->i_tot_{{ion_channel_name}}[i] - d_i_tot_dv[i] * v_comp[i]; } + f_update(); + + //update recordable inlines + for(std::size_t i = 0; i < neuron_{{ ion_channel_name }}_channel_count; i++){ + {%- for variable_info in channel_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + {{ variable }}[i] = {{ vector_printer.print(rhs_expression, "i") }}; + {%- endif %} + {%- endfor %} + } + return std::make_pair(g_val, i_val); } +void nest::{{ion_channel_name}}{{cm_unique_suffix}}::f_update() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ ion_channel_name }}_channel_count; i++){ + {%- if channel_info["Blocks"] %} + {%- if channel_info["Blocks"]["UpdateBlock"] %} + {%- set function = channel_info["Blocks"]["UpdateBlock"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + self_spikes[i] = false; + } +} + +void nest::{{ion_channel_name}}{{cm_unique_suffix}}::f_self_spike() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ ion_channel_name }}_channel_count; i++){ + self_spikes[i] = true; + {%- if channel_info["Blocks"] %} + {%- if channel_info["Blocks"]["SelfSpikesFunction"] %} + {%- set function = channel_info["Blocks"]["SelfSpikesFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + } +} + {%- for function in channel_info["Functions"] %} {{render_channel_function(function, ion_channel_name)}} {%- endfor %} @@ -381,26 +542,60 @@ void nest::{{concentration_name}}{{cm_unique_suffix}}::new_concentration(std::si {{concentration_name}}.push_back(0); compartment_association.push_back(comp_ass); + {%- with %} + {%- for variable_type, variable_info in concentration_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + {%- set rhs_expression = variable_info["transformed_solutions"][0]["states"][variable_name]["init_expression"] %} + {%- if variable_name == concentration_name %} + {{ concentration_name }}[neuron_{{ concentration_name }}_concentration_count-1] = {{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count-1") -}}; + {%- endif %} + {%- endfor %} + {% endwith %} + {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+concentration_name+"_concentration_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count-1") -}}); {%- endfor %} {% for variable_type, variable_info in concentration_info["Parameters"].items() %} // channel parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+concentration_name+"_concentration_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count-1") -}}); {%- endfor %} {% for variable_type, variable_info in concentration_info["Internals"].items() %} // channel parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+concentration_name+"_concentration_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count-1") -}}); {%- endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for variable_info in concentration_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + self_spikes.push_back(false); } } @@ -431,11 +626,23 @@ void nest::{{concentration_name}}{{cm_unique_suffix}}::new_concentration(std::si {{concentration_name}}.push_back(0); compartment_association.push_back(comp_ass); + {%- with %} + {%- for variable_type, variable_info in concentration_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + {%- set rhs_expression = variable_info["transformed_solutions"][0]["states"][variable_name]["init_expression"] %} + {%- if variable_name == concentration_name %} + {{ concentration_name }}[neuron_{{ concentration_name }}_concentration_count-1] = {{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count-1") -}}; + {%- endif %} + {%- endfor %} + {% endwith %} + + {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+concentration_name+"_concentration_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count-1") -}}); {%- endfor %} {%- with %} @@ -462,7 +669,7 @@ void nest::{{concentration_name}}{{cm_unique_suffix}}::new_concentration(std::si // channel parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+concentration_name+"_concentration_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count-1") -}}); {%- endfor %} {%- with %} @@ -479,8 +686,31 @@ void nest::{{concentration_name}}{{cm_unique_suffix}}::new_concentration(std::si // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+concentration_name+"_concentration_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count-1") -}}); {%- endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for variable_info in concentration_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + self_spikes.push_back(false); } } @@ -502,6 +732,20 @@ nest::{{ concentration_name }}{{cm_unique_suffix}}::append_recordables(std::map< } if(!found_rec) ( *recordables )[ Name( std::string("{{variable.name}}") + std::to_string(compartment_idx))] = &zero_recordable; {%- endfor %} + + {%- for variable_info in concentration_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable_name = variable_info.variable_name %} + found_rec = false; + for(size_t conc_id = 0; conc_id < neuron_{{ concentration_name }}_concentration_count; conc_id++){ + if(compartment_association[chan_id] == compartment_idx){ + ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &{{variable_name}}[conc_id]; + found_rec = true; + } + } + if(!found_rec) ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &zero_recordable; + {%- endif %} + {%- endfor %} {% endwith %} found_rec = false; for(size_t conc_id = 0; conc_id < neuron_{{ concentration_name }}_concentration_count; conc_id++){ @@ -513,12 +757,47 @@ nest::{{ concentration_name }}{{cm_unique_suffix}}::append_recordables(std::map< if(!found_rec) ( *recordables )[ Name( std::string("{{concentration_name}}") + std::to_string(compartment_idx))] = &zero_recordable; } + // initialization concentration +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} +void nest::{{ concentration_name }}{{cm_unique_suffix}}::calibrate() { +{%- else %} +void nest::{{ concentration_name }}{{cm_unique_suffix}}::pre_run_hook() { +{%- endif %} + {% if "time_resolution_var" in concentration_info %} + std::vector< double > {{ printer_no_origin.print(concentration_info["time_resolution_var"]) }}(neuron_{{ concentration_name }}_concentration_count, Time::get_resolution().get_ms()); + {% endif %} + for(std::size_t concentration_id = 0; concentration_id < neuron_{{ concentration_name }}_concentration_count; concentration_id++){ + // states + {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name }}[concentration_id] = {{ vector_printer.print(rhs_expression, "concentration_id") }}; + {%- endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}[concentration_id] = {{ vector_printer.print(state_variable_info["init_expression"], "concentration_id") }}; + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}[concentration_id] = 0; + {%- endfor %} + {%- endfor %} + } +} + void nest::{{ concentration_name }}{{cm_unique_suffix}}::f_numstep(std::vector< double > v_comp{% for ode in concentration_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} {% endif %}{% for inline in concentration_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["continuous"]|length %} {% endif %}{% for inline in concentration_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}) { + {% if "time_resolution_var" in concentration_info %} std::vector< double > {{ printer_no_origin.print(concentration_info["time_resolution_var"]) }}(neuron_{{ concentration_name }}_concentration_count, Time::get_resolution().get_ms()); + {% endif %} {%- for ode_variable, ode_info in concentration_info["ODEs"].items() %} {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} @@ -528,15 +807,73 @@ void nest::{{ concentration_name }}{{cm_unique_suffix}}::f_numstep(std::vector< #pragma omp simd for(std::size_t i = 0; i < neuron_{{ concentration_name }}_concentration_count; i++){ + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {{state_variable_name}}[i] += self_spikes[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + {%- for ode_variable, ode_info in concentration_info["ODEs"].items() %} {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} - {{ propagator }}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(propagator_info["init_expression"]) }}; + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; {%- endfor %} {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} - {{state}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(state_solution_info["update_expression"]) }}; + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; {%- endfor %} {%- endfor %} } + + f_update(); + + //update recordable inlines + for(std::size_t i = 0; i < neuron_{{ concentration_name }}_concentration_count; i++){ + {%- for variable_info in concentration_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + {{ variable }}[i] = {{ vector_printer.print(rhs_expression, "i") }}; + {%- endif %} + {%- endfor %} + } +} + +void nest::{{concentration_name}}{{cm_unique_suffix}}::f_update() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ concentration_name }}_concentration_count; i++){ + {%- if concentration_info["Blocks"] %} + {%- if concentration_info["Blocks"]["UpdateBlock"] %} + {%- set function = concentration_info["Blocks"]["UpdateBlock"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + self_spikes[i] = false; + } +} + +void nest::{{concentration_name}}{{cm_unique_suffix}}::f_self_spike() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ concentration_name }}_concentration_count; i++){ + self_spikes[i] = true; + {%- if concentration_info["Blocks"] %} + {%- if concentration_info["Blocks"]["SelfSpikesFunction"] %} + {%- set function = concentration_info["Blocks"]["SelfSpikesFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + } } {%- for function in concentration_info["Functions"] %} @@ -565,238 +902,326 @@ std::vector< double > nest::{{concentration_name}}{{cm_unique_suffix}}::distribu {% endwith %} -////////////////////////////////////// synapses +////////////////////////////////////// receptors -{%- for synapse_name, synapse_info in syns_info.items() %} -// {{synapse_name}} synapse //////////////////////////////////////////////////////////////// +{%- for receptor_name, receptor_info in recs_info.items() %} +// {{receptor_name}} receptor //////////////////////////////////////////////////////////////// -void nest::{{synapse_name}}{{cm_unique_suffix}}::new_synapse(std::size_t comp_ass, const long syn_index) +void nest::{{receptor_name}}{{cm_unique_suffix}}::new_receptor(std::size_t comp_ass, const long rec_index) { - neuron_{{ synapse_name }}_synapse_count++; - i_tot_{{synapse_name}}.push_back(0); + neuron_{{ receptor_name }}_receptor_count++; + i_tot_{{receptor_name}}.push_back(0); compartment_association.push_back(comp_ass); - syn_idx.push_back(syn_index); + rec_idx.push_back(rec_index); - {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+synapse_name+"_synapse_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); {%- endfor %} - {% for variable_type, variable_info in synapse_info["Parameters"].items() %} - // synapse parameter {{variable_type }} + {% for variable_type, variable_info in receptor_info["Parameters"].items() %} + // receptor parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+synapse_name+"_synapse_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); {%- endfor %} // set propagators to ode toolbox returned value - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} {{state_variable_name}}.push_back(0); {%- endfor %} {%- endfor %} // initial values for kernel state variables, set to zero - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} {{state_variable_name}}.push_back(0); {%- endfor %} {%- endfor %} // user declared internals in order they were declared - {%- for internal_name, internal_declaration in synapse_info["internals_used_declared"] %} + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} {{internal_name}}.push_back(0); {%- endfor %} + + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + self_spikes.push_back(false); } -void nest::{{synapse_name}}{{cm_unique_suffix}}::new_synapse(std::size_t comp_ass, const long syn_index, const DictionaryDatum& synapse_params) -/* update {{synapse}} synapse parameters and states */ +void nest::{{receptor_name}}{{cm_unique_suffix}}::new_receptor(std::size_t comp_ass, const long rec_index, const DictionaryDatum& receptor_params) +/* update {{receptor}} receptor parameters and states */ { - neuron_{{ synapse_name }}_synapse_count++; + neuron_{{ receptor_name }}_receptor_count++; compartment_association.push_back(comp_ass); - i_tot_{{synapse_name}}.push_back(0); - syn_idx.push_back(syn_index); + i_tot_{{receptor_name}}.push_back(0); + rec_idx.push_back(rec_index); - {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+synapse_name+"_synapse_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); {%- endfor %} {%- with %} - {%- for variable_type, variable_info in synapse_info["States"].items() %} + {%- for variable_type, variable_info in receptor_info["States"].items() %} {%- set variable = variable_info["ASTVariable"] %} - if( synapse_params->known( "{{variable.name}}" ) ) - {{variable.name}}[neuron_{{ synapse_name }}_synapse_count-1] = getValue< double >( synapse_params, "{{variable.name}}" ); + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); {%- endfor %} {% endwith %} {%- with %} - {%- for variable_type, variable_info in synapse_info["ODEs"].items() %} + {%- for variable_type, variable_info in receptor_info["ODEs"].items() %} {%- set variable_name = variable_type %} {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} // {{concentration_name}} concentration ODE state {{dynamic_variable }} - if( synapse_params->known( "{{variable_name}}" ) ) - {{variable_name}}[neuron_{{ synapse_name }}_synapse_count-1] = getValue< double >( synapse_params, "{{variable_name}}" ); + if( receptor_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable_name}}" ); {%- endfor %} {% endwith %} - {% for variable_type, variable_info in synapse_info["Parameters"].items() %} - // synapse parameter {{variable_type }} + {% for variable_type, variable_info in receptor_info["Parameters"].items() %} + // receptor parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+synapse_name+"_synapse_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); {%- endfor %} {%- with %} - {%- for variable_type, variable_info in synapse_info["Parameters"].items() %} + {%- for variable_type, variable_info in receptor_info["Parameters"].items() %} {%- set variable = variable_info["ASTVariable"] %} - if( synapse_params->known( "{{variable.name}}" ) ) - {{variable.name}}[neuron_{{ synapse_name }}_synapse_count-1] = getValue< double >( synapse_params, "{{variable.name}}" ); + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); {%- endfor %} {% endwith %} // set propagators to ode toolbox returned value - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} {{state_variable_name}}.push_back(0); {%- endfor %} {%- endfor %} // initial values for kernel state variables, set to zero - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} {{state_variable_name}}.push_back(0); {%- endfor %} {%- endfor %} // user declared internals in order they were declared - {%- for internal_name, internal_declaration in synapse_info["internals_used_declared"] %} + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} {{internal_name}}.push_back(0); {%- endfor %} + + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + self_spikes.push_back(false); } void -nest::{{synapse_name}}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, const long compartment_idx) +nest::{{receptor_name}}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, const long compartment_idx) { - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} - for(size_t syns_id = 0; syns_id < neuron_{{ synapse_name }}_synapse_count; syns_id++){ - if(compartment_association[syns_id] == compartment_idx){ - ( *recordables )[ Name( "{{convolution_info["kernel"]["name"]}}" + std::to_string(syns_id) )] = &{{convolution}}[syns_id]; + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[recs_id] == compartment_idx){ + ( *recordables )[ Name( "{{convolution_info["kernel"]["name"]}}" + std::to_string(rec_idx[recs_id]) )] = &{{convolution}}[recs_id]; + } + } + + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable_name = variable_info.variable_name %} + found_rec = false; + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[chan_id] == compartment_idx){ + ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &{{variable_name}}[recs_id]; + found_rec = true; } } + if(!found_rec) ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &zero_recordable; + {%- endif %} {%- endfor %} - for(size_t syns_id = 0; syns_id < neuron_{{ synapse_name }}_synapse_count; syns_id++){ - if(compartment_association[syns_id] == compartment_idx){ - ( *recordables )[ Name( "i_tot_{{synapse_name}}" + std::to_string(syns_id) )] = &i_tot_{{synapse_name}}[syns_id]; + + {%- endfor %} + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[recs_id] == compartment_idx){ + ( *recordables )[ Name( "{{receptor_name}}" + std::to_string(rec_idx[recs_id]) )] = &i_tot_{{receptor_name}}[recs_id]; } } } {%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} -void nest::{{synapse_name}}{{cm_unique_suffix}}::calibrate() +void nest::{{receptor_name}}{{cm_unique_suffix}}::calibrate() {%- else %} -void nest::{{synapse_name}}{{cm_unique_suffix}}::pre_run_hook() +void nest::{{receptor_name}}{{cm_unique_suffix}}::pre_run_hook() {%- endif %} { - std::vector< double > {{ printer_no_origin.print(synapse_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ synapse_name }}_synapse_count, Time::get_resolution().get_ms()); + std::vector< double > {{ printer_no_origin.print(receptor_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); - {%- for state_name, state_declaration in synapse_info["States"].items() %} - std::vector< double > {{state_name}} = (neuron_{{ synapse_name }}_synapse_count, {{ printer_no_origin.print(state_declaration["rhs_expression"])}}); + {%- for state_name, state_declaration in receptor_info["States"].items() %} + std::vector< double > {{state_name}} = (neuron_{{ receptor_name }}_receptor_count, {{ printer_no_origin.print(state_declaration["rhs_expression"])}}); {%- endfor %} - for(std::size_t i = 0; i < neuron_{{ synapse_name }}_synapse_count; i++){ + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ // set propagators to ode toolbox returned value - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} - {{state_variable_name}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(state_variable_info["init_expression"]) }}; + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; {%- endfor %} {%- endfor %} // initial values for kernel state variables, set to zero - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} {{state_variable_name}}[i] = 0; {%- endfor %} {%- endfor %} // user declared internals in order they were declared - {%- for internal_name, internal_declaration in synapse_info["internals_used_declared"] %} - {{internal_name}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(internal_declaration.get_expression()) }}; + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}[i] = {{ vector_printer.print(internal_declaration.get_expression(), "i") }}; {%- endfor %} - {{synapse_info["buffer_name"]}}_[i]->clear(); + s_val = std::vector(neuron_{{ receptor_name }}_receptor_count, 0); + + {{receptor_info["buffer_name"]}}_[i]->clear(); } } -std::pair< std::vector< double >, std::vector< double > > nest::{{synapse_name}}{{cm_unique_suffix}}::f_numstep( std::vector< double > v_comp, const long lag {% for ode in synapse_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if synapse_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if synapse_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if synapse_info["Dependencies"]["continuous"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}) +std::pair< std::vector< double >, std::vector< double > > nest::{{receptor_name}}{{cm_unique_suffix}}::f_numstep(std::vector< double > v_comp, const long lag {% for ode in receptor_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if receptor_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}) { - std::vector< double > g_val(neuron_{{ synapse_name }}_synapse_count, 0.); - std::vector< double > i_val(neuron_{{ synapse_name }}_synapse_count, 0.); - std::vector< double > d_i_tot_dv(neuron_{{ synapse_name }}_synapse_count, 0.); + std::vector< double > g_val(neuron_{{ receptor_name }}_receptor_count, 0.); + std::vector< double > i_val(neuron_{{ receptor_name }}_receptor_count, 0.); + std::vector< double > d_i_tot_dv(neuron_{{ receptor_name }}_receptor_count, 0.); - {%- for ode_variable, ode_info in synapse_info["ODEs"].items() %} + {%- for ode_variable, ode_info in receptor_info["ODEs"].items() %} {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} - std::vector< double > {{ propagator }}(neuron_{{ synapse_name }}_synapse_count, 0); + std::vector< double > {{ propagator }}(neuron_{{ receptor_name }}_receptor_count, 0); {%- endfor %} {%- endfor %} - {% if synapse_info["ODEs"].items()|length %} std::vector< double > {{ printer_no_origin.print(synapse_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ synapse_name }}_synapse_count, Time::get_resolution().get_ms()); {% endif %} + {% if receptor_info["ODEs"].items()|length %} std::vector< double > {{ printer_no_origin.print(receptor_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); {% endif %} - std::vector < double > s_val(neuron_{{ synapse_name }}_synapse_count, 0); - - for(std::size_t i = 0; i < neuron_{{ synapse_name }}_synapse_count; i++){ + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ // get spikes - s_val[i] = {{synapse_info["buffer_name"]}}_[i]->get_value( lag ); // * g_norm_; + s_val[i] = {{receptor_info["buffer_name"]}}_[i]->get_value( lag ); // * g_norm_; } //update ODE state variable #pragma omp simd - for(std::size_t i = 0; i < neuron_{{ synapse_name }}_synapse_count; i++){ - {%- for ode_variable, ode_info in synapse_info["ODEs"].items() %} + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- for ode_variable, ode_info in receptor_info["ODEs"].items() %} {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} - {{ propagator }}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(propagator_info["init_expression"]) }}; + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; {%- endfor %} {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} - {{state}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(state_solution_info["update_expression"]) }}; + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; {%- endfor %} {%- endfor %} - // update kernel state variable / compute synaptic conductance - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + // update kernel state variable / compute recaptic conductance + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} - {{state_variable_name}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(state_variable_info["update_expression"]) }}; - {{state_variable_name}}[i] += s_val[i] * {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(state_variable_info["init_expression"]) }}; + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {%- if convolution_info["spikes"]["name"] == "self_spikes" %} + {{state_variable_name}}[i] += self_spikes[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- else %} + {{state_variable_name}}[i] += s_val[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endif %} {%- endfor %} {%- endfor %} // total current // this expression should be the transformed inline expression - this->i_tot_{{synapse_name}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(synapse_info["root_expression"].get_expression()) }}; + this->i_tot_{{receptor_name}}[i] = {{ vector_printer.print(receptor_info["root_expression"].get_expression(), "i") }}; // derivative of that expression // voltage derivative of total current // compute derivative with respect to current with sympy - d_i_tot_dv[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(synapse_info["inline_expression_d"]) }}; + d_i_tot_dv[i] = {{ vector_printer.print(receptor_info["inline_derivative"], "i") }}; // for numerical integration g_val[i] = - d_i_tot_dv[i]; - i_val[i] = this->i_tot_{{synapse_name}}[i] - d_i_tot_dv[i] * v_comp[i]; + i_val[i] = this->i_tot_{{receptor_name}}[i] - d_i_tot_dv[i] * v_comp[i]; } + f_update(); + + //update recordable inlines + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + {{ variable }}[i] = {{ vector_printer.print(rhs_expression, "i") }}; + {%- endif %} + {%- endfor %} + } + return std::make_pair(g_val, i_val); } -{%- for function in synapse_info["functions_used"] %} -inline {{ function_declaration.FunctionDeclaration(function, "nest::"~synapse_name~cm_unique_suffix~"::") }} +void nest::{{receptor_name}}{{cm_unique_suffix}}::f_update() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- if receptor_info["Blocks"] %} + {%- if receptor_info["Blocks"]["UpdateBlock"] %} + {%- set function = receptor_info["Blocks"]["UpdateBlock"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + self_spikes[i] = false; + } +} + +void nest::{{receptor_name}}{{cm_unique_suffix}}::f_self_spike() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + self_spikes[i] = true; + {%- if receptor_info["Blocks"] %} + {%- if receptor_info["Blocks"]["SelfSpikesFunction"] %} + {%- set function = receptor_info["Blocks"]["SelfSpikesFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + } +} + +{%- for function in receptor_info["functions_used"] %} +inline {{ function_declaration.FunctionDeclaration(function, "nest::"~receptor_name~cm_unique_suffix~"::", true) }} { {%- filter indent(2,True) %} {%- with ast = function.get_stmts_body() %} @@ -806,118 +1231,738 @@ inline {{ function_declaration.FunctionDeclaration(function, "nest::"~synapse_na } {%- endfor %} -void nest::{{synapse_name}}{{cm_unique_suffix}}::get_currents_per_compartment(std::vector< double >& compartment_to_current){ +void nest::{{receptor_name}}{{cm_unique_suffix}}::get_currents_per_compartment(std::vector< double >& compartment_to_current){ for(std::size_t comp_id = 0; comp_id < compartment_to_current.size(); comp_id++){ compartment_to_current[comp_id] = 0; } - for(std::size_t syn_id = 0; syn_id < neuron_{{ synapse_name }}_synapse_count; syn_id++){ - compartment_to_current[this->compartment_association[syn_id]] += this->i_tot_{{synapse_name}}[syn_id]; + for(std::size_t rec_id = 0; rec_id < neuron_{{ receptor_name }}_receptor_count; rec_id++){ + compartment_to_current[this->compartment_association[rec_id]] += this->i_tot_{{receptor_name}}[rec_id]; } } -std::vector< double > nest::{{synapse_name}}{{cm_unique_suffix}}::distribute_shared_vector(std::vector< double > shared_vector){ - std::vector< double > distributed_vector(this->neuron_{{ synapse_name }}_synapse_count, 0.0); - for(std::size_t syn_id = 0; syn_id < this->neuron_{{ synapse_name }}_synapse_count; syn_id++){ - distributed_vector[syn_id] = shared_vector[compartment_association[syn_id]]; +std::vector< double > nest::{{receptor_name}}{{cm_unique_suffix}}::distribute_shared_vector(std::vector< double > shared_vector){ + std::vector< double > distributed_vector(this->neuron_{{ receptor_name }}_receptor_count, 0.0); + for(std::size_t rec_id = 0; rec_id < this->neuron_{{ receptor_name }}_receptor_count; rec_id++){ + distributed_vector[rec_id] = shared_vector[compartment_association[rec_id]]; } return distributed_vector; } -// {{synapse_name}} synapse end /////////////////////////////////////////////////////////// +// {{receptor_name}} receptor end /////////////////////////////////////////////////////////// {%- endfor %} +////////////////////////////////////// receptors with synapses attached +{%- for synapse_name, synapse_info in syns_info.items() %} +{%- for receptor_name, receptor_info in recs_info.items() %} +// {{receptor_name}} receptor //////////////////////////////////////////////////////////////// -////////////////////////////////////// continuous inputs - -{%- for continuous_name, continuous_info in con_in_info.items() %} -// {{continuous_name}} continuous input /////////////////////////////////////////////////// - -void nest::{{continuous_name}}{{cm_unique_suffix}}::new_continuous_input(std::size_t comp_ass, const long con_in_index) +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::new_receptor(std::size_t comp_ass, const long syn_index) { - neuron_{{ continuous_name }}_continuous_input_count++; - i_tot_{{continuous_name}}.push_back(0); + neuron_{{ receptor_name }}_receptor_count++; + i_tot_{{receptor_name}}.push_back(0); compartment_association.push_back(comp_ass); - continuous_idx.push_back(con_in_index); + syn_idx.push_back(syn_index); + delayed_spikes.push_back(TimerDeque(Time::get_resolution().get_ms())); - {%- for pure_variable_name, variable_info in continuous_info["States"].items() %} + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+continuous_name+"_continuous_input_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); {%- endfor %} - {% for variable_type, variable_info in continuous_info["Parameters"].items() %} - // parameter {{variable_type }} + {% for variable_type, variable_info in receptor_info["Parameters"].items() %} + // receptor parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+continuous_name+"_continuous_input_count").print(rhs_expression) -}}); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); {%- endfor %} + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + // user declared internals in order they were declared - {%- for internal_name, internal_declaration in continuous_info["internals_used_declared"] %} + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} {{internal_name}}.push_back(0); {%- endfor %} -} -void nest::{{continuous_name}}{{cm_unique_suffix}}::new_continuous_input(std::size_t comp_ass, const long con_in_index, const DictionaryDatum& con_in_params) -/* update {{continuous_name}} continuous input parameters and states */ -{ - neuron_{{ continuous_name }}_continuous_input_count++; - compartment_association.push_back(comp_ass); - i_tot_{{continuous_name}}.push_back(0); - continuous_idx.push_back(con_in_index); + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} - {%- for pure_variable_name, variable_info in continuous_info["States"].items() %} + + //synapse components: + {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} // state variable {{pure_variable_name }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+continuous_name+"_continuous_input_count").print(rhs_expression) -}}); - {%- endfor %} - {%- for variable_type, variable_info in continuous_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - if( con_in_params->known( "{{variable.name}}" ) ) - {{variable.name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable.name}}" ); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); {%- endfor %} - {%- with %} - {%- for variable_type, variable_info in continuous_info["ODEs"].items() %} - {%- set variable_name = variable_type %} - {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} - // {{continuous_name}} concentration ODE state {{dynamic_variable }} - if( con_in_params->known( "{{variable_name}}" ) ) - {{variable_name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable_name}}" ); - {%- endfor %} - {% endwith %} - - {% for variable_type, variable_info in continuous_info["Parameters"].items() %} - // continuous parameter {{variable_type }} + {% for variable_type, variable_info in synapse_info["Parameters"].items() %} + // receptor parameter {{variable_type }} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name}}.push_back({{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("neuron_"+continuous_name+"_continuous_input_count").print(rhs_expression) -}}); - {%- endfor %} - - {%- with %} - {%- for variable_type, variable_info in continuous_info["Parameters"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - if( con_in_params->known( "{{variable.name}}" ) ) - {{variable.name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable.name}}" ); + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); {%- endfor %} - {% endwith %} // user declared internals in order they were declared - {%- for internal_name, internal_declaration in continuous_info["internals_used_declared"] %} + {%- for internal_name, internal_declaration in synapse_info["Internals"].items() %} {{internal_name}}.push_back(0); {%- endfor %} -} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for inline_name, inline in synapse_info["Inlines"].items() %} + {{inline_name}}.push_back(0); + {%- endfor %} + + {%- with %} + {%- for in_function_declaration in synapse_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + {{variable.get_symbol_name()}}.push_back(0); + {%- endfor %} + {%- endfor %} + {%- endwith %} + + self_spikes.push_back(false); +} + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::new_receptor(std::size_t comp_ass, const long syn_index, const DictionaryDatum& receptor_params) +// update {{receptor}} receptor parameters +{ + neuron_{{ receptor_name }}_receptor_count++; + compartment_association.push_back(comp_ass); + i_tot_{{receptor_name}}.push_back(0); + syn_idx.push_back(syn_index); + + if( receptor_params->known( "{{synapse_info["DelayVariable"]}}" ) ) + delayed_spikes.push_back(TimerDeque(getValue< double >( receptor_params, "{{synapse_info["DelayVariable"]}}" ))); + + else delayed_spikes.push_back(TimerDeque(Time::get_resolution().get_ms())); + + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + {%- with %} + {%- for variable_type, variable_info in receptor_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- with %} + {%- for variable_type, variable_info in receptor_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{concentration_name}} concentration ODE state {{dynamic_variable }} + if( receptor_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable_name}}" ); + {%- endfor %} + {% endwith %} + + {% for variable_type, variable_info in receptor_info["Parameters"].items() %} + // receptor parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in receptor_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} + + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + + //synapse components: + {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); + {%- endfor %} + {%- for variable_type, variable_info in synapse_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + + {% for variable_type, variable_info in synapse_info["Parameters"].items() %} + // receptor parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count-1") -}}); + {%- endfor %} + {%- for variable_type, variable_info in synapse_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in synapse_info["Internals"].items() %} + {{internal_name}}.push_back(0); + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for inline_name, inline in synapse_info["Inlines"].items() %} + {{inline_name}}.push_back(0); + {%- endfor %} + + {%- for in_function_declaration in synapse_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + {{variable.get_symbol_name()}}.push_back(0); + {%- endfor %} + {%- endfor %} + + self_spikes.push_back(false); +} + +void +nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::append_recordables(std::map< Name, double* >* recordables, const long compartment_idx) +{ + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[recs_id] == compartment_idx){ + ( *recordables )[ Name( "{{convolution_info["kernel"]["name"]}}" + std::to_string(syn_idx[recs_id]) )] = &{{convolution}}[recs_id]; + } + } + {%- endfor %} + + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable_name = variable_info.variable_name %} + found_rec = false; + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[chan_id] == compartment_idx){ + ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &{{variable_name}}[recs_id]; + found_rec = true; + } + } + if(!found_rec) ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &zero_recordable; + {%- endif %} + {%- endfor %} + + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[recs_id] == compartment_idx){ + ( *recordables )[ Name( "{{receptor_name}}_{{synapse_name}}" + std::to_string(syn_idx[recs_id]) )] = &i_tot_{{receptor_name}}[recs_id]; + } + } + + //synapse states + {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[recs_id] == compartment_idx){ + ( *recordables )[ Name( "{{pure_variable_name}}" + std::to_string(syn_idx[recs_id]) )] = &{{pure_variable_name}}[recs_id]; + } + } + {%- endfor %} + + {%- for inline_name, inline in synapse_info["Inlines"].items() %} + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[recs_id] == compartment_idx){ + ( *recordables )[ Name( "{{inline_name}}" + std::to_string(syn_idx[recs_id]) )] = &{{inline_name}}[recs_id]; + } + } + {%- endfor %} +} + +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::calibrate() +{%- else %} +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::pre_run_hook() +{%- endif %} +{ + + std::vector< double > {{ printer_no_origin.print(receptor_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); + + {%- for state_name, state_declaration in receptor_info["States"].items() %} + std::vector< double > {{state_name}} = (neuron_{{ receptor_name }}_receptor_count, {{ printer_no_origin.print(state_declaration["rhs_expression"])}}); + {%- endfor %} + + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}[i] = 0; + {%- endfor %} + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}[i] = {{ vector_printer.print(internal_declaration.get_expression(), "i") }}; + {%- endfor %} + + s_val = std::vector(neuron_{{ receptor_name }}_receptor_count, 0); + + {{receptor_info["buffer_name"]}}_[i]->clear(); + } +} + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::postsynaptic_synaptic_processing(){ + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++) { + {%- set function = synapse_info["PostSpikeFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + } +} + {%- with %} + {%- set conc_dep = set(receptor_info["Dependencies"]["concentrations"]).union(synapse_info["Dependencies"]["concentrations"])%} + {%- set rec_dep = set(receptor_info["Dependencies"]["receptors"]).union(synapse_info["Dependencies"]["receptors"])%} + {%- set chan_dep = set(receptor_info["Dependencies"]["channels"]).union(synapse_info["Dependencies"]["channels"])%} + {%- set con_in_dep = set(receptor_info["Dependencies"]["continuous"]).union(synapse_info["Dependencies"]["continuous"])%} +std::pair< std::vector< double >, std::vector< double > > nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::f_numstep(std::vector< double > v_comp, const long lag {% for ode in conc_dep %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if rec_dep|length %} + {% endif %}{% for inline in rec_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if chan_dep|length %} + {% endif %}{% for inline in chan_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if con_in_dep|length %} + {% endif %}{% for inline in con_in_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}) +{ + {%- endwith %} + + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + // get spikes + s_val[i] = {{receptor_info["buffer_name"]}}_[i]->get_value( lag ); // * g_norm_; + } + //synaptic processing: + //presynaptic spike processing + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + if(s_val[i]!=0) { + {%- set function = synapse_info["PreSpikeFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + } + } + //presynaptic spike processing end + //continuous synaptic processing + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + //inlines and convolutions + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {{state_variable_name}}[i] += {%- if convolution_info["post_port"] %}self_spikes[i]{%- else %}s_val[i]{%- endif %} * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + {%- for inline, inline_info in synapse_info["Inlines"].items() %} + {{ inline }}[i] = {{ vector_printer.print(inline_info["inline_expression"].get_expression(), "i") }}; + {%- endfor %} + //update block + {%- if synapse_info["UpdateBlock"] %} + {%- set function = synapse_info["UpdateBlock"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + } + + {% if synapse_info["ODEs"].items()|length %} + std::vector< double > {{ printer_no_origin.print(synapse_info["time_resolution_var"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); + {% endif %} + {%- for ode_variable, ode_info in synapse_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_{{ receptor_name }}_receptor_count, 0); + {%- endfor %} + {%- endfor %} + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- for ode_variable, ode_info in synapse_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + } + + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + // receive (potentially) delayed spikes + s_val[i] = delayed_spikes[i].tick(); + } + + std::vector< double > g_val(neuron_{{ receptor_name }}_receptor_count, 0.); + std::vector< double > i_val(neuron_{{ receptor_name }}_receptor_count, 0.); + std::vector< double > d_i_tot_dv(neuron_{{ receptor_name }}_receptor_count, 0.); + + {%- for ode_variable, ode_info in receptor_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_{{ receptor_name }}_receptor_count, 0); + {%- endfor %} + {%- endfor %} + + {% if receptor_info["ODEs"].items()|length %} std::vector< double > {{ printer_no_origin.print(receptor_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); {% endif %} + + + + //update ODE state variable + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- for ode_variable, ode_info in receptor_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // update kernel state variable / compute synaptic conductance + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {%- if convolution_info["spikes"]["name"] == "self_spikes" %} + {{state_variable_name}}[i] += self_spikes[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- else %} + {{state_variable_name}}[i] += s_val[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endif %} + {%- endfor %} + {%- endfor %} + + // total current + // this expression should be the transformed inline expression + + this->i_tot_{{receptor_name}}[i] = {{ vector_printer.print(receptor_info["root_expression"].get_expression(), "i") }}; + + // derivative of that expression + // voltage derivative of total current + // compute derivative with respect to current with sympy + d_i_tot_dv[i] = {{ vector_printer.print(receptor_info["inline_derivative"], "i") }}; + + // for numerical integration + g_val[i] = - d_i_tot_dv[i]; + i_val[i] = this->i_tot_{{receptor_name}}[i] - d_i_tot_dv[i] * v_comp[i]; + } + + f_update(); + + //update recordable inlines + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + {{ variable }}[i] = {{ vector_printer.print(rhs_expression, "i") }}; + {%- endif %} + {%- endfor %} + } + + return std::make_pair(g_val, i_val); + +} + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::f_update() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- if receptor_info["Blocks"] %} + {%- if receptor_info["Blocks"]["UpdateBlock"] %} + {%- set function = receptor_info["Blocks"]["UpdateBlock"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + self_spikes[i] = false; + } +} + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::f_self_spike() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + self_spikes[i] = true; + {%- if receptor_info["Blocks"] %} + {%- if receptor_info["Blocks"]["SelfSpikesFunction"] %} + {%- set function = receptor_info["Blocks"]["SelfSpikesFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + } +} + +{%- for function in receptor_info["functions_used"] %} +inline {{ function_declaration.FunctionDeclaration(function, "nest::"~receptor_name~cm_unique_suffix~"::") }} +{ +{%- filter indent(2,True) %} +{%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} +{%- include "directives/StmtsBody.jinja2" %} +{%- endwith %} +{%- endfilter %} +} +{%- endfor %} + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::get_currents_per_compartment(std::vector< double >& compartment_to_current){ + for(std::size_t comp_id = 0; comp_id < compartment_to_current.size(); comp_id++){ + compartment_to_current[comp_id] = 0; + } + for(std::size_t syn_id = 0; syn_id < neuron_{{ receptor_name }}_receptor_count; syn_id++){ + compartment_to_current[this->compartment_association[syn_id]] += this->i_tot_{{receptor_name}}[syn_id]; + } +} + +std::vector< double > nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::distribute_shared_vector(std::vector< double > shared_vector){ + std::vector< double > distributed_vector(this->neuron_{{ receptor_name }}_receptor_count, 0.0); + for(std::size_t syn_id = 0; syn_id < this->neuron_{{ receptor_name }}_receptor_count; syn_id++){ + distributed_vector[syn_id] = shared_vector[compartment_association[syn_id]]; + } + return distributed_vector; +} + +// {{receptor_name}}_{{synapse_name}} receptor end /////////////////////////////////////////////////////////// +{%- endfor %} +{%- endfor %} + + +////////////////////////////////////// continuous inputs + +{%- for continuous_name, continuous_info in con_in_info.items() %} +// {{continuous_name}} continuous input /////////////////////////////////////////////////// + +void nest::{{continuous_name}}{{cm_unique_suffix}}::new_continuous_input(std::size_t comp_ass, const long con_in_index) +{ + neuron_{{ continuous_name }}_continuous_input_count++; + i_tot_{{continuous_name}}.push_back(0); + compartment_association.push_back(comp_ass); + continuous_idx.push_back(con_in_index); + + {%- for pure_variable_name, variable_info in continuous_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+continuous_name+"_continuous_input_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in continuous_info["Parameters"].items() %} + // parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+continuous_name+"_continuous_input_count") -}}); + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in continuous_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for variable_info in continuous_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + self_spikes.push_back(false); +} + +void nest::{{continuous_name}}{{cm_unique_suffix}}::new_continuous_input(std::size_t comp_ass, const long con_in_index, const DictionaryDatum& con_in_params) +/* update {{continuous_name}} continuous input parameters and states */ +{ + neuron_{{ continuous_name }}_continuous_input_count++; + compartment_association.push_back(comp_ass); + i_tot_{{continuous_name}}.push_back(0); + continuous_idx.push_back(con_in_index); + + {%- for pure_variable_name, variable_info in continuous_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+continuous_name+"_continuous_input_count-1") -}}); + {%- endfor %} + {%- for variable_type, variable_info in continuous_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( con_in_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable.name}}" ); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in continuous_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{continuous_name}} concentration ODE state {{dynamic_variable }} + if( con_in_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable_name}}" ); + {%- endfor %} + {% endwith %} + + {% for variable_type, variable_info in continuous_info["Parameters"].items() %} + // continuous parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+continuous_name+"_continuous_input_count-1") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in continuous_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( con_in_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in continuous_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for variable_info in continuous_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {{ variable }}.push_back(0); + {%- endif %} + {%- endfor %} + + self_spikes.push_back(false); +} void nest::{{continuous_name}}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, const long compartment_idx) { for(size_t con_in_id = 0; con_in_id < neuron_{{ continuous_name }}_continuous_input_count; con_in_id++){ if(compartment_association[con_in_id] == compartment_idx){ - ( *recordables )[ Name( "i_tot_{{continuous_name}}" + std::to_string(con_in_id) )] = &i_tot_{{continuous_name}}[con_in_id]; + ( *recordables )[ Name( "{{continuous_name}}" + std::to_string(continuous_idx[con_in_id]))] = &i_tot_{{continuous_name}}[con_in_id]; } } + + {%- for variable_info in continuous_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable_name = variable_info.variable_name %} + found_rec = false; + for(size_t con_in_id = 0; con_in_id < neuron_{{ continuous_name }}_continuous_input_count; con_in_id++){ + if(compartment_association[chan_id] == compartment_idx){ + ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &{{variable_name}}[con_in_id]; + found_rec = true; + } + } + if(!found_rec) ( *recordables )[ Name( std::string("{{variable_name}}") + std::to_string(compartment_idx))] = &zero_recordable; + {%- endif %} + {%- endfor %} } {%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} @@ -926,20 +1971,36 @@ void nest::{{continuous_name}}{{cm_unique_suffix}}::calibrate() void nest::{{continuous_name}}{{cm_unique_suffix}}::pre_run_hook() {%- endif %} { + {% if "time_resolution_var" in continuous_info %} + std::vector< double > {{ printer_no_origin.print(continuous_info["time_resolution_var"]) }}(neuron_{{ continuous_name }}_continuous_input_count, Time::get_resolution().get_ms()); + {% endif %} + for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in continuous_info["Internals"] %} + {{internal_name}}[i] = {{ vector_printer.print(internal_declaration.get_expression(), "i") }}; + {%- endfor %} - // user declared internals in order they were declared - {%- for internal_name, internal_declaration in continuous_info["Internals"] %} - {{internal_name}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(internal_declaration.get_expression()) }}; + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}[i] = 0; + {%- endfor %} {%- endfor %} - for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ {% for port_name, port_info in continuous_info["Continuous"].items() %} {{port_name}}_[i]->clear(); {% endfor %} } } -std::pair< std::vector< double >, std::vector< double > > nest::{{continuous_name}}{{cm_unique_suffix}}::f_numstep( std::vector< double > v_comp, const long lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} +std::pair< std::vector< double >, std::vector< double > > nest::{{continuous_name}}{{cm_unique_suffix}}::f_numstep(std::vector< double > v_comp, const long lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} {% endif %}{% for inline in continuous_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["channels"]|length %} {% endif %}{% for inline in continuous_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["continuous"]|length %} {% endif %}{% for inline in continuous_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}) @@ -948,7 +2009,7 @@ std::pair< std::vector< double >, std::vector< double > > nest::{{continuous_nam std::vector< double > i_val(neuron_{{ continuous_name }}_continuous_input_count, 0.); std::vector< double > d_i_tot_dv(neuron_{{ continuous_name }}_continuous_input_count, 0.); - {% if continuous_info["ODEs"].items()|length %} + {% if "time_resolution_var" in continuous_info %} std::vector< double > {{ printer_no_origin.print(continuous_info["time_resolution_var"]) }}(neuron_{{ continuous_name }}_continuous_input_count, Time::get_resolution().get_ms()); {% endif %} @@ -970,34 +2031,92 @@ std::pair< std::vector< double >, std::vector< double > > nest::{{continuous_nam #pragma omp simd for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {{state_variable_name}}[i] += self_spikes[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + //update ODE state variable {%- for ode_variable, ode_info in continuous_info["ODEs"].items() %} {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} - {{ propagator }}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(propagator_info["init_expression"]) }}; + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; {%- endfor %} {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} - {{state}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(state_solution_info["update_expression"]) }}; + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; {%- endfor %} {%- endfor %} // total current // this expression should be the transformed inline expression - this->i_tot_{{continuous_name}}[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(continuous_info["root_expression"].get_expression()) }}; + this->i_tot_{{continuous_name}}[i] = {{ vector_printer.print(continuous_info["root_expression"].get_expression(), "i") }}; // derivative of that expression // voltage derivative of total current // compute derivative with respect to current with sympy - d_i_tot_dv[i] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("i").print(continuous_info["inline_derivative"]) }}; + d_i_tot_dv[i] = {{ vector_printer.print(continuous_info["inline_derivative"], "i") }}; // for numerical integration g_val[i] = - d_i_tot_dv[i]; i_val[i] = this->i_tot_{{continuous_name}}[i] - d_i_tot_dv[i] * v_comp[i]; } + f_update(); + + //update recordable inlines + for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ + {%- for variable_info in continuous_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + {{ variable }}[i] = {{ vector_printer.print(rhs_expression, "i") }}; + {%- endif %} + {%- endfor %} + } + return std::make_pair(g_val, i_val); } +void nest::{{continuous_name}}{{cm_unique_suffix}}::f_update() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ + {%- if continuous_info["Blocks"] %} + {%- if continuous_info["Blocks"]["UpdateBlock"] %} + {%- set function = continuous_info["Blocks"]["UpdateBlock"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + self_spikes[i] = false; + } +} + +void nest::{{continuous_name}}{{cm_unique_suffix}}::f_self_spike() +{ + double __resolution = Time::get_resolution().get_ms(); + for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ + self_spikes[i] = true; + {%- if continuous_info["Blocks"] %} + {%- if continuous_info["Blocks"]["SelfSpikesFunction"] %} + {%- set function = continuous_info["Blocks"]["SelfSpikesFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_stmts_body() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/StmtsBody.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + {%- endif %} + } +} + {%- for function in continuous_info["Functions"] %} inline {{ function_declaration.FunctionDeclaration(function, "nest::"~continuous_name~cm_unique_suffix~"::") }} { diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.h.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.h.jinja2 index 576abcda0..0f9e244cd 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.h.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.h.jinja2 @@ -21,8 +21,8 @@ along with NEST. If not, see . {%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif -%} {%- import 'directives_cpp/FunctionDeclaration.jinja2' as function_declaration with context %} -#ifndef SYNAPSES_NEAT_H_{{cm_unique_suffix | upper }} -#define SYNAPSES_NEAT_H_{{cm_unique_suffix | upper }} +#ifndef RECEPTORS_NEAT_H_{{cm_unique_suffix | upper }} +#define RECEPTORS_NEAT_H_{{cm_unique_suffix | upper }} #include #include @@ -37,6 +37,11 @@ along with NEST. If not, see . {%- endwith -%} {%- endmacro %} +//elementwise vector operations: +#include +#include +#include + namespace nest { @@ -68,12 +73,37 @@ private: std::vector< {{ render_variable_type(variable) }} > {{ variable.name }} = {}; {%- endfor %} + // recordable inlines + {%- for variable_info in channel_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + std::vector< double > {{ variable }} = {}; + {%- endif %} + {%- endfor %} + + // propagators, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // kernel state variables, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in channel_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + // ion-channel root-inline value std::vector< double > i_tot_{{ion_channel_name}} = {}; //zero recordable variable in case of zero contribution channel double zero_recordable = 0; + std::vector< bool > self_spikes; + public: // constructor, destructor {{ion_channel_name}}{{cm_unique_suffix}}(){}; @@ -89,25 +119,28 @@ public: // initialization channel {%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - void calibrate() { + void calibrate(); {%- else %} - void pre_run_hook() { + void pre_run_hook(); {%- endif %} - }; void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); // numerical integration step - std::pair< std::vector< double >, std::vector< double > > f_numstep( std::vector< double > v_comp{% for ode in channel_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} + std::pair< std::vector< double >, std::vector< double > > f_numstep(std::vector< double > v_comp{% for ode in channel_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} {% endif %}{% for inline in channel_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["channels"]|length %} {% endif %}{% for inline in channel_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["continuous"]|length %} {% endif %}{% for inline in channel_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}); + + void f_update(); + void f_self_spike(); + // function declarations {%- for function in channel_info["Functions"] %} #pragma omp declare simd - __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function) }}; + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) }}; {%- endfor %} // root_inline getter @@ -148,12 +181,37 @@ private: std::vector< {{ render_variable_type(variable) }} > {{ variable.name }} = {}; {%- endfor %} + // recordable inlines + {%- for variable_info in concentration_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + std::vector< double > {{ variable }} = {}; + {%- endif %} + {%- endfor %} + + // propagators, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // kernel state variables, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in concentration_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + // concentration value (root-ode state) std::vector< double > {{concentration_name}} = {}; //zero recordable variable in case of zero contribution concentration double zero_recordable = 0; + std::vector< bool > self_spikes; + public: // constructor, destructor {{ concentration_name }}{{cm_unique_suffix}}(){}; @@ -169,31 +227,25 @@ public: // initialization concentration {%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - void calibrate() { + void calibrate(); {%- else %} - void pre_run_hook() { + void pre_run_hook(); {%- endif %} - for(std::size_t concentration_id = 0; concentration_id < neuron_{{ concentration_name }}_concentration_count; concentration_id++){ - // states - {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name }}[concentration_id] = {{ vector_printer_factory.create_ast_vector_parameter_setter_and_printer("concentration_id").print(rhs_expression) }}; - {%- endfor %} - } - }; void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); // numerical integration step - void f_numstep( std::vector< double > v_comp{% for ode in concentration_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} + void f_numstep(std::vector< double > v_comp{% for ode in concentration_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} {% endif %}{% for inline in concentration_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["continuous"]|length %} {% endif %}{% for inline in concentration_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}); + void f_update(); + void f_self_spike(); + // function declarations {%- for function in concentration_info["Functions"] %} #pragma omp declare simd - __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function) }}; + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) }}; {%- endfor %} // root_ode getter @@ -206,12 +258,12 @@ public: {% endwith -%} -////////////////////////////////////////////////// synapses +////////////////////////////////////////////////// receptors -{% macro render_time_resolution_variable(synapse_info) -%} +{% macro render_time_resolution_variable(receptor_info) -%} {# we assume here that there is only one such variable ! #} {%- with %} -{%- for analytic_helper_name, analytic_helper_info in synapse_info["analytic_helpers"].items() -%} +{%- for analytic_helper_name, analytic_helper_info in receptor_info["analytic_helpers"].items() -%} {%- if analytic_helper_info["is_time_resolution"] -%} {{ analytic_helper_name }} {%- endif -%} @@ -219,71 +271,289 @@ public: {% endwith %} {%- endmacro %} +{%- with %} +{%- for receptor_name, receptor_info in recs_info.items() %} + +class {{receptor_name}}{{cm_unique_suffix}}{ +private: + // global receptor index + std::vector< long > rec_idx = {}; + + // propagators, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // kernel state variables, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // user defined parameters, initialized via pre_run_hook() or calibrate() + {%- for param_name, param_declaration in receptor_info["Parameters"].items() %} + std::vector< double > {{param_name}}; + {%- endfor %} + + // states + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }} = {} + }; + {%- endfor %} + + // recordable inlines + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + std::vector< double > {{ variable }} = {}; + {%- endif %} + {%- endfor %} + + std::vector < double > s_val = {}; + + std::vector< double > i_tot_{{receptor_name}} = {}; + + // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + std::vector< double > {{internal_name}}; + {%- endfor %} + + // spike buffer + std::vector< RingBuffer* > {{receptor_info["buffer_name"]}}_; + + std::vector< bool > self_spikes; + +public: + // constructor, destructor + {{receptor_name}}{{cm_unique_suffix}}(){}; + ~{{receptor_name}}{{cm_unique_suffix}}(){}; + + void new_receptor(std::size_t comp_ass, const long rec_index); + void new_receptor(std::size_t comp_ass, const long rec_index, const DictionaryDatum& receptor_params); + + //number of receptors + std::size_t neuron_{{ receptor_name }}_receptor_count = 0; + + std::vector< size_t > compartment_association = {}; + + // numerical integration step + std::pair< std::vector< double >, std::vector< double > > f_numstep(std::vector< double > v_comp, const long lag {% for ode in receptor_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if receptor_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}); + + void f_update(); + void f_self_spike(); + + // calibration +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + void calibrate(); +{%- else %} + void pre_run_hook(); +{%- endif %} + void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); + void set_buffer_ptr( std::vector< RingBuffer >& rec_buffers ) + { + for(std::size_t i = 0; i < rec_idx.size(); i++){ + {{receptor_info["buffer_name"]}}_.push_back(&(rec_buffers[rec_idx[i]])); + } + }; + + // function declarations + {%- for function in receptor_info["Functions"] %} + #pragma omp declare simd + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) -}}; + + {% endfor %} + + // root_inline getter + void get_currents_per_compartment(std::vector< double >& compartment_to_current); + + std::vector< double > distribute_shared_vector(std::vector< double > shared_vector); + +}; + +{% endfor -%} +{% endwith -%} + +////////////////////////////////////////////////// receptors with synapses attached + + {%- with %} {%- for synapse_name, synapse_info in syns_info.items() %} +{%- for receptor_name, receptor_info in recs_info.items() %} -class {{synapse_name}}{{cm_unique_suffix}}{ +class {{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}{ private: - // global synapse index + + // dendritic delay buffer class + class TimerDeque { + public: + TimerDeque(double delay) : + countdown(nest::Time::delay_ms_to_steps(delay)) + {} + + void push(double value) { + int local_countdown = countdown - total_countdown; + total_countdown = countdown; + dq.push_back(std::make_pair(local_countdown, value)); + } + + double tick() { + if (!dq.empty()) { + dq.front().first--; + total_countdown--; + if (dq.front().first == 0) { + double value = dq.front().second; + dq.pop_front(); + return value; + } + } + return 0; + } + + private: + std::deque> dq; + + const int countdown; + int total_countdown = 0; + }; + + // delayed spikes queue + std::vector delayed_spikes; + + // global receptor index std::vector< long > syn_idx = {}; // propagators, initialized via pre_run_hook() or calibrate() - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} std::vector< double > {{state_variable_name}}; {%- endfor %} {%- endfor %} // kernel state variables, initialized via pre_run_hook() or calibrate() - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} std::vector< double > {{state_variable_name}}; {%- endfor %} {%- endfor %} // user defined parameters, initialized via pre_run_hook() or calibrate() - {%- for param_name, param_declaration in synapse_info["Parameters"].items() %} + {%- for param_name, param_declaration in receptor_info["Parameters"].items() %} std::vector< double > {{param_name}}; {%- endfor %} // states - {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} {%- set variable = variable_info["ASTVariable"] %} {%- set rhs_expression = variable_info["rhs_expression"] %} std::vector<{{ render_variable_type(variable) }}> {{ variable.name }} = {} }; {%- endfor %} - std::vector< double > i_tot_{{synapse_name}} = {}; + // recordable inlines + {%- for variable_info in receptor_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + std::vector< double > {{ variable }} = {}; + {%- endif %} + {%- endfor %} + + std::vector < double > s_val = {}; + + std::vector< double > i_tot_{{receptor_name}} = {}; // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() - {%- for internal_name, internal_declaration in synapse_info["internals_used_declared"] %} + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} std::vector< double > {{internal_name}}; {%- endfor %} // spike buffer - std::vector< RingBuffer* > {{synapse_info["buffer_name"]}}_; + std::vector< RingBuffer* > {{receptor_info["buffer_name"]}}_; + + //synapse related variables: + // user defined parameters, initialized via pre_run_hook() or calibrate() + {%- for param_name, param_declaration in synapse_info["Parameters"].items() %} + std::vector< double > {{param_name}}; + {%- endfor %} + + // states + {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }} = {}; + {%- endfor %} + + // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() + {%- for internal_name, internal_declaration in synapse_info["Internals"].items() %} + std::vector< double > {{internal_name}}; + {%- endfor %} + + {%- with %} + {%- for in_function_declaration in synapse_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + std::vector<{{declarations.print_variable_type(variable)}}> {{variable.get_symbol_name()}} = {}; + {%- endfor %} + {%- endfor %} + {%- endwith %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + {%- for inline_name, inline in synapse_info["Inlines"].items() %} + std::vector< double > {{inline_name}}; + {%- endfor %} + + std::vector< bool > self_spikes; public: // constructor, destructor - {{synapse_name}}{{cm_unique_suffix}}(){}; - ~{{synapse_name}}{{cm_unique_suffix}}(){}; + {{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}(){}; + ~{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}(){}; - void new_synapse(std::size_t comp_ass, const long syn_index); - void new_synapse(std::size_t comp_ass, const long syn_index, const DictionaryDatum& synapse_params); + void new_receptor(std::size_t comp_ass, const long syn_index); + void new_receptor(std::size_t comp_ass, const long syn_index, const DictionaryDatum& receptor_params); - //number of synapses - std::size_t neuron_{{ synapse_name }}_synapse_count = 0; + //number of receptors + std::size_t neuron_{{ receptor_name }}_receptor_count = 0; std::vector< size_t > compartment_association = {}; // numerical integration step - std::pair< std::vector< double >, std::vector< double > > f_numstep( std::vector< double > v_comp, const long lag {% for ode in synapse_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if synapse_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if synapse_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if synapse_info["Dependencies"]["continuous"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}); + {%- with %} + {%- set conc_dep = set(receptor_info["Dependencies"]["concentrations"]).union(synapse_info["Dependencies"]["concentrations"])%} + {%- set rec_dep = set(receptor_info["Dependencies"]["receptors"]).union(synapse_info["Dependencies"]["receptors"])%} + {%- set chan_dep = set(receptor_info["Dependencies"]["channels"]).union(synapse_info["Dependencies"]["channels"])%} + {%- set con_in_dep = set(receptor_info["Dependencies"]["continuous"]).union(synapse_info["Dependencies"]["continuous"])%} + std::pair< std::vector< double >, std::vector< double > > f_numstep(std::vector< double > v_comp, const long lag {% for ode in conc_dep %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if rec_dep|length %} + {% endif %}{% for inline in rec_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if chan_dep|length %} + {% endif %}{% for inline in chan_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if con_in_dep|length %} + {% endif %}{% for inline in con_in_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}); + {%- endwith %} + + void f_update(); + void f_self_spike(); + + void postsynaptic_synaptic_processing(); // calibration {%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} @@ -295,14 +565,14 @@ public: void set_buffer_ptr( std::vector< RingBuffer >& syn_buffers ) { for(std::size_t i = 0; i < syn_idx.size(); i++){ - {{synapse_info["buffer_name"]}}_.push_back(&(syn_buffers[syn_idx[i]])); + {{receptor_info["buffer_name"]}}_.push_back(&(syn_buffers[syn_idx[i]])); } }; // function declarations - {%- for function in synapse_info["Functions"] %} + {%- for function in receptor_info["Functions"] %} #pragma omp declare simd - __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function) -}}; + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) -}}; {% endfor %} @@ -311,9 +581,16 @@ public: std::vector< double > distribute_shared_vector(std::vector< double > shared_vector); + void get_history__( double t1, + double t2, + std::deque< histentry >::iterator* start, + std::deque< histentry >::iterator* finish ); + + }; {% endfor -%} +{% endfor %} {% endwith -%} @@ -340,6 +617,29 @@ private: }; {%- endfor %} + // recordable inlines + {%- for variable_info in continuous_info["SecondaryInlineExpressions"] %} + {%- if variable_info.is_recordable %} + {%- set variable = variable_info.variable_name %} + {%- set rhs_expression = variable_info.expression %} + std::vector< double > {{ variable }} = {}; + {%- endif %} + {%- endfor %} + + // propagators, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // kernel state variables, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in continuous_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + std::vector< double > i_tot_{{continuous_name}} = {}; // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() @@ -354,6 +654,8 @@ private: std::vector< RingBuffer* > {{ port_name }}_; {% endfor %} + std::vector< bool > self_spikes; + public: // constructor, destructor {{continuous_name}}{{cm_unique_suffix}}(){}; @@ -368,11 +670,14 @@ public: std::vector< size_t > compartment_association = {}; // numerical integration step - std::pair< std::vector< double >, std::vector< double > > f_numstep( std::vector< double > v_comp, const long lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} + std::pair< std::vector< double >, std::vector< double > > f_numstep(std::vector< double > v_comp, const long lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} {% endif %}{% for inline in continuous_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["channels"]|length %} {% endif %}{% for inline in continuous_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["continuous"]|length %} {% endif %}{% for inline in continuous_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}); + void f_update(); + void f_self_spike(); + // calibration {%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} void calibrate(); @@ -392,7 +697,7 @@ public: // function declarations {%- for function in continuous_info["Functions"] %} #pragma omp declare simd - __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function) -}}; + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) -}}; {% endfor %} @@ -411,11 +716,13 @@ public: {%- set channel_suffix = "_chan_" %} {%- set concentration_suffix = "_conc_" %} -{%- set synapse_suffix = "_syn_" %} +{%- set receptor_suffix = "_syn_" %} {%- set continuous_suffix = "_con_in_" %} class NeuronCurrents{{cm_unique_suffix}} { private: + + bool initialized = false; //mechanisms // ion channels {% with %} @@ -429,12 +736,18 @@ private: {{concentration_name}}{{cm_unique_suffix}} {{concentration_name}}{{concentration_suffix}}; {% endfor -%} {% endwith %} - // synapses + // receptors {% with %} - {%- for synapse_name, synapse_info in syns_info.items() %} - {{synapse_name}}{{cm_unique_suffix}} {{synapse_name}}{{synapse_suffix}}; + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{cm_unique_suffix}} {{receptor_name}}{{receptor_suffix}}; {% endfor -%} {% endwith %} + // receptors with synapses +{%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}} {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}; + {% endfor -%} +{% endfor -%} // continuous inputs {% with %} {%- for continuous_name, continuous_info in con_in_info.items() %} @@ -460,11 +773,20 @@ private: std::vector < std::pair< std::size_t, int > > {{concentration_name}}{{concentration_suffix}}_con_area; {% endfor -%} {% endwith %} - // synapses + // receptors +{% with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + std::vector < double > {{receptor_name}}{{receptor_suffix}}_shared_current; + std::vector < std::pair< std::size_t, int > > {{receptor_name}}{{receptor_suffix}}_con_area; + {% endfor -%} +{% endwith %} + // receptors with synapses {% with %} + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - std::vector < double > {{synapse_name}}{{synapse_suffix}}_shared_current; - std::vector < std::pair< std::size_t, int > > {{synapse_name}}{{synapse_suffix}}_con_area; + std::vector < double > {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current; + std::vector < std::pair< std::size_t, int > > {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area; + {% endfor -%} {% endfor -%} {% endwith %} // continuous inputs @@ -487,6 +809,7 @@ public: {%- else %} void pre_run_hook() { {%- endif %} + if(!initialized){ // initialization of ion channels {%- with %} {%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} @@ -496,8 +819,13 @@ public: {%- for concentration_name, concentration_info in conc_info.items() %} {{concentration_name}}{{concentration_suffix}}.calibrate(); {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{receptor_suffix}}.calibrate(); + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - {{synapse_name}}{{synapse_suffix}}.calibrate(); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.calibrate(); + {% endfor -%} {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} {{continuous_name}}{{continuous_suffix}}.calibrate(); @@ -509,13 +837,19 @@ public: {%- for concentration_name, concentration_info in conc_info.items() %} {{concentration_name}}{{concentration_suffix}}.pre_run_hook(); {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{receptor_suffix}}.pre_run_hook(); + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - {{synapse_name}}{{synapse_suffix}}.pre_run_hook(); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.pre_run_hook(); + {% endfor -%} {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} {{continuous_name}}{{continuous_suffix}}.pre_run_hook(); {% endfor -%} {%- endif %} + int con_end_index; {%- for ion_channel_name, channel_info in chan_info.items() %} if({{ion_channel_name}}{{channel_suffix}}.neuron_{{ ion_channel_name }}_channel_count){ @@ -541,18 +875,32 @@ public: } } {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + if({{receptor_name}}{{receptor_suffix}}.neuron_{{ receptor_name }}_receptor_count){ + con_end_index = int({{receptor_name}}{{receptor_suffix}}.compartment_association[0]); + {{receptor_name}}{{receptor_suffix}}_con_area.push_back(std::pair< std::size_t, int >(0, con_end_index)); + } + for(std::size_t syn_id = 0; syn_id < {{receptor_name}}{{receptor_suffix}}.neuron_{{ receptor_name }}_receptor_count; syn_id++){ + if(!({{receptor_name}}{{receptor_suffix}}.compartment_association[syn_id] == size_t(int(syn_id) + con_end_index))){ + con_end_index = int({{receptor_name}}{{receptor_suffix}}.compartment_association[syn_id]) - int(syn_id); + {{receptor_name}}{{receptor_suffix}}_con_area.push_back(std::pair< std::size_t, int >(syn_id, con_end_index)); + } + } + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - if({{synapse_name}}{{synapse_suffix}}.neuron_{{ synapse_name }}_synapse_count){ - con_end_index = int({{synapse_name}}{{synapse_suffix}}.compartment_association[0]); - {{synapse_name}}{{synapse_suffix}}_con_area.push_back(std::pair< std::size_t, int >(0, con_end_index)); - } - for(std::size_t syn_id = 0; syn_id < {{synapse_name}}{{synapse_suffix}}.neuron_{{ synapse_name }}_synapse_count; syn_id++){ - if(!({{synapse_name}}{{synapse_suffix}}.compartment_association[syn_id] == size_t(int(syn_id) + con_end_index))){ - con_end_index = int({{synapse_name}}{{synapse_suffix}}.compartment_association[syn_id]) - int(syn_id); - {{synapse_name}}{{synapse_suffix}}_con_area.push_back(std::pair< std::size_t, int >(syn_id, con_end_index)); + if({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.neuron_{{ receptor_name }}_receptor_count){ + con_end_index = int({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.compartment_association[0]); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area.push_back(std::pair< std::size_t, int >(0, con_end_index)); + } + for(std::size_t syn_id = 0; syn_id < {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.neuron_{{ receptor_name }}_receptor_count; syn_id++){ + if(!({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.compartment_association[syn_id] == size_t(int(syn_id) + con_end_index))){ + con_end_index = int({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.compartment_association[syn_id]) - int(syn_id); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area.push_back(std::pair< std::size_t, int >(syn_id, con_end_index)); } } {% endfor -%} + {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} if({{continuous_name}}{{continuous_suffix}}.neuron_{{ continuous_name }}_continuous_input_count){ con_end_index = int({{continuous_name}}{{continuous_suffix}}.compartment_association[0]); @@ -566,6 +914,8 @@ public: } {% endfor -%} {% endwith -%} + initialized = true; + } }; void add_mechanism( const std::string& type, const std::size_t compartment_id, const long multi_mech_index = 0) @@ -588,13 +938,23 @@ public: } {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + if ( type == "{{receptor_name}}" ) + { + {{receptor_name}}{{receptor_suffix}}.new_receptor(compartment_id, multi_mech_index); + mech_found = true; + } + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - if ( type == "{{synapse_name}}" ) + if ( type == "{{receptor_name}}_{{synapse_name}}" ) { - {{synapse_name}}{{synapse_suffix}}.new_synapse(compartment_id, multi_mech_index); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.new_receptor(compartment_id, multi_mech_index); mech_found = true; } {% endfor -%} + {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} if ( type == "{{continuous_name}}" ) @@ -607,6 +967,7 @@ public: {% endwith -%} if(!mech_found) { + throw BadProperty( type + " mechanism does not exist." ); assert( false ); } }; @@ -631,13 +992,23 @@ public: } {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + if ( type == "{{receptor_name}}" ) + { + {{receptor_name}}{{receptor_suffix}}.new_receptor(compartment_id, multi_mech_index, mechanism_params); + mech_found = true; + } + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - if ( type == "{{synapse_name}}" ) + if ( type == "{{receptor_name}}_{{synapse_name}}" ) { - {{synapse_name}}{{synapse_suffix}}.new_synapse(compartment_id, multi_mech_index, mechanism_params); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.new_receptor(compartment_id, multi_mech_index, mechanism_params); mech_found = true; } {% endfor -%} + {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} if ( type == "{{continuous_name}}" ) @@ -649,6 +1020,7 @@ public: {% endwith -%} if(!mech_found) { + throw BadProperty( type + " mechanism does not exist." ); assert( false ); } }; @@ -672,8 +1044,14 @@ public: this->{{concentration_name}}{{concentration_suffix}}_shared_concentration.push_back(0.0); {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + this->{{receptor_name}}{{receptor_suffix}}_shared_current.push_back(0.0); + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - this->{{synapse_name}}{{synapse_suffix}}_shared_current.push_back(0.0); + this->{{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current.push_back(0.0); + {% endfor -%} {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} @@ -701,8 +1079,14 @@ public: this->{{concentration_name}}{{concentration_suffix}}_shared_concentration.push_back(0.0); {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + this->{{receptor_name}}{{receptor_suffix}}_shared_current.push_back(0.0); + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - this->{{synapse_name}}{{synapse_suffix}}_shared_current.push_back(0.0); + this->{{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current.push_back(0.0); + {% endfor -%} {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} @@ -713,16 +1097,29 @@ public: void add_receptor_info( ArrayDatum& ad, long compartment_index ) { {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + for( std::size_t syn_it = 0; syn_it != {{receptor_name}}{{receptor_suffix}}.neuron_{{receptor_name}}_receptor_count; syn_it++) + { + DictionaryDatum dd = DictionaryDatum( new Dictionary ); + def< long >( dd, names::receptor_idx, syn_it ); + def< long >( dd, names::comp_idx, compartment_index ); + def< std::string >( dd, names::receptor_type, "{{receptor_name}}" ); + ad.push_back( dd ); + } + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - for( std::size_t syn_it = 0; syn_it != {{synapse_name}}{{synapse_suffix}}.neuron_{{synapse_name}}_synapse_count; syn_it++) + for( std::size_t syn_it = 0; syn_it != {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.neuron_{{receptor_name}}_receptor_count; syn_it++) { DictionaryDatum dd = DictionaryDatum( new Dictionary ); def< long >( dd, names::receptor_idx, syn_it ); def< long >( dd, names::comp_idx, compartment_index ); - def< std::string >( dd, names::receptor_type, "{{synapse_name}}" ); + def< std::string >( dd, names::receptor_type, "{{receptor_name}}_{{synapse_name}}" ); ad.push_back( dd ); } {% endfor -%} + {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} for( std::size_t con_it = 0; con_it != {{continuous_name}}{{continuous_suffix}}.neuron_{{continuous_name}}_continuous_input_count; con_it++) @@ -740,11 +1137,16 @@ public: void set_buffers( std::vector< RingBuffer >& buffers) { - // spike and continuous buffers for synapses and continuous inputs + // spike and continuous buffers for receptors and continuous inputs {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{ receptor_suffix }}.set_buffer_ptr( buffers ); + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - {{synapse_name}}{{ synapse_suffix }}.set_buffer_ptr( buffers ); + {{receptor_name}}{{ receptor_suffix }}_con_{{synapse_name}}.set_buffer_ptr( buffers ); + {% endfor -%} {% endfor -%} {%- for continuous_name, continuous_info in con_in_info.items() %} {{continuous_name}}{{ continuous_suffix }}.set_buffer_ptr( buffers ); @@ -771,10 +1173,19 @@ public: {% endfor %} {% endwith %} - // append synapse state variables to recordables + // append receptor state variables to recordables {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{receptor_suffix}}.append_recordables( &recordables, compartment_idx ); + {% endfor %} + {% endwith %} + + // append receptor with synapse state variables to recordables + {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} {%- for synapse_name, synapse_info in syns_info.items() %} - {{synapse_name}}{{synapse_suffix}}.append_recordables( &recordables, compartment_idx ); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.append_recordables( &recordables, compartment_idx ); + {% endfor %} {% endfor %} {% endwith %} @@ -791,8 +1202,16 @@ public: std::vector< std::pair< double, double > > f_numstep( std::vector< double > v_comp_vec, const long lag ) { std::vector< std::pair< double, double > > comp_to_gi(compartment_number, std::make_pair(0., 0.)); -{%- for synapse_name, synapse_info in syns_info.items() %} - {{synapse_name}}{{synapse_suffix}}.get_currents_per_compartment({{synapse_name}}{{synapse_suffix}}_shared_current); +{%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{receptor_suffix}}.get_currents_per_compartment({{receptor_name}}{{receptor_suffix}}_shared_current); +{% endfor %} +{%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.get_currents_per_compartment({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current); + for(size_t i = 0; i < {{receptor_name}}{{receptor_suffix}}_shared_current.size(); i++){ + {{receptor_name}}{{receptor_suffix}}_shared_current[i] += {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current[i]; + } + {% endfor %} {% endfor %} {%- for continuous_name, continuous_info in con_in_info.items() %} {{continuous_name}}{{continuous_suffix}}.get_currents_per_compartment({{continuous_name}}{{continuous_suffix}}_shared_current); @@ -804,11 +1223,12 @@ public: {{ion_channel_name}}{{channel_suffix}}.get_currents_per_compartment({{ion_channel_name}}{{channel_suffix}}_shared_current); {% endfor -%} + {%- with %} {%- for concentration_name, concentration_info in conc_info.items() %} // computation of {{ concentration_name }} concentration - {{ concentration_name }}{{concentration_suffix}}.f_numstep( {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector(v_comp_vec){% for ode in concentration_info["Dependencies"]["concentrations"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{inline.variable_name}}{{synapse_suffix}}_shared_current){% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} + {{ concentration_name }}{{concentration_suffix}}.f_numstep({{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector(v_comp_vec){% for ode in concentration_info["Dependencies"]["concentrations"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} {% endif %}{% for inline in concentration_info["Dependencies"]["channels"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if concentration_info["Dependencies"]["continuous"]|length %} {% endif %}{% for inline in concentration_info["Dependencies"]["continuous"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}); @@ -821,8 +1241,8 @@ public: {%- with %} {%- for ion_channel_name, channel_info in chan_info.items() %} // contribution of {{ion_channel_name}} channel - gi_mech = {{ion_channel_name}}{{channel_suffix}}.f_numstep( {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector(v_comp_vec){% for ode in channel_info["Dependencies"]["concentrations"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in channel_info["Dependencies"]["receptors"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{inline.variable_name}}{{synapse_suffix}}_shared_current){% endfor %}{% if channel_info["Dependencies"]["channels"]|length %} + gi_mech = {{ion_channel_name}}{{channel_suffix}}.f_numstep({{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector(v_comp_vec){% for ode in channel_info["Dependencies"]["concentrations"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["receptors"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if channel_info["Dependencies"]["channels"]|length %} {% endif %}{% for inline in channel_info["Dependencies"]["channels"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if channel_info["Dependencies"]["continuous"]|length %} {% endif %}{% for inline in channel_info["Dependencies"]["continuous"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}); @@ -853,19 +1273,59 @@ public: {% endwith -%} {%- with %} - {%- for synapse_name, synapse_info in syns_info.items() %} - // contribution of {{synapse_name}} synapses - gi_mech = {{synapse_name}}{{synapse_suffix}}.f_numstep( {{synapse_name}}{{synapse_suffix}}.distribute_shared_vector(v_comp_vec), lag {% for ode in synapse_info["Dependencies"]["concentrations"] %}, {{synapse_name}}{{synapse_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if synapse_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["receptors"] %}, {{synapse_name}}{{synapse_suffix}}.distribute_shared_vector({{inline.variable_name}}{{synapse_suffix}}_shared_current){% endfor %}{% if synapse_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["channels"] %}, {{synapse_name}}{{synapse_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if synapse_info["Dependencies"]["continuous"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["continuous"] %}, {{synapse_name}}{{synapse_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}); + {%- for receptor_name, receptor_info in recs_info.items() %} + // contribution of {{receptor_name}} receptors + gi_mech = {{receptor_name}}{{receptor_suffix}}.f_numstep({{receptor_name}}{{receptor_suffix}}.distribute_shared_vector(v_comp_vec), lag {% for ode in receptor_info["Dependencies"]["concentrations"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if receptor_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["receptors"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if receptor_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["channels"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if receptor_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["continuous"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}); + + con_area_count = {{receptor_name}}{{receptor_suffix}}_con_area.size(); + if(con_area_count > 0){ + for(std::size_t con_area_index = 0; con_area_index < con_area_count-1; con_area_index++){ + std::size_t con_area = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_index].first; + std::size_t next_con_area = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_index+1].first; + int offset = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_index].second; + + #pragma omp simd + for(std::size_t syn_id = con_area; syn_id < next_con_area; syn_id++){ + comp_to_gi[syn_id+offset].first += gi_mech.first[syn_id]; + comp_to_gi[syn_id+offset].second += gi_mech.second[syn_id]; + } + } + + std::size_t con_area = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_count-1].first; + int offset = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_count-1].second; + + #pragma omp simd + for(std::size_t syn_id = con_area; syn_id < {{receptor_name}}{{receptor_suffix}}.neuron_{{ receptor_name }}_receptor_count; syn_id++){ + comp_to_gi[syn_id+offset].first += gi_mech.first[syn_id]; + comp_to_gi[syn_id+offset].second += gi_mech.second[syn_id]; + } + } + {% endfor -%} + {% endwith -%} - con_area_count = {{synapse_name}}{{synapse_suffix}}_con_area.size(); +{%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + // contribution of {{receptor_name}}_{{synapse_name}} receptors + {%- with %} + {%- set conc_dep = set(receptor_info["Dependencies"]["concentrations"]).union(synapse_info["Dependencies"]["concentrations"])%} + {%- set rec_dep = set(receptor_info["Dependencies"]["receptors"]).union(synapse_info["Dependencies"]["receptors"])%} + {%- set chan_dep = set(receptor_info["Dependencies"]["channels"]).union(synapse_info["Dependencies"]["channels"])%} + {%- set con_in_dep = set(receptor_info["Dependencies"]["continuous"]).union(synapse_info["Dependencies"]["continuous"])%} + gi_mech = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.f_numstep({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector(v_comp_vec), lag {% for ode in conc_dep %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if rec_dep|length %} + {% endif %}{% for inline in rec_dep %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if chan_dep|length %} + {% endif %}{% for inline in chan_dep %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if con_in_dep|length %} + {% endif %}{% for inline in con_in_dep %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}); + {%- endwith %} + con_area_count = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area.size(); if(con_area_count > 0){ for(std::size_t con_area_index = 0; con_area_index < con_area_count-1; con_area_index++){ - std::size_t con_area = {{synapse_name}}{{synapse_suffix}}_con_area[con_area_index].first; - std::size_t next_con_area = {{synapse_name}}{{synapse_suffix}}_con_area[con_area_index+1].first; - int offset = {{synapse_name}}{{synapse_suffix}}_con_area[con_area_index].second; + std::size_t con_area = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_index].first; + std::size_t next_con_area = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_index+1].first; + int offset = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_index].second; #pragma omp simd for(std::size_t syn_id = con_area; syn_id < next_con_area; syn_id++){ @@ -874,23 +1334,24 @@ public: } } - std::size_t con_area = {{synapse_name}}{{synapse_suffix}}_con_area[con_area_count-1].first; - int offset = {{synapse_name}}{{synapse_suffix}}_con_area[con_area_count-1].second; + std::size_t con_area = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_count-1].first; + int offset = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_count-1].second; #pragma omp simd - for(std::size_t syn_id = con_area; syn_id < {{synapse_name}}{{synapse_suffix}}.neuron_{{ synapse_name }}_synapse_count; syn_id++){ + for(std::size_t syn_id = con_area; syn_id < {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.neuron_{{ receptor_name }}_receptor_count; syn_id++){ comp_to_gi[syn_id+offset].first += gi_mech.first[syn_id]; comp_to_gi[syn_id+offset].second += gi_mech.second[syn_id]; } } {% endfor -%} + {% endfor -%} {% endwith -%} {%- with %} {%- for continuous_name, continuous_info in con_in_info.items() %} // contribution of {{continuous_name}} continuous inputs - gi_mech = {{continuous_name}}{{continuous_suffix}}.f_numstep( {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector(v_comp_vec), lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in continuous_info["Dependencies"]["receptors"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{inline.variable_name}}{{synapse_suffix}}_shared_current){% endfor %}{% if continuous_info["Dependencies"]["channels"]|length %} + gi_mech = {{continuous_name}}{{continuous_suffix}}.f_numstep({{continuous_name}}{{continuous_suffix}}.distribute_shared_vector(v_comp_vec), lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["receptors"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if continuous_info["Dependencies"]["channels"]|length %} {% endif %}{% for inline in continuous_info["Dependencies"]["channels"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if continuous_info["Dependencies"]["continuous"]|length %} {% endif %}{% for inline in continuous_info["Dependencies"]["continuous"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}); @@ -922,8 +1383,27 @@ public: return comp_to_gi; }; + + void postsynaptic_synaptic_processing(){ + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.postsynaptic_synaptic_processing(); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.f_self_spike(); + {% endfor -%} + {{receptor_name}}{{receptor_suffix}}.f_self_spike(); + {% endfor -%} + {%- for ion_channel_name, channel_info in chan_info.items() %} + {{ion_channel_name}}{{channel_suffix}}.f_self_spike(); + {% endfor -%} + {%- for concentration_name, concentration_info in conc_info.items() %} + {{concentration_name}}{{concentration_suffix}}.f_self_spike(); + {% endfor -%} + {%- for continuous_name, continuous_info in con_in_info.items() %} + {{continuous_name}}{{continuous_suffix}}.f_self_spike(); + {% endfor -%} + }; }; } // namespace -#endif /* #ifndef SYNAPSES_NEAT_H_{{cm_unique_suffix | upper }} */ +#endif /* #ifndef receptorS_NEAT_H_{{cm_unique_suffix | upper }} */ diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.cpp.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.cpp.jinja2 index a70f43213..19d97ebc8 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.cpp.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.cpp.jinja2 @@ -213,18 +213,18 @@ nest::CompTree{{cm_unique_suffix}}::add_compartment( const long parent_index, co if( comp_param_copy->known( "{{variable_type}}" ) ) comp_param_copy->remove("{{variable_type}}"); {%- endfor %} {%- endfor %} -{%- for synapse_name, synapse_info in syns_info.items() %} - {%- for variable_type, variable_info in synapse_info["Parameters"].items() %} +{%- for receptor_name, receptor_info in recs_info.items() %} + {%- for variable_type, variable_info in receptor_info["Parameters"].items() %} {%- set variable = variable_info["ASTVariable"] %} if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); {%- endfor %} {%- endfor %} -{%- for synapse_name, synapse_info in syns_info.items() %} - {%- for variable_type, variable_info in synapse_info["States"].items() %} +{%- for receptor_name, receptor_info in recs_info.items() %} + {%- for variable_type, variable_info in receptor_info["States"].items() %} {%- set variable = variable_info["ASTVariable"] %} if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); {%- endfor %} - {%- for variable_type, variable_info in synapse_info["ODEs"].items() %} + {%- for variable_type, variable_info in receptor_info["ODEs"].items() %} if( comp_param_copy->known( "{{variable_type}}" ) ) comp_param_copy->remove("{{variable_type}}"); {%- endfor %} {%- endfor %} @@ -243,6 +243,18 @@ nest::CompTree{{cm_unique_suffix}}::add_compartment( const long parent_index, co if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); {%- endfor %} {%- endfor %} +//global vars + {%- for variable_type, variable_info in global_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} + {%- for variable_type, variable_info in global_info["ODEs"].items() %} + if( comp_param_copy->known( "{{variable_type}}" ) ) comp_param_copy->remove("{{variable_type}}"); + {%- endfor %} + {%- for variable_type, variable_info in global_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} if(!comp_param_copy->empty()){ std::string msg = "Following parameters are invalid: "; @@ -440,7 +452,7 @@ nest::CompTree{{cm_unique_suffix}}::set_leafs() } /** - * Initializes pointers for the spike buffers for all synapse receptors + * Initializes pointers for the spike buffers for all receptor receptors */ void nest::CompTree{{cm_unique_suffix}}::set_syn_buffers( std::vector< RingBuffer >& syn_buffers ) diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/common/ModuleClassMaster.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/common/ModuleClassMaster.jinja2 index 51c823e81..474fe70a7 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/common/ModuleClassMaster.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/common/ModuleClassMaster.jinja2 @@ -51,6 +51,9 @@ {% for neuron in neurons %} #include "{{perNeuronFileNamesCm[neuron.get_name()]["main"]}}.h" {% endfor %} +{% for synapse in synapses %} +#include "{{perSynapseFileNamesCm[synapse.get_name()]["main"]}}.h" +{% endfor %} class {{moduleName}} : public nest::NESTExtensionInterface { @@ -68,4 +71,7 @@ void {{moduleName}}::initialize() {%- for neuron in neurons %} nest::register_{{perNeuronFileNamesCm[neuron.get_name()]["main"]}}("{{perNeuronFileNamesCm[neuron.get_name()]["main"]}}"); {%- endfor %} -} \ No newline at end of file +{%- for synapse in synapses %} + nest::register_{{perSynapseFileNamesCm[synapse.get_name()]["main"]}}("{{perSynapseFileNamesCm[synapse.get_name()]["main"]}}"); +{%- endfor %} +} diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py index 89e2bb7db..c852dea2d 100644 --- a/pynestml/frontend/pynestml_frontend.py +++ b/pynestml/frontend/pynestml_frontend.py @@ -84,7 +84,7 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st options = synapse_post_neuron_co_generation.set_options(options) transformers.append(synapse_post_neuron_co_generation) - if target_name.upper() == "NEST": + if target_name.upper() in ["NEST"]: from pynestml.transformers.synapse_post_neuron_transformer import SynapsePostNeuronTransformer # co-generate neuron and synapse @@ -497,8 +497,11 @@ def process() -> bool: # validation -- check cocos for models that do not have errors already excluded_models = [] for model in models: + syn_model = False + if "neuron_synapse_pairs" in FrontendConfiguration.get_codegen_opts(): + syn_model = model.name in [(pair["synapse"] + "_nestml") for pair in FrontendConfiguration.get_codegen_opts()["neuron_synapse_pairs"]] if not Logger.has_errors(model.name): - CoCosManager.check_cocos(model) + CoCosManager.check_cocos(model, syn_model=syn_model) if Logger.has_errors(model.name): code, message = Messages.get_model_contains_errors(model.get_name()) diff --git a/pynestml/symbols/predefined_functions.py b/pynestml/symbols/predefined_functions.py index efdf483a3..993889f2d 100644 --- a/pynestml/symbols/predefined_functions.py +++ b/pynestml/symbols/predefined_functions.py @@ -30,6 +30,7 @@ class PredefinedFunctions: This class is used to represent all predefined functions of NESTML. """ + HEAVISIDE = "Heaviside" TIME_RESOLUTION = "resolution" TIME_TIMESTEP = "timestep" TIME_STEPS = "steps" @@ -114,6 +115,18 @@ def register_function(cls, name, params, return_type, element_reference): element_reference=element_reference, is_predefined=True) cls.name2function[name] = symbol + @classmethod + def __register_heaviside_function(cls): + """ + Registers the heaviside function. This function returns 0 for negative input and otherwise 1. + """ + params = list() + params.append(PredefinedTypes.get_real_type()) + symbol = FunctionSymbol(name=cls.HEAVISIDE, param_types=params, + return_type=PredefinedTypes.get_real_type(), + element_reference=None, is_predefined=True) + cls.name2function[cls.HEAVISIDE] = symbol + @classmethod def __register_time_steps_function(cls): """ diff --git a/pynestml/transformers/inline_expression_expansion_transformer.py b/pynestml/transformers/inline_expression_expansion_transformer.py index 065c03b8f..ba2497850 100644 --- a/pynestml/transformers/inline_expression_expansion_transformer.py +++ b/pynestml/transformers/inline_expression_expansion_transformer.py @@ -82,23 +82,24 @@ def make_inline_expressions_self_contained(self, inline_expressions: List[ASTInl from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor for source in inline_expressions: - source_position = source.get_source_position() - for target in inline_expressions: - matcher = re.compile(self._variable_matching_template.format(source.get_variable_name())) - target_definition = str(target.get_expression()) - target_definition = re.sub(matcher, "(" + str(source.get_expression()) + ")", target_definition) - old_parent = target.expression.parent_ - target.expression = ModelParser.parse_expression(target_definition) - target.expression.update_scope(source.get_scope()) - target.expression.parent_ = old_parent - target.expression.accept(ASTParentVisitor()) - target.expression.accept(ASTSymbolTableVisitor()) - - def log_set_source_position(node): - if node.get_source_position().is_added_source_position(): - node.set_source_position(source_position) - - target.expression.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position)) + if "mechanism" not in [e.namespace for e in source.get_decorators()]: + source_position = source.get_source_position() + for target in inline_expressions: + matcher = re.compile(self._variable_matching_template.format(source.get_variable_name())) + target_definition = str(target.get_expression()) + target_definition = re.sub(matcher, "(" + str(source.get_expression()) + ")", target_definition) + old_parent = target.expression.parent_ + target.expression = ModelParser.parse_expression(target_definition) + target.expression.update_scope(source.get_scope()) + target.expression.parent_ = old_parent + target.expression.accept(ASTParentVisitor()) + target.expression.accept(ASTSymbolTableVisitor()) + + def log_set_source_position(node): + if node.get_source_position().is_added_source_position(): + node.set_source_position(source_position) + + target.expression.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position)) return inline_expressions diff --git a/pynestml/utils/ast_global_information_collector.py b/pynestml/utils/ast_global_information_collector.py new file mode 100644 index 000000000..242c9516a --- /dev/null +++ b/pynestml/utils/ast_global_information_collector.py @@ -0,0 +1,530 @@ +# -*- coding: utf-8 -*- +# +# ast_global_information_collector.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from collections import defaultdict + +from pynestml.symbols.predefined_units import PredefinedUnits +from pynestml.visitors.ast_visitor import ASTVisitor + + +class ASTGlobalInformationCollector(object): + """ + This file is part of the compartmental code generation process. + + Collects information about parts of the code that are relevant within the update or OnReceive(self_spike) blocks. + """ + collector_visitor = None + synapse = None + + @classmethod + def __init__(cls, neuron): + cls.neuron = neuron + cls.collector_visitor = ASTMechanismInformationCollectorVisitor() + neuron.accept(cls.collector_visitor) + + @classmethod + def collect_update_block(cls, synapse, global_info): + update_block_collector_visitor = ASTUpdateBlockVisitor() + synapse.accept(update_block_collector_visitor) + global_info["UpdateBlock"] = update_block_collector_visitor.update_block + return global_info + + @classmethod + def collect_self_spike_function(cls, neuron, global_info): + on_receive_collector_visitor = ASTOnReceiveBlockCollectorVisitor() + neuron.accept(on_receive_collector_visitor) + + for function in on_receive_collector_visitor.all_on_receive_blocks: + if function.get_port_name() == "self_spikes": + global_info["SelfSpikesFunction"] = function + + return global_info + + @classmethod + def extend_variables_with_initialisations(cls, neuron, global_info): + """collects initialization expressions for all variables and parameters contained in global_info""" + var_init_visitor = VariableInitializationVisitor(global_info) + neuron.accept(var_init_visitor) + global_info["States"] = var_init_visitor.states + global_info["Parameters"] = var_init_visitor.parameters + global_info["Internals"] = var_init_visitor.internals + + return global_info + + @classmethod + def extend_variable_list_name_based_restricted(cls, extended_list, appending_list, restrictor_list): + """go through appending_list and append every variable that is not in restrictor_list to extended_list for the + purpose of not re-searching the same variable""" + for app_item in appending_list: + appendable = True + for rest_item in restrictor_list: + if rest_item.name == app_item.name: + appendable = False + break + if appendable: + extended_list.append(app_item) + + return extended_list + + @classmethod + def extend_function_call_list_name_based_restricted(cls, extended_list, appending_list, restrictor_list): + """go through appending_list and append every variable that is not in restrictor_list to extended_list for the + purpose of not re-searching the same function""" + for app_item in appending_list: + appendable = True + for rest_item in restrictor_list: + if rest_item.callee_name == app_item.callee_name: + appendable = False + break + if appendable: + extended_list.append(app_item) + + return extended_list + + @classmethod + def collect_related_definitions(cls, neuron, global_info): + """Collects all parts of the nestml code the root expressions previously collected depend on. search + is cut at other mechanisms root expressions""" + from pynestml.meta_model.ast_inline_expression import ASTInlineExpression + from pynestml.meta_model.ast_ode_equation import ASTOdeEquation + + variable_collector = ASTVariableCollectorVisitor() + neuron.accept(variable_collector) + global_states = variable_collector.all_states + global_parameters = variable_collector.all_parameters + global_internals = variable_collector.all_internals + + function_collector = ASTFunctionCollectorVisitor() + neuron.accept(function_collector) + global_functions = function_collector.all_functions + + inline_collector = ASTInlineEquationCollectorVisitor() + neuron.accept(inline_collector) + global_inlines = inline_collector.all_inlines + + ode_collector = ASTODEEquationCollectorVisitor() + neuron.accept(ode_collector) + global_odes = ode_collector.all_ode_equations + + kernel_collector = ASTKernelCollectorVisitor() + neuron.accept(kernel_collector) + global_kernels = kernel_collector.all_kernels + + continuous_input_collector = ASTContinuousInputDeclarationVisitor() + neuron.accept(continuous_input_collector) + global_continuous_inputs = continuous_input_collector.ports + + mechanism_states = list() + mechanism_parameters = list() + mechanism_internals = list() + mechanism_functions = list() + mechanism_inlines = list() + mechanism_odes = list() + synapse_kernels = list() + mechanism_continuous_inputs = list() + + search_variables = list() + search_functions = list() + + found_variables = list() + found_functions = list() + + if "SelfSpikesFunction" in global_info and global_info["SelfSpikesFunction"] is not None: + local_variable_collector = ASTVariableCollectorVisitor() + global_info["SelfSpikesFunction"].accept(local_variable_collector) + search_variables_self_spike = local_variable_collector.all_variables + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + search_variables_self_spike, + search_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + global_info["SelfSpikesFunction"].accept(local_function_call_collector) + search_functions_self_spike = local_function_call_collector.all_function_calls + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + search_functions_self_spike, + search_functions) + + if "UpdateBlock" in global_info and global_info["UpdateBlock"] is not None: + local_variable_collector = ASTVariableCollectorVisitor() + global_info["UpdateBlock"].accept(local_variable_collector) + search_variables_update = local_variable_collector.all_variables + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + search_variables_update, + search_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + global_info["UpdateBlock"].accept(local_function_call_collector) + search_functions_update = local_function_call_collector.all_function_calls + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + search_functions_update, + search_functions) + + while len(search_functions) > 0 or len(search_variables) > 0: + if len(search_functions) > 0: + function_call = search_functions[0] + for function in global_functions: + if function.name == function_call.callee_name: + mechanism_functions.append(function) + found_functions.append(function_call) + + local_variable_collector = ASTVariableCollectorVisitor() + function.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + function.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + # IMPLEMENT CATCH NONDEFINED!!! + search_functions.remove(function_call) + + elif len(search_variables) > 0: + variable = search_variables[0] + if not (variable.name == "v_comp" or variable.name in PredefinedUnits.get_units()): + is_dependency = False + for inline in global_inlines: + if variable.name == inline.variable_name: + if isinstance(inline.get_decorators(), list): + if "mechanism" in [e.namespace for e in inline.get_decorators()]: + is_dependency = True + + if not is_dependency: + mechanism_inlines.append(inline) + + local_variable_collector = ASTVariableCollectorVisitor() + inline.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + inline.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted( + search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for ode in global_odes: + if variable.name == ode.lhs.name: + if isinstance(ode.get_decorators(), list): + if "mechanism" in [e.namespace for e in ode.get_decorators()]: + is_dependency = True + + if not is_dependency: + mechanism_odes.append(ode) + + local_variable_collector = ASTVariableCollectorVisitor() + ode.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + ode.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted( + search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for state in global_states: + if variable.name == state.name and not is_dependency: + mechanism_states.append(state) + + for parameter in global_parameters: + if variable.name == parameter.name: + mechanism_parameters.append(parameter) + + for internal in global_internals: + if variable.name == internal.name: + mechanism_internals.append(internal) + + for kernel in global_kernels: + if variable.name == kernel.get_variables()[0].name: + synapse_kernels.append(kernel) + + local_variable_collector = ASTVariableCollectorVisitor() + kernel.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + kernel.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for input in global_continuous_inputs: + if variable.name == input.name: + mechanism_continuous_inputs.append(input) + search_variables.remove(variable) + found_variables.append(variable) + + global_info["States"] = mechanism_states + global_info["Parameters"] = mechanism_parameters + global_info["Internals"] = mechanism_internals + global_info["Functions"] = mechanism_functions + global_info["SecondaryInlineExpressions"] = mechanism_inlines + global_info["ODEs"] = mechanism_odes + global_info["Continuous"] = mechanism_continuous_inputs + + return global_info + + +class ASTMechanismInformationCollectorVisitor(ASTVisitor): + + def __init__(self): + super(ASTMechanismInformationCollectorVisitor, self).__init__() + self.inEquationsBlock = False + self.inlinesInEquationsBlock = list() + self.odes = list() + + def visit_equations_block(self, node): + self.inEquationsBlock = True + + def endvisit_equations_block(self, node): + self.inEquationsBlock = False + + def visit_inline_expression(self, node): + if self.inEquationsBlock: + self.inlinesInEquationsBlock.append(node) + + def visit_ode_equation(self, node): + self.odes.append(node) + + +class ASTUpdateBlockVisitor(ASTVisitor): + def __init__(self): + super(ASTUpdateBlockVisitor, self).__init__() + self.inside_update_block = False + self.update_block = None + + def visit_update_block(self, node): + self.inside_update_block = True + self.update_block = node.clone() + + def endvisit_update_block(self, node): + self.inside_update_block = False + + +class VariableInitializationVisitor(ASTVisitor): + def __init__(self, channel_info): + super(VariableInitializationVisitor, self).__init__() + self.inside_variable = False + self.inside_declaration = False + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internal_block = False + self.current_declaration = None + self.states = defaultdict() + self.parameters = defaultdict() + self.internals = defaultdict() + self.channel_info = channel_info + + def visit_declaration(self, node): + self.inside_declaration = True + self.current_declaration = node + + def endvisit_declaration(self, node): + self.inside_declaration = False + self.current_declaration = None + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internal_block = True + + def endvisit_block_with_variables(self, node): + self.inside_state_block = False + self.inside_parameter_block = False + self.inside_internal_block = False + + def visit_variable(self, node): + self.inside_variable = True + if self.inside_state_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["States"]): + self.states[node.name] = defaultdict() + self.states[node.name]["ASTVariable"] = node.clone() + self.states[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + if self.inside_parameter_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Parameters"]): + self.parameters[node.name] = defaultdict() + self.parameters[node.name]["ASTVariable"] = node.clone() + self.parameters[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + if self.inside_internal_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Internals"]): + self.internals[node.name] = defaultdict() + self.internals[node.name]["ASTVariable"] = node.clone() + self.internals[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + def endvisit_variable(self, node): + self.inside_variable = False + + +class ASTODEEquationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTODEEquationCollectorVisitor, self).__init__() + self.inside_ode_expression = False + self.all_ode_equations = list() + + def visit_ode_equation(self, node): + self.inside_ode_expression = True + self.all_ode_equations.append(node.clone()) + + def endvisit_ode_equation(self, node): + self.inside_ode_expression = False + + +class ASTVariableCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTVariableCollectorVisitor, self).__init__() + self.inside_variable = False + self.inside_block_with_variables = False + self.all_states = list() + self.all_parameters = list() + self.all_internals = list() + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.all_variables = list() + + def visit_block_with_variables(self, node): + self.inside_block_with_variables = True + if node.is_state: + self.inside_states_block = True + if node.is_parameters: + self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.inside_block_with_variables = False + + def visit_variable(self, node): + self.inside_variable = True + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states.append(node.clone()) + if self.inside_parameters_block: + self.all_parameters.append(node.clone()) + if self.inside_internals_block: + self.all_internals.append(node.clone()) + + def endvisit_variable(self, node): + self.inside_variable = False + + +class ASTFunctionCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTFunctionCollectorVisitor, self).__init__() + self.inside_function = False + self.all_functions = list() + + def visit_function(self, node): + self.inside_function = True + self.all_functions.append(node.clone()) + + def endvisit_function(self, node): + self.inside_function = False + + +class ASTInlineEquationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTInlineEquationCollectorVisitor, self).__init__() + self.inside_inline_expression = False + self.all_inlines = list() + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + self.all_inlines.append(node.clone()) + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + + +class ASTFunctionCallCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTFunctionCallCollectorVisitor, self).__init__() + self.inside_function_call = False + self.all_function_calls = list() + + def visit_function_call(self, node): + self.inside_function_call = True + self.all_function_calls.append(node.clone()) + + def endvisit_function_call(self, node): + self.inside_function_call = False + + +class ASTKernelCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTKernelCollectorVisitor, self).__init__() + self.inside_kernel = False + self.all_kernels = list() + + def visit_kernel(self, node): + self.inside_kernel = True + self.all_kernels.append(node.clone()) + + def endvisit_kernel(self, node): + self.inside_kernel = False + + +class ASTContinuousInputDeclarationVisitor(ASTVisitor): + def __init__(self): + super(ASTContinuousInputDeclarationVisitor, self).__init__() + self.inside_port = False + self.current_port = None + self.ports = list() + + def visit_input_port(self, node): + self.inside_port = True + self.current_port = node + if self.current_port.is_continuous(): + self.ports.append(node.clone()) + + def endvisit_input_port(self, node): + self.inside_port = False + + +class ASTOnReceiveBlockCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTOnReceiveBlockCollectorVisitor, self).__init__() + self.inside_on_receive_block = False + self.all_on_receive_blocks = list() + + def visit_on_receive_block(self, node): + self.inside_on_receive_block = True + self.all_on_receive_blocks.append(node.clone()) + + def endvisit_on_receive_block(self, node): + self.inside_on_receive_block = False diff --git a/pynestml/utils/ast_mechanism_information_collector.py b/pynestml/utils/ast_mechanism_information_collector.py index bf38df32f..8d32da32a 100644 --- a/pynestml/utils/ast_mechanism_information_collector.py +++ b/pynestml/utils/ast_mechanism_information_collector.py @@ -18,16 +18,24 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +import copy from collections import defaultdict from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_kernel import ASTKernel +from pynestml.symbols.predefined_units import PredefinedUnits +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_visitor import ASTVisitor class ASTMechanismInformationCollector(object): - """This class contains all basic mechanism information collection. Further collectors may be implemented to collect - further information for specific mechanism types (example: ASTSynapseInformationCollector)""" + """ + This file is part of the compartmental code generation process. + + This class contains all basic mechanism information collection. Further collectors may be implemented to collect + further information for specific mechanism types (example: ASTReceptorInformationCollector) + """ collector_visitor = None neuron = None @@ -106,13 +114,20 @@ def extend_variables_with_initialisations(cls, neuron, mechs_info): return mechs_info @classmethod - def collect_mechanism_related_definitions(cls, neuron, mechs_info): + def collect_mechanism_related_definitions(cls, neuron, mechs_info, global_info, mech_type: str): """Collects all parts of the nestml code the root expressions previously collected depend on. search is cut at other mechanisms root expressions""" from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_ode_equation import ASTOdeEquation + if "Dependencies" not in global_info: + global_info["Dependencies"] = dict() + for mechanism_name, mechanism_info in mechs_info.items(): + if mech_type not in global_info["Dependencies"]: + global_info["Dependencies"][mech_type] = dict() + global_info["Dependencies"][mech_type][mechanism_name] = list() + variable_collector = ASTVariableCollectorVisitor() neuron.accept(variable_collector) global_states = variable_collector.all_states @@ -152,21 +167,17 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): mechanism_dependencies["channels"] = list() mechanism_dependencies["receptors"] = list() mechanism_dependencies["continuous"] = list() - - mechanism_inlines.append(mechs_info[mechanism_name]["root_expression"]) - - search_variables = list() - search_functions = list() + mechanism_dependencies["global"] = list() found_variables = list() found_functions = list() local_variable_collector = ASTVariableCollectorVisitor() - mechanism_inlines[0].accept(local_variable_collector) + mechs_info[mechanism_name]["root_expression"].accept(local_variable_collector) search_variables = local_variable_collector.all_variables local_function_call_collector = ASTFunctionCallCollectorVisitor() - mechanism_inlines[0].accept(local_function_call_collector) + mechs_info[mechanism_name]["root_expression"].accept(local_function_call_collector) search_functions = local_function_call_collector.all_function_calls while len(search_functions) > 0 or len(search_variables) > 0: @@ -193,14 +204,15 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): elif len(search_variables) > 0: variable = search_variables[0] - if not variable.name == "v_comp": + if not (variable.name == "v_comp" or variable.name in PredefinedUnits.get_units()): is_dependency = False for inline in global_inlines: if variable.name == inline.variable_name: if isinstance(inline.get_decorators(), list): if "mechanism" in [e.namespace for e in inline.get_decorators()]: is_dependency = True - if not (isinstance(mechanism_info["root_expression"], ASTInlineExpression) and inline.variable_name == mechanism_info["root_expression"].variable_name): + if not (isinstance(mechanism_info["root_expression"], + ASTInlineExpression) and inline.variable_name == mechanism_info["root_expression"].variable_name): if "channel" in [e.name for e in inline.get_decorators()]: if not inline.variable_name in [i.variable_name for i in mechanism_dependencies["channels"]]: @@ -232,13 +244,17 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): for ode in global_odes: if variable.name == ode.lhs.name: + if ode.lhs.name in global_info["States"]: + global_info["Dependencies"][mech_type][mechanism_name].append(ode.lhs) + del global_info["States"][ode.lhs.name] + # is_dependency = True if isinstance(ode.get_decorators(), list): if "mechanism" in [e.namespace for e in ode.get_decorators()]: is_dependency = True - if not (isinstance(mechanism_info["root_expression"], ASTOdeEquation) and ode.lhs.name == mechanism_info["root_expression"].lhs.name): + if not (isinstance(mechanism_info["root_expression"], + ASTOdeEquation) and ode.lhs.name == mechanism_info["root_expression"].lhs.name): if "concentration" in [e.name for e in ode.get_decorators()]: - if not ode.lhs.name in [o.lhs.name for o in - mechanism_dependencies["concentrations"]]: + if not ode.lhs.name in [o.lhs.name for o in mechanism_dependencies["concentrations"]]: mechanism_dependencies["concentrations"].append(ode) if not is_dependency: @@ -257,17 +273,35 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): local_function_call_collector.all_function_calls, search_functions + found_functions) - for state in global_states: - if variable.name == state.name and not is_dependency: - mechanism_states.append(state) + for state_name, state in global_states.items(): + if variable.name == state_name: + if state["ASTVariable"] in global_info["States"]: + mechanism_dependencies["global"].append(state["ASTVariable"]) + global_info["Dependencies"][mech_type][mechanism_name].append(state["ASTVariable"]) + del global_info["States"][state_name] + + if not is_dependency: + mechanism_states.append(state["ASTVariable"]) + + for parameter_name, parameter in global_parameters.items(): + if variable.name == parameter_name: + mechanism_parameters.append(parameter["ASTVariable"]) - for parameter in global_parameters: - if variable.name == parameter.name: - mechanism_parameters.append(parameter) + for internal_name, internal in global_internals.items(): + if variable.name == internal_name: + mechanism_internals.append(internal["ASTVariable"]) - for internal in global_internals: - if variable.name == internal.name: - mechanism_internals.append(internal) + local_variable_collector = ASTVariableCollectorVisitor() + internal["ASTExpression"].accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + internal["ASTExpression"].accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) for kernel in global_kernels: if variable.name == kernel.get_variables()[0].name: @@ -304,6 +338,106 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): return mechs_info + @classmethod + def collect_kernels(cls, neuron, mechs_info): + """ + Collect internals, kernels, inputs and convolutions associated with the synapse. + """ + for mechanism_name, mechanism_info in mechs_info.items(): + mechanism_info["convolutions"] = defaultdict() + info_collector = ASTKernelInformationCollectorVisitor() + neuron.accept(info_collector) + + inlines = copy.copy(mechanism_info["SecondaryInlineExpressions"]) + if isinstance(mechanism_info["root_expression"], ASTInlineExpression): + inlines.append(mechanism_info["root_expression"]) + for inline in inlines: + kernel_arg_pairs = info_collector.get_extracted_kernel_args_by_name( + inline.get_variable_name()) + for kernel_var, spikes_var in kernel_arg_pairs: + kernel_name = kernel_var.get_name() + spikes_name = spikes_var.get_name() + convolution_name = info_collector.construct_kernel_X_spike_buf_name( + kernel_name, spikes_name, 0) + mechanism_info["convolutions"][convolution_name] = { + "kernel": { + "name": kernel_name, + "ASTKernel": info_collector.get_kernel_by_name(kernel_name), + }, + "spikes": { + "name": spikes_name, + "ASTInputPort": info_collector.get_input_port_by_name(spikes_name), + }, + } + return mechs_info + + @classmethod + def collect_block_dependencies_and_owned(cls, mech_info, blocks, block_type): + block = blocks[0].clone() + blocks.pop(0) + for next_block in blocks: + block.stmts += next_block.stmts + + block.accept(ASTParentVisitor()) + + for mechanism_name, mechanism_info in mech_info.items(): + dependencies = list() + updated_dependencies = list() + owned = list() + updated_owned = mechanism_info["States"] + mechanism_info["Parameters"] + mechanism_info["Internals"] + + loop_counter = 0 + while set([v.get_name() for v in owned]) != set([v.get_name() for v in updated_owned]) or set( + [v.get_name() for v in dependencies]) != set([v.get_name() for v in updated_dependencies]): + owned = updated_owned + dependencies = updated_dependencies + collector = ASTUpdateBlockDependencyAndOwnedExtractor(owned, dependencies) + block.accept(collector) + updated_owned = collector.owned + updated_dependencies = collector.dependencies + loop_counter += 1 + + mechanism_info["Blocks"] = dict() + mechanism_info["Blocks"]["dependencies"] = dependencies + mechanism_info["Blocks"]["owned"] = owned + + @classmethod + def block_reduction(cls, mech_info, block, block_type): + for mechanism_name, mechanism_info in mech_info.items(): + owned = mechanism_info["Blocks"]["owned"] + dependencies = mechanism_info["Blocks"]["dependencies"] + new_update_block = block.clone() + new_update_block.accept(ASTParentVisitor()) + update_block_reductor = ASTUpdateBlockReductor(owned, dependencies) + new_update_block.accept(update_block_reductor) + mechanism_info["Blocks"][block_type] = new_update_block + + @classmethod + def recursive_update_block_reduction(cls, mech_info, eliminated, update_block): + reduced_update_blocks = dict() + reduced_update_blocks["Block"] = update_block + reduced_update_blocks["Reductions"] = dict() + for mechanism_name, mechanism_info in mech_info.items(): + if mechanism_name not in eliminated: + if "UpdateBlockComputation" in mechanism_info: + owned = mechanism_info["UpdateBlockComputation"]["owned"] + exclusive_dependencies = mechanism_info["UpdateBlockComputation"]["dependencies"] + leftovers_depend = False + for comp_mechanism_name, comp_mechanism_info in mech_info.items(): + if mechanism_name != comp_mechanism_name: + if "UpdateBlockComputation" in comp_mechanism_info: + exclusive_dependencies = list(set(exclusive_dependencies) - set( + comp_mechanism_info["UpdateBlockComputation"]["dependencies"])) + leftovers_depend = leftovers_depend and len(set(comp_mechanism_info["UpdateBlockComputation"]["dependencies"] + comp_mechanism_info["UpdateBlockComputation"]["owned"]) & set(owned)) > 0 + if not leftovers_depend: + new_update_block = update_block.clone() + update_block_reductor = ASTUpdateBlockReductor(owned, exclusive_dependencies) + new_update_block.accept(update_block_reductor) + reduced_update_blocks["Reductions"][mechanism_name] = cls.recursive_update_block_reduction( + mech_info, eliminated.append(mechanism_name), new_update_block) + + return reduced_update_blocks + class ASTMechanismInformationCollectorVisitor(ASTVisitor): @@ -336,15 +470,37 @@ def __init__(self, channel_info): self.inside_parameter_block = False self.inside_state_block = False self.inside_internal_block = False + self.inside_expression = False + self.current_declaration = None self.states = defaultdict() self.parameters = defaultdict() self.internals = defaultdict() self.channel_info = channel_info + self.search_vars = channel_info["States"] + channel_info["Parameters"] + channel_info["Internals"] + if "Blocks" in channel_info: + self.search_vars += channel_info["Blocks"]["dependencies"] + self.search_vars += channel_info["Blocks"]["owned"] def visit_declaration(self, node): self.inside_declaration = True self.current_declaration = node + for var in node.variables: + if any(var.name == variable.name for variable in self.search_vars): + if self.inside_state_block: + self.states[var.name] = defaultdict() + self.states[var.name]["ASTVariable"] = var.clone() + self.states[var.name]["rhs_expression"] = node.get_expression().clone() + + if self.inside_parameter_block: + self.parameters[var.name] = defaultdict() + self.parameters[var.name]["ASTVariable"] = var.clone() + self.parameters[var.name]["rhs_expression"] = node.get_expression().clone() + + if self.inside_internal_block: + self.internals[var.name] = defaultdict() + self.internals[var.name]["ASTVariable"] = var.clone() + self.internals[var.name]["rhs_expression"] = node.get_expression().clone() def endvisit_declaration(self, node): self.inside_declaration = False @@ -365,27 +521,16 @@ def endvisit_block_with_variables(self, node): def visit_variable(self, node): self.inside_variable = True - if self.inside_state_block and self.inside_declaration: - if any(node.name == variable.name for variable in self.channel_info["States"]): - self.states[node.name] = defaultdict() - self.states[node.name]["ASTVariable"] = node.clone() - self.states[node.name]["rhs_expression"] = self.current_declaration.get_expression() - - if self.inside_parameter_block and self.inside_declaration: - if any(node.name == variable.name for variable in self.channel_info["Parameters"]): - self.parameters[node.name] = defaultdict() - self.parameters[node.name]["ASTVariable"] = node.clone() - self.parameters[node.name]["rhs_expression"] = self.current_declaration.get_expression() - - if self.inside_internal_block and self.inside_declaration: - if any(node.name == variable.name for variable in self.channel_info["Internals"]): - self.internals[node.name] = defaultdict() - self.internals[node.name]["ASTVariable"] = node.clone() - self.internals[node.name]["rhs_expression"] = self.current_declaration.get_expression() def endvisit_variable(self, node): self.inside_variable = False + def visit_expression(self, node): + self.inside_expression = True + + def endvisit_expression(self, node): + self.inside_expression = False + class ASTODEEquationCollectorVisitor(ASTVisitor): def __init__(self): @@ -406,13 +551,37 @@ def __init__(self): super(ASTVariableCollectorVisitor, self).__init__() self.inside_variable = False self.inside_block_with_variables = False - self.all_states = list() - self.all_parameters = list() - self.all_internals = list() + self.all_states = dict() + self.all_parameters = dict() + self.all_internals = dict() self.inside_states_block = False self.inside_parameters_block = False self.inside_internals_block = False + self.inside_declaration = False + self.inside_expression_inside_declaration = False + self.expression_recursion = 0 self.all_variables = list() + self.current_declaration_expression = None + + def visit_declaration(self, node): + self.inside_declaration = True + if node.has_expression(): + self.current_declaration_expression = node.get_expression().clone() + + def endvisit_declaration(self, node): + self.inside_declaration = False + self.current_declaration_expression = None + + def visit_expression(self, node): + if self.inside_declaration: + self.inside_expression_inside_declaration = True + self.expression_recursion += 1 + + def endvisit_expression(self, node): + if self.inside_declaration: + self.expression_recursion -= 1 + if self.expression_recursion == 0: + self.inside_expression_inside_declaration = False def visit_block_with_variables(self, node): self.inside_block_with_variables = True @@ -431,13 +600,17 @@ def endvisit_block_with_variables(self, node): def visit_variable(self, node): self.inside_variable = True - self.all_variables.append(node.clone()) - if self.inside_states_block: - self.all_states.append(node.clone()) - if self.inside_parameters_block: - self.all_parameters.append(node.clone()) - if self.inside_internals_block: - self.all_internals.append(node.clone()) + if not self.inside_expression_inside_declaration: + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states[node.get_name()] = {"ASTVariable": node.clone(), + "ASTExpression": self.current_declaration_expression} + if self.inside_parameters_block: + self.all_parameters[node.get_name()] = {"ASTVariable": node.clone(), + "ASTExpression": self.current_declaration_expression} + if self.inside_internals_block: + self.all_internals[node.get_name()] = {"ASTVariable": node.clone(), + "ASTExpression": self.current_declaration_expression} def endvisit_variable(self, node): self.inside_variable = False @@ -514,3 +687,511 @@ def visit_input_port(self, node): def endvisit_input_port(self, node): self.inside_port = False + + +class ASTKernelInformationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTKernelInformationCollectorVisitor, self).__init__() + + # various dicts to store collected information + self.kernel_name_to_kernel = defaultdict() + self.inline_expression_to_kernel_args = defaultdict(lambda: set()) + self.inline_expression_to_function_calls = defaultdict(lambda: set()) + self.kernel_to_function_calls = defaultdict(lambda: set()) + self.parameter_name_to_declaration = defaultdict(lambda: None) + self.state_name_to_declaration = defaultdict(lambda: None) + self.variable_name_to_declaration = defaultdict(lambda: None) + self.internal_var_name_to_declaration = defaultdict(lambda: None) + self.inline_expression_to_variables = defaultdict(lambda: set()) + self.kernel_to_rhs_variables = defaultdict(lambda: set()) + self.declaration_to_rhs_variables = defaultdict(lambda: set()) + self.input_port_name_to_input_port = defaultdict() + + # traversal states and nodes + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internals_block = False + self.inside_equations_block = False + self.inside_input_block = False + self.inside_inline_expression = False + self.inside_kernel = False + self.inside_kernel_call = False + self.inside_declaration = False + self.inside_simple_expression = False + self.inside_expression = False + + self.current_inline_expression = None + self.current_kernel = None + self.current_expression = None + self.current_simple_expression = None + self.current_declaration = None + + self.current_synapse_name = None + + def get_state_declaration(self, variable_name): + return self.state_name_to_declaration[variable_name] + + def get_variable_declaration(self, variable_name): + return self.variable_name_to_declaration[variable_name] + + def get_kernel_by_name(self, name: str): + return self.kernel_name_to_kernel[name] + + def get_inline_expressions_with_kernels(self): + return self.inline_expression_to_kernel_args.keys() + + def get_kernel_function_calls(self, kernel: ASTKernel): + return self.kernel_to_function_calls[kernel] + + def get_inline_function_calls(self, inline: ASTInlineExpression): + return self.inline_expression_to_function_calls[inline] + + def get_variable_names_of_synapse(self, synapse_inline: ASTInlineExpression, exclude_names: set = set(), + exclude_ignorable=True) -> set: + """extracts all variables specific to a single synapse + (which is defined by the inline expression containing kernels) + independently of what block they are declared in + it also cascades over all right hand side variables until all + variables are included""" + if exclude_ignorable: + exclude_names.update(self.get_variable_names_to_ignore()) + + # find all variables used in the inline + potential_variables = self.inline_expression_to_variables[synapse_inline] + + # find all kernels referenced by the inline + # and collect variables used by those kernels + kernel_arg_pairs = self.get_extracted_kernel_args(synapse_inline) + for kernel_var, spikes_var in kernel_arg_pairs: + kernel = self.get_kernel_by_name(kernel_var.get_name()) + potential_variables.update(self.kernel_to_rhs_variables[kernel]) + + # find declarations for all variables and check + # what variables their rhs expressions use + # for example if we have + # a = b * c + # then check if b and c are already in potential_variables + # if not, add those as well + potential_variables_copy = copy.copy(potential_variables) + + potential_variables_prev_count = len(potential_variables) + while True: + for potential_variable in potential_variables_copy: + var_name = potential_variable.get_name() + if var_name in exclude_names: + continue + declaration = self.get_variable_declaration(var_name) + if declaration is None: + continue + variables_referenced = self.declaration_to_rhs_variables[var_name] + potential_variables.update(variables_referenced) + if potential_variables_prev_count == len(potential_variables): + break + potential_variables_prev_count = len(potential_variables) + + # transform variables into their names and filter + # out anything form exclude_names + result = set() + for potential_variable in potential_variables: + var_name = potential_variable.get_name() + if var_name not in exclude_names: + result.add(var_name) + + return result + + @classmethod + def get_variable_names_to_ignore(cls): + return set(PredefinedVariables.get_variables().keys()).union({"v_comp"}) + + def get_synapse_specific_internal_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the internals block + dereferenced = defaultdict() + for potential_internals_name in synapse_variable_names: + if potential_internals_name in self.internal_var_name_to_declaration: + dereferenced[potential_internals_name] = self.internal_var_name_to_declaration[ + potential_internals_name] + return dereferenced + + def get_synapse_specific_state_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the state block + dereferenced = defaultdict() + for potential_state_name in synapse_variable_names: + if potential_state_name in self.state_name_to_declaration: + dereferenced[potential_state_name] = self.state_name_to_declaration[potential_state_name] + return dereferenced + + def get_synapse_specific_parameter_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the parameter block + dereferenced = defaultdict() + for potential_param_name in synapse_variable_names: + if potential_param_name in self.parameter_name_to_declaration: + dereferenced[potential_param_name] = self.parameter_name_to_declaration[potential_param_name] + return dereferenced + + def get_extracted_kernel_args(self, inline_expression: ASTInlineExpression) -> set: + return self.inline_expression_to_kernel_args[inline_expression] + + def get_extracted_kernel_args_by_name(self, inline_name: str) -> set: + inline_expression = [inline for inline in self.inline_expression_to_kernel_args.keys() if + inline.get_variable_name() == inline_name] + if len(inline_expression): + return self.inline_expression_to_kernel_args[inline_expression[0]] + else: + return set() + + def get_basic_kernel_variable_names(self, synapse_inline): + """ + for every occurence of convolve(port, spikes) generate "port__X__spikes" variable + gather those variables for this synapse inline and return their list + + note that those variables will occur as substring in other kernel variables i.e "port__X__spikes__d" or "__P__port__X__spikes__port__X__spikes" + + so we can use the result to identify all the other kernel variables related to the + specific synapse inline declaration + """ + order = 0 + results = [] + for syn_inline, args in self.inline_expression_to_kernel_args.items(): + if synapse_inline.variable_name == syn_inline.variable_name: + for kernel_var, spike_var in args: + kernel_name = kernel_var.get_name() + spike_input_port = self.input_port_name_to_input_port[spike_var.get_name( + )] + kernel_variable_name = self.construct_kernel_X_spike_buf_name( + kernel_name, spike_input_port, order) + results.append(kernel_variable_name) + + return results + + def get_used_kernel_names(self, inline_expression: ASTInlineExpression): + return [kernel_var.get_name() for kernel_var, _ in self.get_extracted_kernel_args(inline_expression)] + + def get_input_port_by_name(self, name): + return self.input_port_name_to_input_port[name] + + def get_used_spike_names(self, inline_expression: ASTInlineExpression): + return [spikes_var.get_name() for _, spikes_var in self.get_extracted_kernel_args(inline_expression)] + + def visit_kernel(self, node): + self.current_kernel = node + self.inside_kernel = True + if self.inside_equations_block: + kernel_name = node.get_variables()[0].get_name_of_lhs() + self.kernel_name_to_kernel[kernel_name] = node + + def visit_function_call(self, node): + if self.inside_equations_block: + if self.inside_inline_expression and self.inside_simple_expression: + if node.get_name() == "convolve": + self.inside_kernel_call = True + kernel, spikes = node.get_args() + kernel_var = kernel.get_variables()[0] + spikes_var = spikes.get_variables()[0] + self.inline_expression_to_kernel_args[self.current_inline_expression].add( + (kernel_var, spikes_var)) + else: + self.inline_expression_to_function_calls[self.current_inline_expression].add( + node) + if self.inside_kernel and self.inside_simple_expression: + self.kernel_to_function_calls[self.current_kernel].add(node) + + def endvisit_function_call(self, node): + self.inside_kernel_call = False + + def endvisit_kernel(self, node): + self.current_kernel = None + self.inside_kernel = False + + def visit_variable(self, node): + if self.inside_inline_expression and not self.inside_kernel_call: + self.inline_expression_to_variables[self.current_inline_expression].add( + node) + elif self.inside_kernel and (self.inside_expression or self.inside_simple_expression): + self.kernel_to_rhs_variables[self.current_kernel].add(node) + elif self.inside_declaration and self.inside_expression: + declared_variable = self.current_declaration.get_variables()[ + 0].get_name() + self.declaration_to_rhs_variables[declared_variable].add(node) + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + self.current_inline_expression = node + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + self.current_inline_expression = None + + def visit_equations_block(self, node): + self.inside_equations_block = True + + def endvisit_equations_block(self, node): + self.inside_equations_block = False + + def visit_input_block(self, node): + self.inside_input_block = True + + def visit_input_port(self, node): + self.input_port_name_to_input_port[node.get_name()] = node + + def endvisit_input_block(self, node): + self.inside_input_block = False + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = False + if node.is_parameters: + self.inside_parameter_block = False + if node.is_internals: + self.inside_internals_block = False + + def visit_simple_expression(self, node): + self.inside_simple_expression = True + self.current_simple_expression = node + + def endvisit_simple_expression(self, node): + self.inside_simple_expression = False + self.current_simple_expression = None + + def visit_declaration(self, node): + self.inside_declaration = True + self.current_declaration = node + + # collect decalarations generally + variable_name = node.get_variables()[0].get_name() + self.variable_name_to_declaration[variable_name] = node + + # collect declarations per block + if self.inside_parameter_block: + self.parameter_name_to_declaration[variable_name] = node + elif self.inside_state_block: + self.state_name_to_declaration[variable_name] = node + elif self.inside_internals_block: + self.internal_var_name_to_declaration[variable_name] = node + + def endvisit_declaration(self, node): + self.inside_declaration = False + self.current_declaration = None + + def visit_expression(self, node): + self.inside_expression = True + self.current_expression = node + + def endvisit_expression(self, node): + self.inside_expression = False + self.current_expression = None + + # this method was copied over from ast_transformer + # in order to avoid a circular dependency + @staticmethod + def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, + diff_order_symbol="__d"): + assert type(kernel_var_name) is str + assert type(order) is int + assert type(diff_order_symbol) is str + return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str( + spike_input_port) + diff_order_symbol * order + + +class ASTUpdateBlockDependencyAndOwnedExtractor(ASTVisitor): + def __init__(self, init_owned, init_dep): + super(ASTUpdateBlockDependencyAndOwnedExtractor, self).__init__() + + self.inside_block = None + self.dependencies = init_dep + self.owned = init_owned + + self.block_with_dep = False + self.block_with_owned = False + + self.control_depth = 0 + self.current_control_vars = list() + self.inside_ctl_condition = False + + self.expression_depth = 0 + + self.inside_small_stmt = False + self.inside_stmt = False + self.inside_variable = False + self.inside_if_stmt = False + self.inside_while_stmt = False + self.inside_for_stmt = False + + self.inside_if_clause = False + self.inside_elif_clause = False + self.inside_else_clause = False + self.inside_for_clause = False + self.inside_while_clause = False + + def visit_variable(self, node): + self.inside_variable = True + + def endvisit_variable(self, node): + self.inside_variable = False + + def visit_if_stmt(self, node): + self.inside_if_stmt = True + self.control_depth += 1 + self.inside_ctl_condition = True + + def endvisit_if_stmt(self, node): + self.control_depth -= 1 + self.inside_ctl_condition = False + self.current_control_vars.pop() + self.inside_if_stmt = False + + def visit_while_stmt(self, node): + self.inside_while_stmt = True + self.control_depth += 1 + var_collector = ASTVariableCollectorVisitor() + node.condition.accept(var_collector) + self.current_control_vars.append(var_collector.all_variables) + self.inside_ctl_condition = True + + def endvisit_while_stmt(self, node): + self.control_depth -= 1 + self.inside_ctl_condition = False + self.current_control_vars.pop() + self.inside_while_stmt = False + + def visit_for_stmt(self, node): + self.inside_for_stmt = True + self.control_depth += 1 + var_collector = ASTVariableCollectorVisitor() + node.condition.accept(var_collector) + self.current_control_vars.append(var_collector.all_variables) + self.inside_ctl_condition = True + + def endvisit_for_stmt(self, node): + self.control_depth -= 1 + self.inside_ctl_condition = False + self.current_control_vars.pop() + self.inside_for_stmt = False + + def visit_if_clause(self, node): + self.inside_if_clause = True + var_collector = ASTVariableCollectorVisitor() + node.condition.accept(var_collector) + self.current_control_vars.append(var_collector.all_variables) + + def endvisit_if_clause(self, node): + self.inside_if_clause = False + + def visit_elif_clause(self, node): + self.inside_elif_clause = True + var_collector = ASTVariableCollectorVisitor() + node.condition.accept(var_collector) + self.current_control_vars[-1].extend(var_collector.all_variables) + + def endvisit_elif_clause(self, node): + self.inside_elif_clause = False + + def visit_block(self, node): + self.inside_block = True + self.inside_ctl_condition = False + + def endvisit_block(self, node): + self.inside_block = False + self.inside_ctl_condition = True + + def visit_assignment(self, node): + self.inside_assignment = True + var_collector = ASTVariableCollectorVisitor() + node.rhs.accept(var_collector) + if node.lhs.get_name() in [n.get_name() for n in (self.dependencies + self.owned)]: + self.dependencies.extend(var_collector.all_variables) + for dep in self.current_control_vars: + self.dependencies.extend(dep) + + if len(set([n.get_name() for n in self.owned]) & set([n.get_name() for n in var_collector.all_variables])): + self.owned.append(node.lhs) + + def endvisit_assignment(self, node): + self.inside_assignment = False + + +class ASTUpdateBlockReductor(ASTVisitor): + def __init__(self, init_owned, init_exclusive_dep): + super(ASTUpdateBlockReductor, self).__init__() + + self.dependencies = init_exclusive_dep + self.owned = init_owned + + self.delete_stmts = list() + self.current_stmt_index = list() + self.block_depth = -1 + + self.inside_stmt = False + self.inside_if_stmt = False + self.inside_while_stmt = False + self.inside_for_stmt = False + + def visit_if_stmt(self, node): + self.inside_if_stmt = True + + def endvisit_if_stmt(self, node): + all_empty = len(node.get_if_clause().get_stmts_body().get_stmts()) == 0 + for block in [n.get_stmts_body() for n in node.get_elif_clauses()]: + all_empty = all_empty and len(block.get_stmts()) == 0 + if node.get_else_clause() is not None: + all_empty = all_empty and len(node.get_else_clause().get_stmts_body().get_stmts()) == 0 + if all_empty: + self.delete_stmts[self.block_depth].append(node.get_parent().get_parent()) + self.inside_if_stmt = False + + def visit_while_stmt(self, node): + self.inside_while_stmt = True + + def endvisit_while_stmt(self, node): + if len(node.get_stmts_body().get_stmts()) == 0: + self.delete_stmts[self.block_depth].append(node.get_parent().get_parent()) + self.inside_while_stmt = False + + def visit_for_stmt(self, node): + self.inside_for_stmt = True + + def endvisit_for_stmt(self, node): + if len(node.get_stmts_body().get_stmts()) == 0: + self.delete_stmts[self.block_depth].append(node.get_parent().get_parent()) + self.inside_for_stmt = False + + def visit_block(self, node): + self.inside_block = True + self.block_depth += 1 + self.delete_stmts.append(list()) + + def endvisit_block(self, node): + for stmt in self.delete_stmts[self.block_depth]: + node.delete_stmt(stmt) + self.delete_stmts.pop() + self.block_depth -= 1 + self.inside_block = False + + def visit_assignment(self, node): + self.inside_assignment = True + var_collector = ASTVariableCollectorVisitor() + node.rhs.accept(var_collector) + if node.lhs.get_name() not in [n.get_name() for n in (self.dependencies + self.owned)]: + self.delete_stmts[self.block_depth].append(node.get_parent().get_parent()) + + def endvisit_assignment(self, node): + self.inside_assignment = False diff --git a/pynestml/utils/ast_receptor_information_collector.py b/pynestml/utils/ast_receptor_information_collector.py new file mode 100644 index 000000000..f579d689a --- /dev/null +++ b/pynestml/utils/ast_receptor_information_collector.py @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +# +# ast_receptor_information_collector.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from collections import defaultdict +import copy + +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_kernel import ASTKernel +from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.visitors.ast_visitor import ASTVisitor + + +class ASTReceptorInformationCollector(ASTVisitor): + """ + This file is part of the compartmental code generation process. + + Collect all synapse relevant information. + """ + + def __init__(self): + super(ASTReceptorInformationCollector, self).__init__() + + # various dicts to store collected information + self.kernel_name_to_kernel = defaultdict() + self.inline_expression_to_kernel_args = defaultdict(lambda: set()) + self.inline_expression_to_function_calls = defaultdict(lambda: set()) + self.kernel_to_function_calls = defaultdict(lambda: set()) + self.parameter_name_to_declaration = defaultdict(lambda: None) + self.state_name_to_declaration = defaultdict(lambda: None) + self.variable_name_to_declaration = defaultdict(lambda: None) + self.internal_var_name_to_declaration = defaultdict(lambda: None) + self.inline_expression_to_variables = defaultdict(lambda: set()) + self.kernel_to_rhs_variables = defaultdict(lambda: set()) + self.declaration_to_rhs_variables = defaultdict(lambda: set()) + self.input_port_name_to_input_port = defaultdict() + + # traversal states and nodes + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internals_block = False + self.inside_equations_block = False + self.inside_input_block = False + self.inside_inline_expression = False + self.inside_kernel = False + self.inside_kernel_call = False + self.inside_declaration = False + # self.inside_variable = False + self.inside_simple_expression = False + self.inside_expression = False + # self.inside_function_call = False + + self.current_inline_expression = None + self.current_kernel = None + self.current_expression = None + self.current_simple_expression = None + self.current_declaration = None + # self.current_variable = None + + self.current_synapse_name = None + + def get_state_declaration(self, variable_name): + return self.state_name_to_declaration[variable_name] + + def get_variable_declaration(self, variable_name): + return self.variable_name_to_declaration[variable_name] + + def get_kernel_by_name(self, name: str): + return self.kernel_name_to_kernel[name] + + def get_inline_expressions_with_kernels(self): + return self.inline_expression_to_kernel_args.keys() + + def get_kernel_function_calls(self, kernel: ASTKernel): + return self.kernel_to_function_calls[kernel] + + def get_inline_function_calls(self, inline: ASTInlineExpression): + return self.inline_expression_to_function_calls[inline] + + def get_variable_names_of_synapse(self, synapse_inline: ASTInlineExpression, exclude_names: set = set(), exclude_ignorable=True) -> set: + """extracts all variables specific to a single synapse + (which is defined by the inline expression containing kernels) + independently of what block they are declared in + it also cascades over all right hand side variables until all + variables are included""" + if exclude_ignorable: + exclude_names.update(self.get_variable_names_to_ignore()) + + # find all variables used in the inline + potential_variables = self.inline_expression_to_variables[synapse_inline] + + # find all kernels referenced by the inline + # and collect variables used by those kernels + kernel_arg_pairs = self.get_extracted_kernel_args(synapse_inline) + for kernel_var, spikes_var in kernel_arg_pairs: + kernel = self.get_kernel_by_name(kernel_var.get_name()) + potential_variables.update(self.kernel_to_rhs_variables[kernel]) + + # find declarations for all variables and check + # what variables their rhs expressions use + # for example if we have + # a = b * c + # then check if b and c are already in potential_variables + # if not, add those as well + potential_variables_copy = copy.copy(potential_variables) + + potential_variables_prev_count = len(potential_variables) + while True: + for potential_variable in potential_variables_copy: + var_name = potential_variable.get_name() + if var_name in exclude_names: + continue + declaration = self.get_variable_declaration(var_name) + if declaration is None: + continue + variables_referenced = self.declaration_to_rhs_variables[var_name] + potential_variables.update(variables_referenced) + if potential_variables_prev_count == len(potential_variables): + break + potential_variables_prev_count = len(potential_variables) + + # transform variables into their names and filter + # out anything form exclude_names + result = set() + for potential_variable in potential_variables: + var_name = potential_variable.get_name() + if var_name not in exclude_names: + result.add(var_name) + + return result + + @classmethod + def get_variable_names_to_ignore(cls): + return set(PredefinedVariables.get_variables().keys()).union({"v_comp"}) + + def get_synapse_specific_internal_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the internals block + dereferenced = defaultdict() + for potential_internals_name in synapse_variable_names: + if potential_internals_name in self.internal_var_name_to_declaration: + dereferenced[potential_internals_name] = self.internal_var_name_to_declaration[potential_internals_name] + return dereferenced + + def get_synapse_specific_state_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the state block + dereferenced = defaultdict() + for potential_state_name in synapse_variable_names: + if potential_state_name in self.state_name_to_declaration: + dereferenced[potential_state_name] = self.state_name_to_declaration[potential_state_name] + return dereferenced + + def get_synapse_specific_parameter_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the parameter block + dereferenced = defaultdict() + for potential_param_name in synapse_variable_names: + if potential_param_name in self.parameter_name_to_declaration: + dereferenced[potential_param_name] = self.parameter_name_to_declaration[potential_param_name] + return dereferenced + + def get_extracted_kernel_args(self, inline_expression: ASTInlineExpression) -> set: + return self.inline_expression_to_kernel_args[inline_expression] + + def get_basic_kernel_variable_names(self, synapse_inline): + """ + for every occurence of convolve(port, spikes) generate "port__X__spikes" variable + gather those variables for this synapse inline and return their list + + note that those variables will occur as substring in other kernel variables i.e "port__X__spikes__d" or "__P__port__X__spikes__port__X__spikes" + + so we can use the result to identify all the other kernel variables related to the + specific synapse inline declaration + """ + order = 0 + results = [] + for syn_inline, args in self.inline_expression_to_kernel_args.items(): + if synapse_inline.variable_name == syn_inline.variable_name: + for kernel_var, spike_var in args: + kernel_name = kernel_var.get_name() + spike_input_port = self.input_port_name_to_input_port[spike_var.get_name( + )] + kernel_variable_name = self.construct_kernel_X_spike_buf_name( + kernel_name, spike_input_port, order) + results.append(kernel_variable_name) + + return results + + def get_used_kernel_names(self, inline_expression: ASTInlineExpression): + return [kernel_var.get_name() for kernel_var, _ in self.get_extracted_kernel_args(inline_expression)] + + def get_input_port_by_name(self, name): + return self.input_port_name_to_input_port[name] + + def get_used_spike_names(self, inline_expression: ASTInlineExpression): + return [spikes_var.get_name() for _, spikes_var in self.get_extracted_kernel_args(inline_expression)] + + def visit_kernel(self, node): + self.current_kernel = node + self.inside_kernel = True + if self.inside_equations_block: + kernel_name = node.get_variables()[0].get_name_of_lhs() + self.kernel_name_to_kernel[kernel_name] = node + + def visit_function_call(self, node): + if self.inside_equations_block: + if self.inside_inline_expression and self.inside_simple_expression: + if node.get_name() == "convolve": + self.inside_kernel_call = True + kernel, spikes = node.get_args() + kernel_var = kernel.get_variables()[0] + spikes_var = spikes.get_variables()[0] + if "mechanism::receptor" in [(e.namespace + "::" + e.name) for e in self.current_inline_expression.get_decorators()]: + self.inline_expression_to_kernel_args[self.current_inline_expression].add( + (kernel_var, spikes_var)) + else: + self.inline_expression_to_function_calls[self.current_inline_expression].add( + node) + if self.inside_kernel and self.inside_simple_expression: + self.kernel_to_function_calls[self.current_kernel].add(node) + + def endvisit_function_call(self, node): + self.inside_kernel_call = False + + def endvisit_kernel(self, node): + self.current_kernel = None + self.inside_kernel = False + + def visit_variable(self, node): + if self.inside_inline_expression and not self.inside_kernel_call: + self.inline_expression_to_variables[self.current_inline_expression].add( + node) + elif self.inside_kernel and (self.inside_expression or self.inside_simple_expression): + self.kernel_to_rhs_variables[self.current_kernel].add(node) + elif self.inside_declaration and self.inside_expression: + declared_variable = self.current_declaration.get_variables()[ + 0].get_name() + self.declaration_to_rhs_variables[declared_variable].add(node) + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + self.current_inline_expression = node + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + self.current_inline_expression = None + + def visit_equations_block(self, node): + self.inside_equations_block = True + + def endvisit_equations_block(self, node): + self.inside_equations_block = False + + def visit_input_block(self, node): + self.inside_input_block = True + + def visit_input_port(self, node): + self.input_port_name_to_input_port[node.get_name()] = node + + def endvisit_input_block(self, node): + self.inside_input_block = False + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = False + if node.is_parameters: + self.inside_parameter_block = False + if node.is_internals: + self.inside_internals_block = False + + def visit_simple_expression(self, node): + self.inside_simple_expression = True + self.current_simple_expression = node + + def endvisit_simple_expression(self, node): + self.inside_simple_expression = False + self.current_simple_expression = None + + def visit_declaration(self, node): + self.inside_declaration = True + self.current_declaration = node + + # collect decalarations generally + variable_name = node.get_variables()[0].get_name() + self.variable_name_to_declaration[variable_name] = node + + # collect declarations per block + if self.inside_parameter_block: + self.parameter_name_to_declaration[variable_name] = node + elif self.inside_state_block: + self.state_name_to_declaration[variable_name] = node + elif self.inside_internals_block: + self.internal_var_name_to_declaration[variable_name] = node + + def endvisit_declaration(self, node): + self.inside_declaration = False + self.current_declaration = None + + def visit_expression(self, node): + self.inside_expression = True + self.current_expression = node + + def endvisit_expression(self, node): + self.inside_expression = False + self.current_expression = None + + # this method was copied over from ast_transformer + # in order to avoid a circular dependency + @staticmethod + def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, diff_order_symbol="__d"): + assert type(kernel_var_name) is str + assert type(order) is int + assert type(diff_order_symbol) is str + return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str(spike_input_port) + diff_order_symbol * order diff --git a/pynestml/utils/ast_synapse_information_collector.py b/pynestml/utils/ast_synapse_information_collector.py index f5a6763bc..a8e90656d 100644 --- a/pynestml/utils/ast_synapse_information_collector.py +++ b/pynestml/utils/ast_synapse_information_collector.py @@ -18,25 +18,345 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - -from _collections import defaultdict import copy +from collections import defaultdict from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_kernel import ASTKernel +from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.predefined_variables import PredefinedVariables from pynestml.visitors.ast_visitor import ASTVisitor -class ASTSynapseInformationCollector(ASTVisitor): +class ASTSynapseInformationCollector(object): """ - for each inline expression inside the equations block, - collect all synapse relevant information + This file is part of the compartmental code generation process. + Additional parsing of ODE solutions and collection of updated information from the AST. """ + collector_visitor = None + synapse = None + + @classmethod + def __init__(cls, synapse): + cls.synapse = synapse + cls.collector_visitor = ASTMechanismInformationCollectorVisitor() + synapse.accept(cls.collector_visitor) + + @classmethod + def collect_definitions(cls, synapse, syn_info): + # variables + var_collector_visitor = ASTVariableCollectorVisitor() + synapse.accept(var_collector_visitor) + syn_info["States"] = var_collector_visitor.all_states + syn_info["Parameters"] = var_collector_visitor.all_parameters + syn_info["Internals"] = var_collector_visitor.all_internals + + # ODEs + ode_collector_visitor = ASTODEEquationCollectorVisitor() + synapse.accept(ode_collector_visitor) + syn_info["ODEs"] = ode_collector_visitor.all_ode_equations + + # inlines + inline_collector_visitor = ASTInlineEquationCollectorVisitor() + synapse.accept(inline_collector_visitor) + syn_info["Inlines"] = inline_collector_visitor.all_inlines + + # functions + function_collector_visitor = ASTFunctionCollectorVisitor() + synapse.accept(function_collector_visitor) + syn_info["Functions"] = function_collector_visitor.all_functions + + return syn_info + + @classmethod + def collect_on_receive_blocks(cls, synapse, syn_info, pre_port, post_port): + pre_spike_collector_visitor = ASTOnReceiveBlockVisitor(pre_port) + synapse.accept(pre_spike_collector_visitor) + syn_info["PreSpikeFunction"] = pre_spike_collector_visitor.on_receive_block + + post_spike_collector_visitor = ASTOnReceiveBlockVisitor(post_port) + synapse.accept(post_spike_collector_visitor) + syn_info["PostSpikeFunction"] = post_spike_collector_visitor.on_receive_block + + return syn_info + + @classmethod + def collect_update_block(cls, synapse, syn_info): + update_block_collector_visitor = ASTUpdateBlockVisitor() + synapse.accept(update_block_collector_visitor) + syn_info["UpdateBlock"] = update_block_collector_visitor.update_block + return syn_info + + @classmethod + def collect_ports(cls, synapse, syn_info): + port_collector_visitor = ASTPortVisitor() + synapse.accept(port_collector_visitor) + syn_info["SpikingPorts"] = port_collector_visitor.spiking_ports + syn_info["ContinuousPorts"] = port_collector_visitor.continuous_ports + return syn_info + + @classmethod + def collect_potential_dependencies(cls, synapse, syn_info): + non_dec_asmt_visitor = ASTNonDeclaringAssignmentVisitor() + synapse.accept(non_dec_asmt_visitor) + + potential_dependencies = copy.deepcopy(syn_info["States"]) + for state in syn_info["States"]: + for assignment in non_dec_asmt_visitor.non_declaring_assignments: + if state == assignment.get_variable().get_name(): + if state in potential_dependencies: + del potential_dependencies[state] + + syn_info["PotentialDependencies"] = potential_dependencies + + return syn_info + + @classmethod + def extend_variables_with_initialisations(cls, synapse, syn_info): + """collects initialization expressions for all variables and parameters contained in syn_info""" + var_init_visitor = VariableInitializationVisitor(syn_info) + synapse.accept(var_init_visitor) + syn_info["States"] = var_init_visitor.states + syn_info["Parameters"] = var_init_visitor.parameters + syn_info["Internals"] = var_init_visitor.internals + + return syn_info + + @classmethod + def extend_variable_list_name_based_restricted(cls, extended_list, appending_list, restrictor_list): + """go through appending_list and append every variable that is not in restrictor_list to extended_list for the + purpose of not re-searching the same variable""" + for app_item in appending_list: + appendable = True + for rest_item in restrictor_list: + if rest_item.name == app_item.name: + appendable = False + break + if appendable: + extended_list.append(app_item) + + return extended_list + + @classmethod + def extend_function_call_list_name_based_restricted(cls, extended_list, appending_list, restrictor_list): + """go through appending_list and append every variable that is not in restrictor_list to extended_list for the + purpose of not re-searching the same function""" + for app_item in appending_list: + appendable = True + for rest_item in restrictor_list: + if rest_item.callee_name == app_item.callee_name: + appendable = False + break + if appendable: + extended_list.append(app_item) + + return extended_list + @classmethod + def collect_mechanism_related_definitions(cls, neuron, syn_info): + """Collects all parts of the nestml code the root expressions previously collected depend on. search + is cut at other mechanisms root expressions""" + from pynestml.meta_model.ast_inline_expression import ASTInlineExpression + from pynestml.meta_model.ast_ode_equation import ASTOdeEquation + + for mechanism_name, mechanism_info in syn_info.items(): + variable_collector = ASTVariableCollectorVisitor() + neuron.accept(variable_collector) + global_states = variable_collector.all_states + global_parameters = variable_collector.all_parameters + global_internals = variable_collector.all_internals + + function_collector = ASTFunctionCollectorVisitor() + neuron.accept(function_collector) + global_functions = function_collector.all_functions + + inline_collector = ASTInlineEquationCollectorVisitor() + neuron.accept(inline_collector) + global_inlines = inline_collector.all_inlines + + ode_collector = ASTODEEquationCollectorVisitor() + neuron.accept(ode_collector) + global_odes = ode_collector.all_ode_equations + + kernel_collector = ASTKernelCollectorVisitor() + neuron.accept(kernel_collector) + global_kernels = kernel_collector.all_kernels + + continuous_input_collector = ASTContinuousInputDeclarationVisitor() + neuron.accept(continuous_input_collector) + global_continuous_inputs = continuous_input_collector.ports + + mechanism_states = list() + mechanism_parameters = list() + mechanism_internals = list() + mechanism_functions = list() + mechanism_inlines = list() + mechanism_odes = list() + synapse_kernels = list() + mechanism_continuous_inputs = list() + mechanism_dependencies = defaultdict() + mechanism_dependencies["concentrations"] = list() + mechanism_dependencies["channels"] = list() + mechanism_dependencies["receptors"] = list() + mechanism_dependencies["continuous"] = list() + + mechanism_inlines.append(syn_info[mechanism_name]["root_expression"]) + + search_variables = list() + search_functions = list() + + found_variables = list() + found_functions = list() + + local_variable_collector = ASTVariableCollectorVisitor() + mechanism_inlines[0].accept(local_variable_collector) + search_variables = local_variable_collector.all_variables + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + mechanism_inlines[0].accept(local_function_call_collector) + search_functions = local_function_call_collector.all_function_calls + + while len(search_functions) > 0 or len(search_variables) > 0: + if len(search_functions) > 0: + function_call = search_functions[0] + for function in global_functions: + if function.name == function_call.callee_name: + mechanism_functions.append(function) + found_functions.append(function_call) + + local_variable_collector = ASTVariableCollectorVisitor() + function.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + function.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + # IMPLEMENT CATCH NONDEFINED!!! + search_functions.remove(function_call) + + elif len(search_variables) > 0: + variable = search_variables[0] + if not variable.name == "v_comp": + is_dependency = False + for inline in global_inlines: + if variable.name == inline.variable_name: + if isinstance(inline.get_decorators(), list): + if "mechanism" in [e.namespace for e in inline.get_decorators()]: + is_dependency = True + if not (isinstance(mechanism_info["root_expression"], + ASTInlineExpression) and inline.variable_name == mechanism_info["root_expression"].variable_name): + if "channel" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["channels"]]: + mechanism_dependencies["channels"].append(inline) + if "receptor" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["receptors"]]: + mechanism_dependencies["receptors"].append(inline) + if "continuous" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["continuous"]]: + mechanism_dependencies["continuous"].append(inline) + + if not is_dependency: + mechanism_inlines.append(inline) + + local_variable_collector = ASTVariableCollectorVisitor() + inline.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + inline.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted( + search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for ode in global_odes: + if variable.name == ode.lhs.name: + if isinstance(ode.get_decorators(), list): + if "mechanism" in [e.namespace for e in ode.get_decorators()]: + is_dependency = True + if not (isinstance(mechanism_info["root_expression"], + ASTOdeEquation) and ode.lhs.name == mechanism_info["root_expression"].lhs.name): + if "concentration" in [e.name for e in ode.get_decorators()]: + if not ode.lhs.name in [o.lhs.name for o in + mechanism_dependencies["concentrations"]]: + mechanism_dependencies["concentrations"].append(ode) + + if not is_dependency: + mechanism_odes.append(ode) + + local_variable_collector = ASTVariableCollectorVisitor() + ode.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + ode.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted( + search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for state in global_states: + if variable.name == state.name and not is_dependency: + mechanism_states.append(state) + + for parameter in global_parameters: + if variable.name == parameter.name: + mechanism_parameters.append(parameter) + + for internal in global_internals: + if variable.name == internal.name: + mechanism_internals.append(internal) + + for kernel in global_kernels: + if variable.name == kernel.get_variables()[0].name: + synapse_kernels.append(kernel) + + local_variable_collector = ASTVariableCollectorVisitor() + kernel.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + kernel.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for input in global_continuous_inputs: + if variable.name == input.name: + mechanism_continuous_inputs.append(input) + + search_variables.remove(variable) + found_variables.append(variable) + + syn_info[mechanism_name]["States"] = mechanism_states + syn_info[mechanism_name]["Parameters"] = mechanism_parameters + syn_info[mechanism_name]["Internals"] = mechanism_internals + syn_info[mechanism_name]["Functions"] = mechanism_functions + syn_info[mechanism_name]["SecondaryInlineExpressions"] = mechanism_inlines + syn_info[mechanism_name]["ODEs"] = mechanism_odes + syn_info[mechanism_name]["Continuous"] = mechanism_continuous_inputs + syn_info[mechanism_name]["Dependencies"] = mechanism_dependencies + + return syn_info + + +class ASTKernelInformationCollectorVisitor(ASTVisitor): def __init__(self): - super(ASTSynapseInformationCollector, self).__init__() + super(ASTKernelInformationCollectorVisitor, self).__init__() # various dicts to store collected information self.kernel_name_to_kernel = defaultdict() @@ -62,17 +382,14 @@ def __init__(self): self.inside_kernel = False self.inside_kernel_call = False self.inside_declaration = False - # self.inside_variable = False self.inside_simple_expression = False self.inside_expression = False - # self.inside_function_call = False self.current_inline_expression = None self.current_kernel = None self.current_expression = None self.current_simple_expression = None self.current_declaration = None - # self.current_variable = None self.current_synapse_name = None @@ -94,7 +411,8 @@ def get_kernel_function_calls(self, kernel: ASTKernel): def get_inline_function_calls(self, inline: ASTInlineExpression): return self.inline_expression_to_function_calls[inline] - def get_variable_names_of_synapse(self, synapse_inline: ASTInlineExpression, exclude_names: set = set(), exclude_ignorable=True) -> set: + def get_variable_names_of_synapse(self, synapse_inline: ASTInlineExpression, exclude_names: set = set(), + exclude_ignorable=True) -> set: """extracts all variables specific to a single synapse (which is defined by the inline expression containing kernels) independently of what block they are declared in @@ -189,6 +507,12 @@ def get_synapse_specific_parameter_declarations(self, synapse_inline: ASTInlineE def get_extracted_kernel_args(self, inline_expression: ASTInlineExpression) -> set: return self.inline_expression_to_kernel_args[inline_expression] + def get_extracted_kernel_args_by_name(self, inline_name: str) -> set: + inline_expression = [inline for inline in self.inline_expression_to_kernel_args.keys() if + inline.get_variable_name() == inline_name] + + return self.inline_expression_to_kernel_args[inline_expression[0]] + def get_basic_kernel_variable_names(self, synapse_inline): """ for every occurence of convolve(port, spikes) generate "port__X__spikes" variable @@ -237,9 +561,8 @@ def visit_function_call(self, node): kernel, spikes = node.get_args() kernel_var = kernel.get_variables()[0] spikes_var = spikes.get_variables()[0] - if "mechanism::receptor" in [(e.namespace + "::" + e.name) for e in self.current_inline_expression.get_decorators()]: - self.inline_expression_to_kernel_args[self.current_inline_expression].add( - (kernel_var, spikes_var)) + self.inline_expression_to_kernel_args[self.current_inline_expression].add( + (kernel_var, spikes_var)) else: self.inline_expression_to_function_calls[self.current_inline_expression].add( node) @@ -342,8 +665,306 @@ def endvisit_expression(self, node): # this method was copied over from ast_transformer # in order to avoid a circular dependency @staticmethod - def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, diff_order_symbol="__d"): + def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, + diff_order_symbol="__d"): assert type(kernel_var_name) is str assert type(order) is int assert type(diff_order_symbol) is str - return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str(spike_input_port) + diff_order_symbol * order + return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str( + spike_input_port) + diff_order_symbol * order + + +class ASTMechanismInformationCollectorVisitor(ASTVisitor): + + def __init__(self): + super(ASTMechanismInformationCollectorVisitor, self).__init__() + self.inEquationsBlock = False + self.inlinesInEquationsBlock = list() + self.odes = list() + + def visit_equations_block(self, node): + self.inEquationsBlock = True + + def endvisit_equations_block(self, node): + self.inEquationsBlock = False + + def visit_inline_expression(self, node): + if self.inEquationsBlock: + self.inlinesInEquationsBlock.append(node) + + def visit_ode_equation(self, node): + self.odes.append(node) + + +# Helper collectors: +class VariableInitializationVisitor(ASTVisitor): + def __init__(self, channel_info): + super(VariableInitializationVisitor, self).__init__() + self.inside_variable = False + self.inside_declaration = False + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internal_block = False + self.current_declaration = None + self.states = defaultdict() + self.parameters = defaultdict() + self.internals = defaultdict() + self.channel_info = channel_info + + def visit_declaration(self, node): + self.inside_declaration = True + self.current_declaration = node + + def endvisit_declaration(self, node): + self.inside_declaration = False + self.current_declaration = None + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internal_block = True + + def endvisit_block_with_variables(self, node): + self.inside_state_block = False + self.inside_parameter_block = False + self.inside_internal_block = False + + def visit_variable(self, node): + self.inside_variable = True + if self.inside_state_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["States"]): + self.states[node.name] = defaultdict() + self.states[node.name]["ASTVariable"] = node.clone() + self.states[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + if self.inside_parameter_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Parameters"]): + self.parameters[node.name] = defaultdict() + self.parameters[node.name]["ASTVariable"] = node.clone() + self.parameters[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + if self.inside_internal_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Internals"]): + self.internals[node.name] = defaultdict() + self.internals[node.name]["ASTVariable"] = node.clone() + self.internals[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + def endvisit_variable(self, node): + self.inside_variable = False + + +class ASTODEEquationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTODEEquationCollectorVisitor, self).__init__() + self.inside_ode_expression = False + self.all_ode_equations = list() + + def visit_ode_equation(self, node): + self.inside_ode_expression = True + self.all_ode_equations.append(node.clone()) + + def endvisit_ode_equation(self, node): + self.inside_ode_expression = False + + +class ASTVariableCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTVariableCollectorVisitor, self).__init__() + self.inside_variable = False + self.inside_block_with_variables = False + self.all_states = list() + self.all_parameters = list() + self.all_internals = list() + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.all_variables = list() + + def visit_block_with_variables(self, node): + self.inside_block_with_variables = True + if node.is_state: + self.inside_states_block = True + if node.is_parameters: + self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.inside_block_with_variables = False + + def visit_variable(self, node): + self.inside_variable = True + if not (node.name == "v_comp" or node.name in PredefinedUnits.get_units()): + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states.append(node.clone()) + if self.inside_parameters_block: + self.all_parameters.append(node.clone()) + if self.inside_internals_block: + self.all_internals.append(node.clone()) + + def endvisit_variable(self, node): + self.inside_variable = False + + +class ASTFunctionCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTFunctionCollectorVisitor, self).__init__() + self.inside_function = False + self.all_functions = list() + + def visit_function(self, node): + self.inside_function = True + self.all_functions.append(node.clone()) + + def endvisit_function(self, node): + self.inside_function = False + + +class ASTInlineEquationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTInlineEquationCollectorVisitor, self).__init__() + self.inside_inline_expression = False + self.all_inlines = list() + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + self.all_inlines.append(node.clone()) + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + + +class ASTFunctionCallCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTFunctionCallCollectorVisitor, self).__init__() + self.inside_function_call = False + self.all_function_calls = list() + + def visit_function_call(self, node): + self.inside_function_call = True + self.all_function_calls.append(node.clone()) + + def endvisit_function_call(self, node): + self.inside_function_call = False + + +class ASTKernelCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTKernelCollectorVisitor, self).__init__() + self.inside_kernel = False + self.all_kernels = list() + + def visit_kernel(self, node): + self.inside_kernel = True + self.all_kernels.append(node.clone()) + + def endvisit_kernel(self, node): + self.inside_kernel = False + + +class ASTContinuousInputDeclarationVisitor(ASTVisitor): + def __init__(self): + super(ASTContinuousInputDeclarationVisitor, self).__init__() + self.inside_port = False + self.current_port = None + self.ports = list() + + def visit_input_port(self, node): + self.inside_port = True + self.current_port = node + if self.current_port.is_continuous(): + self.ports.append(node.clone()) + + def endvisit_input_port(self, node): + self.inside_port = False + + +class ASTOnReceiveBlockVisitor(ASTVisitor): + def __init__(self, port_name): + super(ASTOnReceiveBlockVisitor, self).__init__() + self.inside_on_receive = False + self.port_name = port_name + self.on_receive_block = None + + def visit_on_receive_block(self, node): + self.inside_on_receive = True + if node.port_name in self.port_name: + self.on_receive_block = node.clone() + + def endvisit_on_receive_block(self, node): + self.inside_on_receive = False + + +class ASTUpdateBlockVisitor(ASTVisitor): + def __init__(self): + super(ASTUpdateBlockVisitor, self).__init__() + self.inside_update_block = False + self.update_block = None + + def visit_update_block(self, node): + self.inside_update_block = True + self.update_block = node.clone() + + def endvisit_update_block(self, node): + self.inside_update_block = False + + +class ASTPortVisitor(ASTVisitor): + def __init__(self): + super(ASTPortVisitor, self).__init__() + self.inside_port = False + self.spiking_ports = list() + self.continuous_ports = list() + + def visit_input_port(self, node): + self.inside_port = True + if node.is_spike(): + self.spiking_ports.append(node.clone()) + if node.is_continuous(): + self.continuous_ports.append(node.clone()) + + def endvisit_input_port(self, node): + self.inside_port = False + + +class ASTNonDeclaringAssignmentVisitor(ASTVisitor): + def __init__(self): + super(ASTNonDeclaringAssignmentVisitor, self).__init__() + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.inside_assignment = False + self.non_declaring_assignments = list() + + def visit_states_block(self, node): + self.inside_states_block = True + + def endvisit_states_block(self, node): + self.inside_states_block = False + + def visit_parameters_block(self, node): + self.inside_parameters_block = True + + def endvisit_parameters_block(self, node): + self.inside_parameters_block = False + + def visit_internals_block(self, node): + self.inside_internals_block = True + + def endvisit_internals_block(self, node): + self.inside_internals_block = False + + def visit_assignment(self, node): + self.inside_assignment = True + if not self.inside_parameters_block or not self.inside_internals_block or self.inside_states_block: + self.non_declaring_assignments.append(node.clone()) + + def endvisit_assignment(self, node): + self.inside_assignment = False diff --git a/pynestml/utils/ast_vector_parameter_setter_and_printer.py b/pynestml/utils/ast_vector_parameter_setter_and_printer.py index da0ee5076..986abdbc3 100644 --- a/pynestml/utils/ast_vector_parameter_setter_and_printer.py +++ b/pynestml/utils/ast_vector_parameter_setter_and_printer.py @@ -24,6 +24,11 @@ class ASTVectorParameterSetterAndPrinter(ASTPrinter): + """ + This file is part of the compartmental code generation process. + + Part of vectorized printing. + """ def __init__(self): super(ASTVectorParameterSetterAndPrinter, self).__init__() self.inside_variable = False diff --git a/pynestml/utils/ast_vector_parameter_setter_and_printer_factory.py b/pynestml/utils/ast_vector_parameter_setter_and_printer_factory.py index 13c3b08d5..451a63376 100644 --- a/pynestml/utils/ast_vector_parameter_setter_and_printer_factory.py +++ b/pynestml/utils/ast_vector_parameter_setter_and_printer_factory.py @@ -29,7 +29,11 @@ class ASTVectorParameterSetterAndPrinterFactory: + """ + This file is part of the compartmental code generation process. + Part of vectorized printing. + """ def __init__(self, model, printer): self.printer = printer self.model = model diff --git a/pynestml/utils/chan_info_enricher.py b/pynestml/utils/chan_info_enricher.py index 61ff60f6f..b7609d475 100644 --- a/pynestml/utils/chan_info_enricher.py +++ b/pynestml/utils/chan_info_enricher.py @@ -28,6 +28,8 @@ class ChanInfoEnricher(MechsInfoEnricher): """ + This file is part of the compartmental code generation process. + Class extends MechsInfoEnricher by the computation of the inline derivative. This hasn't been done in the channel processing because it would cause a circular dependency through the coco checks used by the ModelParser which we need to use. diff --git a/pynestml/utils/channel_processing.py b/pynestml/utils/channel_processing.py index b35822b87..dc7b2a671 100644 --- a/pynestml/utils/channel_processing.py +++ b/pynestml/utils/channel_processing.py @@ -26,8 +26,12 @@ class ChannelProcessing(MechanismProcessing): - """Extends MechanismProcessing. Searches for Variables that if 0 lead to the root expression always beeing zero so - that the computation can be skipped during the simulation""" + """ + This file is part of the compartmental code generation process. + + Extends MechanismProcessing. Searches for Variables that if 0 lead to the root expression always beeing zero so + that the computation can be skipped during the simulation + """ mechType = "channel" @@ -37,6 +41,7 @@ def __init__(self, params): @classmethod def collect_information_for_specific_mech_types(cls, neuron, mechs_info): mechs_info = cls.write_key_zero_parameters_for_root_inlines(mechs_info) + cls.check_all_convolutions_with_self_spikes(mechs_info) return mechs_info diff --git a/pynestml/utils/con_in_info_enricher.py b/pynestml/utils/con_in_info_enricher.py index fa8228ac8..582881d9d 100644 --- a/pynestml/utils/con_in_info_enricher.py +++ b/pynestml/utils/con_in_info_enricher.py @@ -27,9 +27,13 @@ class ConInInfoEnricher(MechsInfoEnricher): - """Class extends MechsInfoEnricher by the computation of the inline derivative. This hasn't been done in the + """ + This file is part of the compartmental code generation process. + + Class extends MechsInfoEnricher by the computation of the inline derivative. This hasn't been done in the channel processing because it would cause a circular dependency through the coco checks used by the ModelParser - which we need to use.""" + which we need to use. + """ def __init__(self, params): super(MechsInfoEnricher, self).__init__(params) diff --git a/pynestml/utils/conc_info_enricher.py b/pynestml/utils/conc_info_enricher.py index f62d4989e..51bc98014 100644 --- a/pynestml/utils/conc_info_enricher.py +++ b/pynestml/utils/conc_info_enricher.py @@ -18,12 +18,46 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from collections import defaultdict from pynestml.utils.mechs_info_enricher import MechsInfoEnricher class ConcInfoEnricher(MechsInfoEnricher): - """Just created for consistency with the rest of the mechanism generation process. No more than the base-class - enriching needs to be done""" + """ + This file is part of the compartmental code generation process. + + Class extends MechsInfoEnricher by the computation of the inline derivative. This hasn't been done in the + channel processing because it would cause a circular dependency through the coco checks used by the ModelParser + which we need to use. + """ def __init__(self, params): super(MechsInfoEnricher, self).__init__(params) + + @classmethod + def enrich_mechanism_specific(cls, neuron, mechs_info): + mechs_info = cls.ode_toolbox_processing_for_root_expression(neuron, mechs_info) + return mechs_info + + @classmethod + def ode_toolbox_processing_for_root_expression(cls, neuron, conc_info): + """applies the ode_toolbox_processing to the root_expression since that was never appended to the list of ODEs + in the base processing and thereby also never went through the ode_toolbox processing""" + for concentration_name, concentration_info in conc_info.items(): + # Create fake mechs_info such that it can be processed by the existing ode_toolbox_processing function. + fake_conc_info = defaultdict() + fake_concentration_info = defaultdict() + fake_concentration_info["ODEs"] = list() + fake_concentration_info["ODEs"].append(concentration_info["root_expression"]) + fake_conc_info["fake"] = fake_concentration_info + + fake_conc_info = cls.get_transformed_ode_equations(fake_conc_info) + fake_conc_info = cls.ode_toolbox_processing(neuron, fake_conc_info) + cls.add_propagators_to_internals(neuron, fake_conc_info) + fake_conc_info = cls.transform_ode_solutions(neuron, fake_conc_info) + + conc_info[concentration_name]["ODEs"] = {**conc_info[concentration_name]["ODEs"], **fake_conc_info["fake"]["ODEs"]} + if "time_resolution_var" in fake_conc_info["fake"]: + conc_info[concentration_name]["time_resolution_var"] = fake_conc_info["fake"]["time_resolution_var"] + + return conc_info diff --git a/pynestml/utils/concentration_processing.py b/pynestml/utils/concentration_processing.py index 7671a992b..f51c91957 100644 --- a/pynestml/utils/concentration_processing.py +++ b/pynestml/utils/concentration_processing.py @@ -27,8 +27,12 @@ class ConcentrationProcessing(MechanismProcessing): - """The default Processing ignores the root expression when solving the odes which in case of the concentration - mechanism is a ode that needs to be solved. This is added here.""" + """ + This file is part of the compartmental code generation process. + + The default Processing ignores the root expression when solving the odes which in case of the concentration + mechanism is a ode that needs to be solved. This is added here. + """ mechType = "concentration" def __init__(self, params): @@ -36,8 +40,8 @@ def __init__(self, params): @classmethod def collect_information_for_specific_mech_types(cls, neuron, mechs_info): - mechs_info = cls.ode_toolbox_processing_for_root_expression(neuron, mechs_info) mechs_info = cls.write_key_zero_parameters_for_root_odes(mechs_info) + cls.check_all_convolutions_with_self_spikes(mechs_info) return mechs_info diff --git a/pynestml/utils/continuous_input_processing.py b/pynestml/utils/continuous_input_processing.py index 980217fc1..f081b2df0 100644 --- a/pynestml/utils/continuous_input_processing.py +++ b/pynestml/utils/continuous_input_processing.py @@ -27,6 +27,9 @@ class ContinuousInputProcessing(MechanismProcessing): + """ + This file is part of the compartmental code generation process. + """ mechType = "continuous_input" def __init__(self, params): @@ -40,4 +43,6 @@ def collect_information_for_specific_mech_types(cls, neuron, mechs_info): continuous[port.name] = copy.deepcopy(port) mechs_info[continuous_name]["Continuous"] = continuous + cls.check_all_convolutions_with_self_spikes(mechs_info) + return mechs_info diff --git a/pynestml/utils/global_info_enricher.py b/pynestml/utils/global_info_enricher.py new file mode 100644 index 000000000..9807dbed2 --- /dev/null +++ b/pynestml/utils/global_info_enricher.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- +# +# global_info_enricher.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.meta_model.ast_model import ASTModel +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.utils.ast_utils import ASTUtils +from pynestml.visitors.ast_visitor import ASTVisitor +from pynestml.utils.model_parser import ModelParser +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind + +from collections import defaultdict + + +class GlobalInfoEnricher: + """ + This file is part of the compartmental code generation process. + + Adds information collection that can't be done in the processing class since that is used in the cocos. + Here we use the ModelParser which would lead to a cyclic dependency. + + Additionally, we require information about the paired neurons mechanism to confirm what dependencies are actually existent in the neuron. + """ + + @classmethod + def enrich_with_additional_info(cls, neuron: ASTModel, global_info: dict): + global_info = cls.transform_ode_solutions(neuron, global_info) + global_info = cls.extract_infunction_declarations(global_info) + + return global_info + + @classmethod + def transform_ode_solutions(cls, neuron, global_info): + for ode_var_name, ode_info in global_info["ODEs"].items(): + global_info["ODEs"][ode_var_name]["transformed_solutions"] = list() + + for ode_solution_index in range(len(ode_info["ode_toolbox_output"])): + solution_transformed = defaultdict() + solution_transformed["states"] = defaultdict() + solution_transformed["propagators"] = defaultdict() + + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as neurons have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ + variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + + solution_transformed["states"][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["propagators"].items(): + prop_variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + if prop_variable is None: + ASTUtils.add_declarations_to_internals( + neuron, ode_info["ode_toolbox_output"][ode_solution_index]["propagators"]) + prop_variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol( + variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as neurons have been + # defined to get here + expression.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + solution_transformed["propagators"][variable_name] = { + "ASTVariable": prop_variable, "init_expression": expression, } + expression_variable_collector = ASTEnricherInfoCollectorVisitor() + expression.accept(expression_variable_collector) + + neuron_internal_declaration_collector = ASTEnricherInfoCollectorVisitor() + neuron.accept(neuron_internal_declaration_collector) + + for variable in expression_variable_collector.all_variables: + for internal_declaration in neuron_internal_declaration_collector.internal_declarations: + if variable.get_name() == internal_declaration.get_variables()[0].get_name() \ + and internal_declaration.get_expression().is_function_call() \ + and internal_declaration.get_expression().get_function_call().callee_name == \ + PredefinedFunctions.TIME_RESOLUTION: + global_info["time_resolution_var"] = variable + + global_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) + + neuron.accept(ASTParentVisitor()) + + return global_info + + @classmethod + def extract_infunction_declarations(cls, global_info): + declaration_visitor = ASTDeclarationCollectorAndUniqueRenamerVisitor() + if "SelfSpikesFunction" in global_info and global_info["SelfSpikesFunction"] is not None: + self_spike_function = global_info["SelfSpikesFunction"] + self_spike_function.accept(declaration_visitor) + if "UpdateBlock" in global_info and global_info["UpdateBlock"] is not None: + update_block = global_info["UpdateBlock"] + update_block.accept(declaration_visitor) + + declaration_vars = list() + for decl in declaration_visitor.declarations: + for var in decl.get_variables(): + declaration_vars.append(var.get_name()) + + global_info["InFunctionDeclarationsVars"] = declaration_visitor.declarations + return global_info + + @classmethod + def substituteNoneWithEmptyBlocks(cls, global_info): + if (not "UpdateBlock" in global_info) or (global_info["UpdateBlock"] is None): + empty = ModelParser.parse_block("") + global_info["UpdateBlock"] = empty.clone() + if (not "SelfSpikesFunction" in global_info) or (global_info["SelfSpikesFunction"] is None): + empty = ModelParser.parse_block("") + global_info["SelfSpikesFunction"] = empty.clone() + + return global_info + + @classmethod + def compute_update_block_variations(cls, info_collector, mechs_info, global_info): + if global_info["UpdateBlock"] is not None: + info_collector.collect_update_block_dependencies_and_owned(mechs_info, global_info) + global_info["UpdateBlock"] = info_collector.recursive_update_block_reduction(mechs_info, [], + global_info["UpdateBlock"]) + + +class ASTEnricherInfoCollectorVisitor(ASTVisitor): + + def __init__(self): + super(ASTEnricherInfoCollectorVisitor, self).__init__() + self.inside_variable = False + self.inside_block_with_variables = False + self.all_states = list() + self.all_parameters = list() + self.inside_states_block = False + self.inside_parameters_block = False + self.all_variables = list() + self.inside_internals_block = False + self.inside_declaration = False + self.internal_declarations = list() + + def visit_block_with_variables(self, node): + self.inside_block_with_variables = True + if node.is_state: + self.inside_states_block = True + if node.is_parameters: + self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_block_with_variables = False + self.inside_internals_block = False + + def visit_variable(self, node): + self.inside_variable = True + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states.append(node.clone()) + if self.inside_parameters_block: + self.all_parameters.append(node.clone()) + + def endvisit_variable(self, node): + self.inside_variable = False + + def visit_declaration(self, node): + self.inside_declaration = True + if self.inside_internals_block: + self.internal_declarations.append(node) + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTDeclarationCollectorAndUniqueRenamerVisitor(ASTVisitor): + def __init__(self): + super(ASTDeclarationCollectorAndUniqueRenamerVisitor, self).__init__() + self.declarations = list() + self.variable_names = dict() + self.inside_declaration = False + self.inside_block = False + self.current_block = None + + def visit_block(self, node): + self.inside_block = True + self.current_block = node + + def endvisit_block(self, node): + self.inside_block = False + self.current_block = None + + def visit_declaration(self, node): + self.inside_declaration = True + for variable in node.get_variables(): + if variable.get_name() in self.variable_names: + self.variable_names[variable.get_name()] += 1 + else: + self.variable_names[variable.get_name()] = 0 + new_name = variable.get_name() + '_' + str(self.variable_names[variable.get_name()]) + name_replacer = ASTVariableNameReplacerVisitor(variable.get_name(), new_name) + self.current_block.accept(name_replacer) + node.accept(ASTSymbolTableVisitor()) + self.declarations.append(node.clone()) + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTVariableNameReplacerVisitor(ASTVisitor): + def __init__(self, old_name, new_name): + super(ASTVariableNameReplacerVisitor, self).__init__() + self.inside_variable = False + self.new_name = new_name + self.old_name = old_name + + def visit_variable(self, node): + self.inside_variable = True + if node.get_name() == self.old_name: + node.set_name(self.new_name) + + def endvisit_variable(self, node): + self.inside_variable = False diff --git a/pynestml/utils/global_processing.py b/pynestml/utils/global_processing.py new file mode 100644 index 000000000..7549f91ad --- /dev/null +++ b/pynestml/utils/global_processing.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +# +# global_processing.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from collections import defaultdict + +import copy + +from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter +from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter +from pynestml.codegeneration.printers.ode_toolbox_function_call_printer import ODEToolboxFunctionCallPrinter +from pynestml.codegeneration.printers.ode_toolbox_variable_printer import ODEToolboxVariablePrinter + +from pynestml.codegeneration.printers.sympy_simple_expression_printer import SympySimpleExpressionPrinter +from pynestml.meta_model.ast_expression import ASTExpression +from pynestml.meta_model.ast_model import ASTModel +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression +from pynestml.utils.ast_global_information_collector import ASTGlobalInformationCollector +from pynestml.utils.ast_utils import ASTUtils + +from odetoolbox import analysis + + +class GlobalProcessing: + """ + This file is part of the compartmental code generation process. + + Processing of code parts related to the update and OnReceive(self_spikes) blocks. + """ + + # used to keep track of whenever check_co_co was already called + # see inside check_co_co + first_time_run = defaultdict(lambda: True) + # stores neuron from the first call of check_co_co + global_info = defaultdict() + + # ODE-toolbox printers + _constant_printer = ConstantPrinter() + _ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) + _ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) + _ode_toolbox_printer = ODEToolboxExpressionPrinter( + simple_expression_printer=SympySimpleExpressionPrinter( + variable_printer=_ode_toolbox_variable_printer, + constant_printer=_constant_printer, + function_call_printer=_ode_toolbox_function_call_printer)) + + _ode_toolbox_variable_printer._expression_printer = _ode_toolbox_printer + _ode_toolbox_function_call_printer._expression_printer = _ode_toolbox_printer + + @classmethod + def prepare_equations_for_ode_toolbox(cls, synapse, syn_info): + """Transforms the collected ode equations to the required input format of ode-toolbox and adds it to the + syn_info dictionary""" + + mechanism_odes = defaultdict() + for ode in syn_info["ODEs"]: + nestml_printer = NESTMLPrinter() + ode_nestml_expression = nestml_printer.print_ode_equation(ode) + mechanism_odes[ode.lhs.name] = defaultdict() + mechanism_odes[ode.lhs.name]["ASTOdeEquation"] = ode + mechanism_odes[ode.lhs.name]["ODENestmlExpression"] = ode_nestml_expression + syn_info["ODEs"] = mechanism_odes + + for ode_variable_name, ode_info in syn_info["ODEs"].items(): + # Expression: + odetoolbox_indict = {"dynamics": []} + lhs = ASTUtils.to_ode_toolbox_name(ode_info["ASTOdeEquation"].get_lhs().get_complete_name()) + rhs = cls._ode_toolbox_printer.print(ode_info["ASTOdeEquation"].get_rhs()) + entry = {"expression": lhs + " = " + rhs, "initial_values": {}} + + # Initial values: + symbol_order = ode_info["ASTOdeEquation"].get_lhs().get_differential_order() + for order in range(symbol_order): + iv_symbol_name = ode_info["ASTOdeEquation"].get_lhs().get_name() + "'" * order + initial_value_expr = synapse.get_initial_value(iv_symbol_name) + entry["initial_values"][ + ASTUtils.to_ode_toolbox_name(iv_symbol_name)] = cls._ode_toolbox_printer.print( + initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + syn_info["ODEs"][ode_variable_name]["ode_toolbox_input"] = odetoolbox_indict + + return syn_info + + @classmethod + def collect_raw_odetoolbox_output(cls, syn_info): + """calls ode-toolbox for each ode individually and collects the raw output""" + for ode_variable_name, ode_info in syn_info["ODEs"].items(): + solver_result = analysis(ode_info["ode_toolbox_input"], disable_stiffness_check=True) + syn_info["ODEs"][ode_variable_name]["ode_toolbox_output"] = solver_result + + return syn_info + + @classmethod + def ode_toolbox_processing(cls, neuron, global_info): + global_info = cls.prepare_equations_for_ode_toolbox(neuron, global_info) + global_info = cls.collect_raw_odetoolbox_output(global_info) + return global_info + + @classmethod + def get_global_info(cls, neuron): + """ + returns previously generated global_info + as a deep copy so it can't be changed externally + via object references + :param neuron: a single neuron instance. + """ + return cls.global_info[ + neuron.get_name()] # return direct refenrence with no copy due to intended external manipulation + + @classmethod + def check_co_co(cls, neuron: ASTModel): + """ + Checks if mechanism conditions apply for the handed over neuron. + :param neuron: a single neuron instance. + """ + + # make sure we only run this a single time + # subsequent calls will be after AST has been transformed + # and there would be no kernels or inlines anymore + if cls.first_time_run[neuron.get_name()]: + # collect root expressions and initialize collector + info_collector = ASTGlobalInformationCollector(neuron) + + # collect and process all basic mechanism information + global_info = defaultdict() + + global_info = info_collector.collect_update_block(neuron, global_info) + global_info = info_collector.collect_self_spike_function(neuron, global_info) + + global_info = info_collector.collect_related_definitions(neuron, global_info) + global_info = info_collector.extend_variables_with_initialisations(neuron, global_info) + global_info = cls.ode_toolbox_processing(neuron, global_info) + + cls.global_info[neuron.get_name()] = copy.deepcopy(global_info) + cls.first_time_run[neuron.get_name()] = False + + @classmethod + def print_element(cls, name, element, rec_step): + message = "" + for indent in range(rec_step): + message += "----" + message += name + ": " + if isinstance(element, defaultdict): + message += "\n" + message += cls.print_dictionary(element, rec_step + 1) + else: + if hasattr(element, 'name'): + message += element.name + elif isinstance(element, str): + message += element + elif isinstance(element, dict): + message += "\n" + message += cls.print_dictionary(element, rec_step + 1) + elif isinstance(element, list): + for index in range(len(element)): + message += "\n" + message += cls.print_element(str(index), element[index], rec_step + 1) + elif isinstance(element, ASTExpression) or isinstance(element, ASTSimpleExpression): + message += cls._ode_toolbox_printer.print(element) + + message += "(" + type(element).__name__ + ")" + return message + + @classmethod + def print_dictionary(cls, dictionary, rec_step): + """ + Print the mechanisms info dictionaries. + """ + message = "" + for name, element in dictionary.items(): + message += cls.print_element(name, element, rec_step) + message += "\n" + return message diff --git a/pynestml/utils/mechanism_processing.py b/pynestml/utils/mechanism_processing.py index d45257adf..6e19728e6 100644 --- a/pynestml/utils/mechanism_processing.py +++ b/pynestml/utils/mechanism_processing.py @@ -23,9 +23,17 @@ import copy +from pynestml.utils.logger import Logger, LoggingLevel + +from pynestml.utils.messages import Messages + +from pynestml.frontend.frontend_configuration import FrontendConfiguration + +from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables + +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.codegeneration.printers.sympy_simple_expression_printer import SympySimpleExpressionPrinter -from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter from pynestml.codegeneration.printers.constant_printer import ConstantPrinter from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter @@ -41,9 +49,13 @@ class MechanismProcessing: - """Manages the collection of basic information necesary for all types of mechanisms and uses the + """ + This file is part of the compartmental code generation process. + + Manages the collection of basic information necesary for all types of mechanisms and uses the collect_information_for_specific_mech_types interface that needs to be implemented by the specific mechanism type - processing classes""" + processing classes + """ # used to keep track of whenever check_co_co was already called # see inside check_co_co @@ -138,6 +150,130 @@ def determine_dependencies(cls, mechs_info): mechs_info[mechanism_name]["dependencies"] = dependencies return mechs_info + @classmethod + def convolution_ode_toolbox_processing(cls, neuron, mechs_info): + for mechanism_name, mechanism_info in mechs_info.items(): + parameters_block = None + if neuron.get_parameters_blocks(): + parameters_block = neuron.get_parameters_blocks()[0] + for convolution_name, convolution_info in mechanism_info["convolutions"].items(): + kernel_buffer = (convolution_info["kernel"]["ASTKernel"], convolution_info["spikes"]["ASTInputPort"]) + convolution_solution = cls.ode_solve_convolution(neuron, parameters_block, kernel_buffer) + mechanism_info["convolutions"][convolution_name]["analytic_solution"] = convolution_solution + return mechs_info + + @classmethod + def ode_solve_convolution(cls, + neuron: ASTModel, + parameters_block: ASTBlockWithVariables, + kernel_buffer): + odetoolbox_indict = cls.create_ode_indict( + neuron, parameters_block, kernel_buffer) + full_solver_result = analysis( + odetoolbox_indict, + disable_stiffness_check=True, + log_level=FrontendConfiguration.logging_level) + analytic_solver = None + analytic_solvers = [ + x for x in full_solver_result if x["solver"] == "analytical"] + assert len( + analytic_solvers) <= 1, "More than one analytic solver not presently supported" + if len(analytic_solvers) > 0: + analytic_solver = analytic_solvers[0] + + return analytic_solver + + @classmethod + def create_ode_indict(cls, + neuron: ASTModel, + parameters_block: ASTBlockWithVariables, + kernel_buffer): + kernel_buffers = {tuple(kernel_buffer)} + odetoolbox_indict = cls.transform_ode_and_kernels_to_json( + neuron, parameters_block, kernel_buffers) + odetoolbox_indict["options"] = {} + odetoolbox_indict["options"]["output_timestep_symbol"] = "__h" + return odetoolbox_indict + + @classmethod + def transform_ode_and_kernels_to_json( + cls, + neuron: ASTModel, + parameters_block, + kernel_buffers): + """ + Converts AST node to a JSON representation suitable for passing to ode-toolbox. + + Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements + + convolve(G, ex_spikes) + convolve(G, in_spikes) + + then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`. + + :param parameters_block: ASTBlockWithVariables + :return: Dict + """ + odetoolbox_indict = {"dynamics": []} + + equations_block = neuron.get_equations_blocks()[0] + + for kernel, spike_input_port in kernel_buffers: + if ASTUtils.is_delta_kernel(kernel): + continue + # delta function -- skip passing this to ode-toolbox + + for kernel_var in kernel.get_variables(): + expr = ASTUtils.get_expr_from_kernel_var( + kernel, kernel_var.get_complete_name()) + kernel_order = kernel_var.get_differential_order() + kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name( + kernel_var.get_name(), spike_input_port.get_name(), kernel_order, diff_order_symbol="'") + + ASTUtils.replace_rhs_variables(expr, kernel_buffers) + + entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}} + + # initial values need to be declared for order 1 up to kernel + # order (e.g. none for kernel function f(t) = ...; 1 for kernel + # ODE f'(t) = ...; 2 for f''(t) = ... and so on) + for order in range(kernel_order): + iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name( + kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") + symbol_name_ = kernel_var.get_name() + "'" * order + symbol = equations_block.get_scope().resolve_to_symbol( + symbol_name_, SymbolKind.VARIABLE) + assert symbol is not None, "Could not find initial value for variable " + symbol_name_ + initial_value_expr = symbol.get_declaring_expression() + assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ + entry["initial_values"][iv_sym_name_ode_toolbox] = cls._ode_toolbox_printer.print( + initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + + odetoolbox_indict["parameters"] = {} + if parameters_block is not None: + for decl in parameters_block.get_declarations(): + for var in decl.variables: + odetoolbox_indict["parameters"][var.get_complete_name( + )] = cls._ode_toolbox_printer.print(decl.get_expression()) + + return odetoolbox_indict + + @classmethod + def extract_mech_blocks(cls, info_collector, mechs_info, global_info): + block_list = list() + if "UpdateBlock" in global_info and global_info["UpdateBlock"] is not None: + block_list.append(global_info["UpdateBlock"].get_stmts_body()) + if "SelfSpikesFunction" in global_info and global_info["SelfSpikesFunction"] is not None: + block_list.append(global_info["SelfSpikesFunction"].get_stmts_body()) + if len(block_list) > 0: + info_collector.collect_block_dependencies_and_owned(mechs_info, block_list, "UpdateBlock") + if "UpdateBlock" in global_info and global_info["UpdateBlock"] is not None: + info_collector.block_reduction(mechs_info, global_info["UpdateBlock"], "UpdateBlock") + if "SelfSpikesFunction" in global_info and global_info["SelfSpikesFunction"] is not None: + info_collector.block_reduction(mechs_info, global_info["SelfSpikesFunction"], "SelfSpikesFunction") + @classmethod def get_mechs_info(cls, neuron: ASTModel): """ @@ -150,7 +286,7 @@ def get_mechs_info(cls, neuron: ASTModel): return copy.deepcopy(cls.mechs_info[neuron][cls.mechType]) @classmethod - def check_co_co(cls, neuron: ASTModel): + def check_co_co(cls, neuron: ASTModel, global_info): """ Checks if mechanism conditions apply for the handed over neuron. :param neuron: a single neuron instance. @@ -165,9 +301,12 @@ def check_co_co(cls, neuron: ASTModel): mechs_info = info_collector.detect_mechs(cls.mechType) # collect and process all basic mechanism information - mechs_info = info_collector.collect_mechanism_related_definitions(neuron, mechs_info) + mechs_info = info_collector.collect_mechanism_related_definitions(neuron, mechs_info, global_info, cls.mechType) + cls.extract_mech_blocks(info_collector, mechs_info, global_info) mechs_info = info_collector.extend_variables_with_initialisations(neuron, mechs_info) - mechs_info = cls.ode_toolbox_processing(neuron, mechs_info) + + mechs_info = info_collector.collect_kernels(neuron, mechs_info) + mechs_info = cls.convolution_ode_toolbox_processing(neuron, mechs_info) # collect and process all mechanism type specific information mechs_info = cls.collect_information_for_specific_mech_types(neuron, mechs_info) @@ -175,6 +314,30 @@ def check_co_co(cls, neuron: ASTModel): cls.mechs_info[neuron][cls.mechType] = mechs_info cls.first_time_run[neuron][cls.mechType] = False + @classmethod + def get_transformed_ode_equations(cls, mechs_info: dict): + from pynestml.utils.mechs_info_enricher import SynsInfoEnricherVisitor + enriched_mechs_info = copy.copy(mechs_info) + for mechanism_name, mechanism_info in mechs_info.items(): + transformed_odes = list() + for ode in mechs_info[mechanism_name]["ODEs"]: + ode_name = ode.lhs.name + transformed_odes.append( + SynsInfoEnricherVisitor.ode_name_to_transformed_ode[ode_name]) + enriched_mechs_info[mechanism_name]["ODEs"] = transformed_odes + + return enriched_mechs_info + + @classmethod + def check_all_convolutions_with_self_spikes(cls, mechs_info): + for mechanism_name, mechanism_info in mechs_info.items(): + for convolution_name, convolution in mechanism_info["convolutions"].items(): + if convolution["spikes"]["name"] != "self_spikes": + code, message = Messages.cm_non_self_spike_convolution_in_mech(mechanism_name, cls.mechType) + Logger.log_message(error_position=None, + code=code, message=message, + log_level=LoggingLevel.ERROR) + @classmethod def print_element(cls, name, element, rec_step): message = "" @@ -189,6 +352,8 @@ def print_element(cls, name, element, rec_step): message += element.name elif isinstance(element, str): message += element + elif isinstance(element, bool): + message += str(element) elif isinstance(element, dict): message += "\n" message += cls.print_dictionary(element, rec_step + 1) @@ -198,6 +363,8 @@ def print_element(cls, name, element, rec_step): message += cls.print_element(str(index), element[index], rec_step + 1) elif isinstance(element, ASTExpression) or isinstance(element, ASTSimpleExpression): message += cls._ode_toolbox_printer.print(element) + elif isinstance(element, ASTInlineExpression): + message += cls._ode_toolbox_printer.print(element.get_expression()) message += "(" + type(element).__name__ + ")" return message diff --git a/pynestml/utils/mechs_info_enricher.py b/pynestml/utils/mechs_info_enricher.py index ea645a02c..3aab1e879 100644 --- a/pynestml/utils/mechs_info_enricher.py +++ b/pynestml/utils/mechs_info_enricher.py @@ -18,9 +18,30 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +import copy from collections import defaultdict +from odetoolbox import analysis +from pynestml.cocos.co_cos_manager import CoCosManager + +from pynestml.symbol_table.symbol_table import SymbolTable + +from pynestml.codegeneration.printers.sympy_simple_expression_printer import SympySimpleExpressionPrinter +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression + +from pynestml.meta_model.ast_small_stmt import ASTSmallStmt + +from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter + +from pynestml.codegeneration.printers.ode_toolbox_function_call_printer import ODEToolboxFunctionCallPrinter + +from pynestml.codegeneration.printers.ode_toolbox_variable_printer import ODEToolboxVariablePrinter + +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter + +from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter + +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_model import ASTModel from pynestml.symbols.predefined_functions import PredefinedFunctions from pynestml.symbols.symbol import SymbolKind @@ -34,19 +55,126 @@ class MechsInfoEnricher: """ + This file is part of the compartmental code generation process. + Adds information collection that can't be done in the processing class since that is used in the cocos. Here we use the ModelParser which would lead to a cyclic dependency. """ + # ODE-toolbox printers + _constant_printer = ConstantPrinter() + _ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) + _ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) + _ode_toolbox_printer = ODEToolboxExpressionPrinter( + simple_expression_printer=SympySimpleExpressionPrinter( + variable_printer=_ode_toolbox_variable_printer, + constant_printer=_constant_printer, + function_call_printer=_ode_toolbox_function_call_printer)) + + _ode_toolbox_variable_printer._expression_printer = _ode_toolbox_printer + _ode_toolbox_function_call_printer._expression_printer = _ode_toolbox_printer + def __init__(self): pass @classmethod def enrich_with_additional_info(cls, neuron: ASTModel, mechs_info: dict): + neuron.accept(SynsInfoEnricherVisitor()) + mechs_info = cls.get_transformed_ode_equations(mechs_info) + mechs_info = cls.ode_toolbox_processing(neuron, mechs_info) + + cls.add_propagators_to_internals(neuron, mechs_info) + neuron.accept(SynsInfoEnricherVisitor()) + mechs_info = cls.transform_ode_solutions(neuron, mechs_info) + mechs_info = cls.transform_convolutions_analytic_solutions_generall(neuron, mechs_info) mechs_info = cls.enrich_mechanism_specific(neuron, mechs_info) return mechs_info + @classmethod + def get_transformed_ode_equations(cls, mechs_info: dict): + enriched_mechs_info = copy.copy(mechs_info) + for mechanism_name, mechanism_info in mechs_info.items(): + transformed_odes = list() + for ode in mechs_info[mechanism_name]["ODEs"]: + ode_name = ode.lhs.name + transformed_odes.append( + SynsInfoEnricherVisitor.ode_name_to_transformed_ode[ode_name]) + enriched_mechs_info[mechanism_name]["ODEs"] = transformed_odes + + return enriched_mechs_info + + @classmethod + def ode_toolbox_processing(cls, neuron, mechs_info): + mechs_info = cls.prepare_equations_for_ode_toolbox(neuron, mechs_info) + mechs_info = cls.collect_raw_odetoolbox_output(mechs_info) + return mechs_info + + @classmethod + def prepare_equations_for_ode_toolbox(cls, neuron, mechs_info): + """Transforms the collected ode equations to the required input format of ode-toolbox and adds it to the + mechs_info dictionary""" + for mechanism_name, mechanism_info in mechs_info.items(): + mechanism_odes = defaultdict() + for ode in mechanism_info["ODEs"]: + nestml_printer = NESTMLPrinter() + ode_nestml_expression = nestml_printer.print_ode_equation(ode) + mechanism_odes[ode.lhs.name] = defaultdict() + mechanism_odes[ode.lhs.name]["ASTOdeEquation"] = ode + mechanism_odes[ode.lhs.name]["ODENestmlExpression"] = ode_nestml_expression + mechs_info[mechanism_name]["ODEs"] = mechanism_odes + + for mechanism_name, mechanism_info in mechs_info.items(): + for ode_variable_name, ode_info in mechanism_info["ODEs"].items(): + # Expression: + odetoolbox_indict = {"dynamics": []} + lhs = ASTUtils.to_ode_toolbox_name(ode_info["ASTOdeEquation"].get_lhs().get_complete_name()) + rhs = cls._ode_toolbox_printer.print(ode_info["ASTOdeEquation"].get_rhs()) + entry = {"expression": lhs + " = " + rhs, "initial_values": {}} + + # Initial values: + symbol_order = ode_info["ASTOdeEquation"].get_lhs().get_differential_order() + for order in range(symbol_order): + iv_symbol_name = ode_info["ASTOdeEquation"].get_lhs().get_name() + "'" * order + initial_value_expr = neuron.get_initial_value(iv_symbol_name) + entry["initial_values"][ + ASTUtils.to_ode_toolbox_name(iv_symbol_name)] = cls._ode_toolbox_printer.print( + initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + mechs_info[mechanism_name]["ODEs"][ode_variable_name]["ode_toolbox_input"] = odetoolbox_indict + + return mechs_info + + @classmethod + def collect_raw_odetoolbox_output(cls, mechs_info): + """calls ode-toolbox for each ode individually and collects the raw output""" + for mechanism_name, mechanism_info in mechs_info.items(): + for ode_variable_name, ode_info in mechanism_info["ODEs"].items(): + solver_result = analysis(ode_info["ode_toolbox_input"], disable_stiffness_check=True) + mechs_info[mechanism_name]["ODEs"][ode_variable_name]["ode_toolbox_output"] = solver_result + + return mechs_info + + @classmethod + def add_propagators_to_internals(cls, neuron, mechs_info): + for mechanism_name, mechanism_info in mechs_info.items(): + for ode_var_name, ode_info in mechanism_info["ODEs"].items(): + for ode_solution_index in range(len(ode_info["ode_toolbox_output"])): + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["propagators"].items(): + ASTUtils.add_declaration_to_internals(neuron, variable_name, rhs_str) + + if "convolutions" in mechanism_info: + for convolution_name, convolution_info in mechanism_info["convolutions"].items(): + for variable_name, rhs_str in convolution_info["analytic_solution"]["propagators"].items(): + ASTUtils.add_declaration_to_internals(neuron, variable_name, rhs_str) + + SymbolTable.delete_model_scope(neuron.get_name()) + symbol_table_visitor = ASTSymbolTableVisitor() + neuron.accept(symbol_table_visitor) + CoCosManager.check_cocos(neuron, after_ast_rewrite=True) + SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) + @classmethod def transform_ode_solutions(cls, neuron, mechs_info): for mechanism_name, mechanism_info in mechs_info.items(): @@ -58,9 +186,37 @@ def transform_ode_solutions(cls, neuron, mechs_info): solution_transformed["states"] = defaultdict() solution_transformed["propagators"] = defaultdict() + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ + variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + + solution_transformed["states"][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["propagators"].items(): prop_variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE) + if prop_variable is None: ASTUtils.add_declarations_to_internals( neuron, ode_info["ode_toolbox_output"][ode_solution_index]["propagators"]) @@ -92,41 +248,205 @@ def transform_ode_solutions(cls, neuron, mechs_info): PredefinedFunctions.TIME_RESOLUTION: mechanism_info["time_resolution_var"] = variable - for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items(): - variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, - SymbolKind.VARIABLE) + mechanism_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) - expression = ModelParser.parse_expression(rhs_str) - # pretend that update expressions are in "equations" block, - # which should always be present, as synapses have been - # defined to get here - expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) - expression.accept(ASTSymbolTableVisitor()) + neuron.accept(ASTParentVisitor()) - update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ - variable_name] - update_expr_ast = ModelParser.parse_expression( - update_expr_str) - # pretend that update expressions are in "equations" block, - # which should always be present, as differential equations - # must have been defined to get here - update_expr_ast.update_scope( - neuron.get_scope()) - update_expr_ast.accept(ASTParentVisitor()) - update_expr_ast.accept(ASTSymbolTableVisitor()) - neuron.accept(ASTSymbolTableVisitor()) + return mechs_info - solution_transformed["states"][variable_name] = { - "ASTVariable": variable, - "init_expression": expression, - "update_expression": update_expr_ast, - } + @classmethod + def transform_convolutions_analytic_solutions_generall(cls, neuron: ASTModel, cm_mechs_info: dict): + enriched_syns_info = copy.copy(cm_mechs_info) + for mechanism_name, mechanism_info in cm_mechs_info.items(): + for convolution_name in mechanism_info["convolutions"].keys(): + analytic_solution = enriched_syns_info[mechanism_name][ + "convolutions"][convolution_name]["analytic_solution"] + analytic_solution_transformed = defaultdict( + lambda: defaultdict()) - mechanism_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) + for variable_name, expression_str in analytic_solution["initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) - neuron.accept(ASTParentVisitor()) + expression = ModelParser.parse_expression(expression_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) - return mechs_info + update_expr_str = analytic_solution["update_expressions"][variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + + analytic_solution_transformed['kernel_states'][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } + + mechanism_info = cls.get_time_res_var_conv_declaration(neuron, mechanism_info, expression) + + for variable_name, expression_string in analytic_solution["propagators"].items( + ): + variable = SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] + expression = ModelParser.parse_expression( + expression_string) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + analytic_solution_transformed['propagators'][variable_name] = { + "ASTVariable": variable, "init_expression": expression, } + + mechanism_info = cls.get_time_res_var_conv_declaration(neuron, mechanism_info, expression) + + enriched_syns_info[mechanism_name]["convolutions"][convolution_name]["analytic_solution"] = \ + analytic_solution_transformed + + if isinstance(enriched_syns_info[mechanism_name]["root_expression"], ASTInlineExpression): + inline_expression_name = enriched_syns_info[mechanism_name]["root_expression"].variable_name + enriched_syns_info[mechanism_name]["root_expression"] = \ + SynsInfoEnricherVisitor.inline_name_to_transformed_inline[inline_expression_name] + + transformed_inlines = list() + for inline in cm_mechs_info[mechanism_name]["SecondaryInlineExpressions"]: + inline_expression_name = inline.variable_name + transformed_inlines.append( + SynsInfoEnricherVisitor.inline_name_to_transformed_inline[inline_expression_name]) + enriched_syns_info[mechanism_name]["secondary_inline_expressions"] = transformed_inlines + + return enriched_syns_info + + @classmethod + def get_analytic_helper_variable_declarations(cls, single_synapse_info): + variable_names = cls.get_analytic_helper_variable_names( + single_synapse_info) + result = dict() + for variable_name in variable_names: + if variable_name not in SynsInfoEnricherVisitor.internal_variable_name_to_variable: + continue + variable = SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] + expression = SynsInfoEnricherVisitor.variables_to_internal_declarations[variable] + result[variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + } + if expression.is_function_call() and expression.get_function_call( + ).callee_name == PredefinedFunctions.TIME_RESOLUTION: + result[variable_name]["is_time_resolution"] = True + else: + result[variable_name]["is_time_resolution"] = False + + return result + + @classmethod + def get_analytic_helper_variable_names(cls, single_synapse_info): + """get new variables that only occur on the right hand side of analytic solution Expressions + but for wich analytic solution does not offer any values + this can isolate out additional variables that suddenly appear such as __h + whose initial values are not inlcuded in the output of analytic solver""" + + analytic_lhs_vars = set() + + for convolution_name, convolution_info in single_synapse_info["convolutions"].items( + ): + analytic_sol = convolution_info["analytic_solution"] + + # get variables representing convolutions by kernel + for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( + ): + analytic_lhs_vars.add(kernel_var_name) + + # get propagator variable names + for propagator_var_name, propagator_info in analytic_sol["propagators"].items( + ): + analytic_lhs_vars.add(propagator_var_name) + + return cls.get_new_variables_after_transformation( + single_synapse_info).symmetric_difference(analytic_lhs_vars) + + @classmethod + def get_new_variables_after_transformation(cls, single_synapse_info): + return cls.get_all_synapse_variables(single_synapse_info).difference( + single_synapse_info["total_used_declared"]) + + @classmethod + def get_all_synapse_variables(cls, single_synapse_info): + """returns all variable names referenced by the synapse inline + and by the analytical solution + assumes that the model has already been transformed""" + + # get all variables from transformed inline + inline_variables = cls.get_variable_names_used( + single_synapse_info["root_expression"]) + + analytic_solution_vars = set() + # get all variables from transformed analytic solution + for convolution_name, convolution_info in single_synapse_info["convolutions"].items( + ): + analytic_sol = convolution_info["analytic_solution"] + # get variables from init and update expressions + # for each kernel + for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( + ): + analytic_solution_vars.add(kernel_var_name) + + update_vars = cls.get_variable_names_used( + kernel_info["update_expression"]) + init_vars = cls.get_variable_names_used( + kernel_info["init_expression"]) + + analytic_solution_vars.update(update_vars) + analytic_solution_vars.update(init_vars) + + # get variables from init expressions + # for each propagator + # include propagator variable itself + for propagator_var_name, propagator_info in analytic_sol["propagators"].items( + ): + analytic_solution_vars.add(propagator_var_name) + + init_vars = cls.get_variable_names_used( + propagator_info["init_expression"]) + + analytic_solution_vars.update(init_vars) + + return analytic_solution_vars.union(inline_variables) + + @classmethod + def get_variable_names_used(cls, node) -> set: + variable_names_extractor = ASTUsedVariableNamesExtractor(node) + return variable_names_extractor.variable_names + + @classmethod + def get_time_res_var_conv_declaration(cls, neuron, mechanism_info, expression): + expression_variable_collector = ASTEnricherInfoCollectorVisitor() + expression.accept(expression_variable_collector) + + # now also identify analytic helper variables such as __h + neuron_internal_declaration_collector = ASTEnricherInfoCollectorVisitor() + neuron.accept(neuron_internal_declaration_collector) + + for variable in expression_variable_collector.all_variables: + for internal_declaration in neuron_internal_declaration_collector.internal_declarations: + if variable.get_name() == internal_declaration.get_variables()[0].get_name() \ + and (isinstance(internal_declaration.get_expression(), ASTSmallStmt) + or isinstance(internal_declaration.get_expression(), ASTSimpleExpression)) \ + and internal_declaration.get_expression().is_function_call() \ + and internal_declaration.get_expression().get_function_call().callee_name == \ + PredefinedFunctions.TIME_RESOLUTION: + mechanism_info["time_resolution_var"] = variable + + return mechanism_info @classmethod def enrich_mechanism_specific(cls, neuron, mechs_info): @@ -181,3 +501,88 @@ def visit_declaration(self, node): def endvisit_declaration(self, node): self.inside_declaration = False + + +class SynsInfoEnricherVisitor(ASTVisitor): + variables_to_internal_declarations = {} + internal_variable_name_to_variable = {} + inline_name_to_transformed_inline = {} + ode_name_to_transformed_ode = {} + + # assuming depth first traversal + # collect declaratins in the order + # in which they were present in the neuron + declarations_ordered = [] + + def __init__(self): + super(SynsInfoEnricherVisitor, self).__init__() + + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internals_block = False + self.inside_inline_expression = False + self.inside_inline_expression = False + self.inside_declaration = False + self.inside_simple_expression = False + self.inside_ode_equation = False + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + inline_name = node.variable_name + SynsInfoEnricherVisitor.inline_name_to_transformed_inline[inline_name] = node + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + + def visit_ode_equation(self, node): + self.inside_ode_equation = True + ode_name = node.lhs.name + SynsInfoEnricherVisitor.ode_name_to_transformed_ode[ode_name] = node + + def endvisit_ode_equation(self, node): + self.inside_ode_equation = False + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = False + if node.is_parameters: + self.inside_parameter_block = False + if node.is_internals: + self.inside_internals_block = False + + def visit_simple_expression(self, node): + self.inside_simple_expression = True + + def endvisit_simple_expression(self, node): + self.inside_simple_expression = False + + def visit_declaration(self, node): + self.declarations_ordered.append(node) + self.inside_declaration = True + if self.inside_internals_block: + variable = node.get_variables()[0] + expression = node.get_expression() + SynsInfoEnricherVisitor.variables_to_internal_declarations[variable] = expression + SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable.get_name( + )] = variable + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTUsedVariableNamesExtractor(ASTVisitor): + def __init__(self, node): + super(ASTUsedVariableNamesExtractor, self).__init__() + self.variable_names = set() + node.accept(self) + + def visit_variable(self, node): + self.variable_names.add(node.get_name()) diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py index 1930d91e0..855e2e8a4 100644 --- a/pynestml/utils/messages.py +++ b/pynestml/utils/messages.py @@ -144,6 +144,8 @@ class MessageCode(Enum): WEIGHT_VARIABLE_NOT_SPECIFIED = 119 DELAY_VARIABLE_NOT_FOUND = 120 WEIGHT_VARIABLE_NOT_FOUND = 121 + CM_VAR_MULTIUSE = 122 + CM_INVALID_CONVOLUTION_BUFFER = 123 class Messages: @@ -389,7 +391,7 @@ def get_code_generated(cls, model_name: str, path: str) -> Tuple[MessageCode, st assert (path is not None and isinstance(path, str)), \ '(PyNestML.Utils.Message) Not a string provided (%s)!' % type(path) message = 'Successfully generated code for the model: \'' + \ - model_name + '\' in: \'' + path + '\' !' + model_name + '\' in: \'' + path + '\' !' return MessageCode.CODE_SUCCESSFULLY_GENERATED, message @classmethod @@ -929,13 +931,13 @@ def get_ode_needs_consistent_units(cls, name, differential_order, lhs_type, rhs_ message = 'ODE definition for \'' if differential_order > 1: message += 'd^' + str(differential_order) + ' ' + \ - name + ' / dt^' + str(differential_order) + '\'' + name + ' / dt^' + str(differential_order) + '\'' if differential_order > 0: message += 'd ' + name + ' / dt\'' else: message += '\'' + str(name) + '\'' message += ' has inconsistent units: expected \'' + \ - lhs_type.print_symbol() + '\', got \'' + rhs_type.print_symbol() + '\'' + lhs_type.print_symbol() + '\', got \'' + rhs_type.print_symbol() + '\'' return MessageCode.ODE_NEEDS_CONSISTENT_UNITS, message @classmethod @@ -944,7 +946,7 @@ def get_ode_function_needs_consistent_units( assert (name is not None and isinstance(name, str)), \ '(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name) message = 'ODE function definition for \'' + name + '\' has inconsistent units: expected \'' + \ - declared_type.print_symbol() + '\', got \'' + expression_type.print_symbol() + '\'' + declared_type.print_symbol() + '\', got \'' + expression_type.print_symbol() + '\'' return MessageCode.ODE_FUNCTION_NEEDS_CONSISTENT_UNITS, message @classmethod @@ -996,9 +998,9 @@ def templated_arg_types_inconsistent(cls, function_name, failing_arg_idx, other_ :rtype: (MessageCode,str) """ message = 'In function \'' + function_name + '\': actual derived type of templated parameter ' + \ - str(failing_arg_idx + 1) + ' is \'' + failing_arg_type_str + '\', which is inconsistent with that of parameter(s) ' + \ - ', '.join([str(_ + 1) for _ in other_args_idx]) + \ - ', which has/have type \'' + other_type_str + '\'' + str(failing_arg_idx + 1) + ' is \'' + failing_arg_type_str + '\', which is inconsistent with that of parameter(s) ' + \ + ', '.join([str(_ + 1) for _ in other_args_idx]) + \ + ', which has/have type \'' + other_type_str + '\'' return MessageCode.TEMPLATED_ARG_TYPES_INCONSISTENT, message @classmethod @@ -1062,8 +1064,7 @@ def get_output_port_type_differs(cls) -> Tuple[MessageCode, str]: def get_kernel_wrong_type(cls, kernel_name: str, differential_order: int, - actual_type: str) -> Tuple[MessageCode, - str]: + actual_type: str) -> Tuple[MessageCode, str]: """ Returns a message indicating that the type of a kernel is wrong. :param kernel_name: the name of the kernel @@ -1085,8 +1086,7 @@ def get_kernel_wrong_type(cls, def get_kernel_iv_wrong_type(cls, iv_name: str, actual_type: str, - expected_type: str) -> Tuple[MessageCode, - str]: + expected_type: str) -> Tuple[MessageCode, str]: """ Returns a message indicating that the type of a kernel initial value is wrong. :param iv_name: the name of the state variable with an initial value @@ -1116,25 +1116,25 @@ def get_equations_defined_but_integrate_odes_not_called(cls): def get_template_root_path_created(cls, templates_root_dir: str): message = "Given template root path is not an absolute path. " \ "Creating the absolute path with default templates directory '" + \ - templates_root_dir + "'" + templates_root_dir + "'" return MessageCode.TEMPLATE_ROOT_PATH_CREATED, message @classmethod def get_vector_parameter_wrong_block(cls, var, block): message = "The vector parameter '" + var + "' is declared in the wrong block '" + block + "'. " \ - "The vector parameter can only be declared in parameters or internals block." + "The vector parameter can only be declared in parameters or internals block." return MessageCode.VECTOR_PARAMETER_WRONG_BLOCK, message @classmethod def get_vector_parameter_wrong_type(cls, var): message = "The vector parameter '" + var + "' is of the wrong type. " \ - "The vector parameter can be only of type integer." + "The vector parameter can be only of type integer." return MessageCode.VECTOR_PARAMETER_WRONG_TYPE, message @classmethod def get_vector_parameter_wrong_size(cls, var, value): message = "The vector parameter '" + var + "' has value '" + value + "' " \ - "which is less than or equal to 0." + "which is less than or equal to 0." return MessageCode.VECTOR_PARAMETER_WRONG_SIZE, message @classmethod @@ -1167,9 +1167,9 @@ def get_no_gating_variables( """ message = "No gating variables found inside declaration of '" + \ - cm_inline_expr.variable_name + "', " + cm_inline_expr.variable_name + "', " message += "\nmeaning no variable ends with the suffix '_" + \ - ion_channel_name + "' here. " + ion_channel_name + "' here. " message += "This suffix indicates that a variable is a gating variable. " message += "At least one gating variable is expected to exist." @@ -1182,7 +1182,7 @@ def get_cm_inline_expression_variable_used_mulitple_times( bad_variable_name: str, ion_channel_name: str): message = "Variable name '" + bad_variable_name + \ - "' seems to be used multiple times" + "' seems to be used multiple times" message += "' inside inline expression '" + cm_inline_expr.variable_name + "'. " message += "\nVariables are not allowed to occur multiple times here." @@ -1196,23 +1196,23 @@ def get_expected_cm_function_missing( function_name: str): message = "Implementation of a function called '" + function_name + "' not found. " message += "It is expected because of variable '" + \ - variable_name + "' in the ion channel '" + ion_channel_name + "'" + variable_name + "' in the ion channel '" + ion_channel_name + "'" return MessageCode.CM_FUNCTION_MISSING, message @classmethod def get_expected_cm_function_wrong_args_count( cls, ion_channel_name: str, variable_name, astfun: ASTFunction): message = "Function '" + astfun.name + \ - "' is expected to have exactly one Argument. " + "' is expected to have exactly one Argument. " message += "It is related to variable '" + variable_name + \ - "' in the ion channel '" + ion_channel_name + "'" + "' in the ion channel '" + ion_channel_name + "'" return MessageCode.CM_FUNCTION_BAD_NUMBER_ARGS, message @classmethod def get_expected_cm_function_bad_return_type( cls, ion_channel_name: str, astfun: ASTFunction) -> Tuple[MessageCode, str]: message = "'" + ion_channel_name + "' channel function '" + \ - astfun.name + "' must return real. " + astfun.name + "' must return real. " return MessageCode.CM_FUNCTION_BAD_RETURN_TYPE, message @classmethod @@ -1223,7 +1223,7 @@ def get_expected_cm_variables_missing_in_blocks(cls, for missing_var, proper_location in missing_variable_to_proper_block.items(): message += "Variable with name '" + missing_var message += "' not found but expected to exist inside of " + \ - proper_location + " because of position " + proper_location + " because of position " message += str( expected_variables_to_reason[missing_var].get_source_position()) + "\n" return MessageCode.CM_VARIABLES_NOT_DECLARED, message @@ -1247,7 +1247,7 @@ def get_v_comp_variable_value_missing(cls, neuron_name: str, missing_variable_na def get_syns_bad_buffer_count(cls, buffers: set, synapse_name: str) -> Tuple[MessageCode, str]: message = "Synapse `\'%s\' uses the following input buffers: %s" % ( synapse_name, buffers) - message += " However exaxtly one spike input buffer per synapse is allowed." + message += " However exaxtly one spike input buffer aside the self_spikes buffer is allowed per synapse." return MessageCode.SYNS_BAD_BUFFER_COUNT, message @classmethod @@ -1291,15 +1291,35 @@ def get_integrate_odes_arg_higher_order(cls, arg: str) -> Tuple[MessageCode, str return MessageCode.INTEGRATE_ODES_ARG_HIGHER_ORDER, message @classmethod - def get_mechs_dictionary_info(cls, chan_info, syns_info, conc_info, con_in_info) -> Tuple[MessageCode, str]: + def get_mechs_dictionary_info(cls, chan_info, recs_info, conc_info, con_in_info, syns_info, global_info) -> Tuple[MessageCode, str]: message = "" message += "chan_info:\n" + chan_info + "\n" - message += "syns_info:\n" + syns_info + "\n" + message += "recs_info:\n" + recs_info + "\n" message += "conc_info:\n" + conc_info + "\n" message += "con_in_info:\n" + con_in_info + "\n" + message += "syns_info:\n" + syns_info + "\n" + message += "global_info:\n" + global_info + "\n" return MessageCode.MECHS_DICTIONARY_INFO, message + @classmethod + def cm_shared_variables_not_allowed(cls, varname: str, mech_names: list): + message = "Multiple mechanisms (" + it = iter(mech_names) + for mech_name in mech_names: + message += mech_name + if mech_name == next(it, None): + message += ", " + message += ") are referencing the same variable: '" + varname + "'" + + return MessageCode.CM_VAR_MULTIUSE, message + + @classmethod + def cm_non_self_spike_convolution_in_mech(cls, mech_name: str, mech_type: str): + message = ("Only convolutions with buffer self_spikes are allowed in mechanisms of type '" + mech_type + "' but are contained in '" + mech_name + "'.") + + return MessageCode.CM_INVALID_CONVOLUTION_BUFFER, message + @classmethod def get_fixed_timestep_func_used(cls): message = "Model contains a call to fixed-timestep functions (``resolution()`` and/or ``steps()``). This restricts the model to being compatible only with fixed-timestep simulators. Consider eliminating ``resolution()`` and ``steps()`` from the model, and using ``timestep()`` instead." diff --git a/pynestml/utils/receptor_processing.py b/pynestml/utils/receptor_processing.py new file mode 100644 index 000000000..1eb397fd6 --- /dev/null +++ b/pynestml/utils/receptor_processing.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +# +# receptor_processing.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import copy + +from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables +from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbols.symbol import SymbolKind +from pynestml.utils.ast_receptor_information_collector import ASTReceptorInformationCollector +from pynestml.utils.ast_utils import ASTUtils +from pynestml.utils.logger import Logger, LoggingLevel +from pynestml.utils.mechanism_processing import MechanismProcessing +from pynestml.utils.messages import Messages + +from odetoolbox import analysis + + +class ReceptorProcessing(MechanismProcessing): + """ + This file is part of the compartmental code generation process. + + Receptor mechanism specific processing. + """ + mechType = "receptor" + + def __init__(self, params): + super(MechanismProcessing, self).__init__(params) + + @classmethod + def collect_information_for_specific_mech_types(cls, neuron, mechs_info): + mechs_info, add_info_collector = cls.collect_additional_base_infos(neuron, mechs_info) + if len(mechs_info) > 0: + # only do this if any synapses found + # otherwise tests may fail + mechs_info = cls.collect_and_check_inputs_per_synapse(mechs_info) + + return mechs_info + + @classmethod + def collect_additional_base_infos(cls, neuron, syns_info): + """ + Collect internals, kernels, inputs and convolutions associated with the synapse. + """ + info_collector = ASTReceptorInformationCollector() + neuron.accept(info_collector) + for synapse_name, synapse_info in syns_info.items(): + synapse_inline = syns_info[synapse_name]["root_expression"] + syns_info[synapse_name][ + "internals_used_declared"] = info_collector.get_synapse_specific_internal_declarations(synapse_inline) + syns_info[synapse_name]["total_used_declared"] = info_collector.get_variable_names_of_synapse( + synapse_inline) + + return syns_info, info_collector + + @classmethod + def collect_and_check_inputs_per_synapse( + cls, + syns_info: dict): + new_syns_info = copy.copy(syns_info) + + # collect all buffers used + for synapse_name, synapse_info in syns_info.items(): + new_syns_info[synapse_name]["buffers_used"] = set() + for convolution_name, convolution_info in synapse_info["convolutions"].items( + ): + input_name = convolution_info["spikes"]["name"] + if input_name != "self_spikes": + new_syns_info[synapse_name]["buffers_used"].add(input_name) + + # now make sure each synapse is using exactly one buffer except self_spikes + for synapse_name, synapse_info in syns_info.items(): + buffers = new_syns_info[synapse_name]["buffers_used"] + if len(buffers) != 1: + code, message = Messages.get_syns_bad_buffer_count( + buffers, synapse_name) + causing_object = synapse_info["root_expression"] + Logger.log_message( + code=code, + message=message, + error_position=causing_object.get_source_position(), + log_level=LoggingLevel.ERROR, + node=causing_object) + + return new_syns_info + + @classmethod + def ode_solve_convolution(cls, + neuron: ASTModel, + parameters_block: ASTBlockWithVariables, + kernel_buffer): + odetoolbox_indict = cls.create_ode_indict( + neuron, parameters_block, kernel_buffer) + full_solver_result = analysis( + odetoolbox_indict, + disable_stiffness_check=True, + log_level=FrontendConfiguration.logging_level) + analytic_solver = None + analytic_solvers = [ + x for x in full_solver_result if x["solver"] == "analytical"] + assert len( + analytic_solvers) <= 1, "More than one analytic solver not presently supported" + if len(analytic_solvers) > 0: + analytic_solver = analytic_solvers[0] + + return analytic_solver + + @classmethod + def create_ode_indict(cls, + neuron: ASTModel, + parameters_block: ASTBlockWithVariables, + kernel_buffer): + kernel_buffers = {tuple(kernel_buffer)} + odetoolbox_indict = cls.transform_ode_and_kernels_to_json( + neuron, parameters_block, kernel_buffers) + odetoolbox_indict["options"] = {} + odetoolbox_indict["options"]["output_timestep_symbol"] = "__h" + return odetoolbox_indict + + @classmethod + def transform_ode_and_kernels_to_json( + cls, + neuron: ASTModel, + parameters_block, + kernel_buffers): + """ + Converts AST node to a JSON representation suitable for passing to ode-toolbox. + + Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements + + convolve(G, ex_spikes) + convolve(G, in_spikes) + + then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`. + + :param parameters_block: ASTBlockWithVariables + :return: Dict + """ + odetoolbox_indict = {"dynamics": []} + + equations_block = neuron.get_equations_blocks()[0] + + for kernel, spike_input_port in kernel_buffers: + if ASTUtils.is_delta_kernel(kernel): + continue + # delta function -- skip passing this to ode-toolbox + + for kernel_var in kernel.get_variables(): + expr = ASTUtils.get_expr_from_kernel_var( + kernel, kernel_var.get_complete_name()) + kernel_order = kernel_var.get_differential_order() + kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name( + kernel_var.get_name(), spike_input_port.get_name(), kernel_order, diff_order_symbol="'") + + ASTUtils.replace_rhs_variables(expr, kernel_buffers) + + entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}} + + # initial values need to be declared for order 1 up to kernel + # order (e.g. none for kernel function f(t) = ...; 1 for kernel + # ODE f'(t) = ...; 2 for f''(t) = ... and so on) + for order in range(kernel_order): + iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name( + kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") + symbol_name_ = kernel_var.get_name() + "'" * order + symbol = equations_block.get_scope().resolve_to_symbol( + symbol_name_, SymbolKind.VARIABLE) + assert symbol is not None, "Could not find initial value for variable " + symbol_name_ + initial_value_expr = symbol.get_declaring_expression() + assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ + entry["initial_values"][iv_sym_name_ode_toolbox] = cls._ode_toolbox_printer.print( + initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + + odetoolbox_indict["parameters"] = {} + if parameters_block is not None: + for decl in parameters_block.get_declarations(): + for var in decl.variables: + odetoolbox_indict["parameters"][var.get_complete_name( + )] = cls._ode_toolbox_printer.print(decl.get_expression()) + + return odetoolbox_indict diff --git a/pynestml/utils/recs_info_enricher.py b/pynestml/utils/recs_info_enricher.py new file mode 100644 index 000000000..1af7a31b7 --- /dev/null +++ b/pynestml/utils/recs_info_enricher.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- +# +# recs_info_enricher.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + + +import copy +import sympy + +from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.utils.mechs_info_enricher import MechsInfoEnricher +from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.visitors.ast_visitor import ASTVisitor + + +class RecsInfoEnricher(MechsInfoEnricher): + """ + This file is part of the compartmental code generation process. + + input: a neuron after ODE-toolbox transformations + + the kernel analysis solves all kernels at the same time + this splits the variables on per kernel basis + """ + + def __init__(self, params): + super(MechsInfoEnricher, self).__init__(params) + + @classmethod + def enrich_mechanism_specific(cls, neuron, mechs_info): + specific_enricher_visitor = RecsInfoEnricherVisitor() + neuron.accept(specific_enricher_visitor) + mechs_info = cls.transform_convolutions_analytic_solutions(neuron, mechs_info) + mechs_info = cls.compute_expression_derivative(mechs_info) + mechs_info = cls.restore_order_internals(neuron, mechs_info) + return mechs_info + + @classmethod + def transform_convolutions_analytic_solutions(cls, neuron: ASTModel, cm_syns_info: dict): + + enriched_syns_info = copy.copy(cm_syns_info) + for synapse_name, synapse_info in cm_syns_info.items(): + # only one buffer allowed, so allow direct access + # to it instead of a list + if "buffer_name" not in enriched_syns_info[synapse_name]: + buffers_used = list( + enriched_syns_info[synapse_name]["buffers_used"]) + del enriched_syns_info[synapse_name]["buffers_used"] + enriched_syns_info[synapse_name]["buffer_name"] = buffers_used[0] + + # now also identify analytic helper variables such as __h + enriched_syns_info[synapse_name]["analytic_helpers"] = cls.get_analytic_helper_variable_declarations( + enriched_syns_info[synapse_name]) + + return enriched_syns_info + + @classmethod + def restore_order_internals(cls, neuron: ASTModel, cm_syns_info: dict): + """orders user defined internals + back to the order they were originally defined + this is important if one such variable uses another + user needs to have control over the order + assign each variable a rank + that corresponds to the order in + RecsInfoEnricher.declarations_ordered""" + variable_name_to_order = {} + for index, declaration in enumerate( + RecsInfoEnricherVisitor.declarations_ordered): + variable_name = declaration.get_variables()[0].get_name() + variable_name_to_order[variable_name] = index + + enriched_syns_info = copy.copy(cm_syns_info) + for synapse_name, synapse_info in cm_syns_info.items(): + user_internals = enriched_syns_info[synapse_name]["internals_used_declared"] + user_internals_sorted = sorted( + user_internals.items(), key=lambda x: variable_name_to_order[x[0]]) + enriched_syns_info[synapse_name]["internals_used_declared"] = user_internals_sorted + + return enriched_syns_info + + @classmethod + def compute_expression_derivative(cls, chan_info): + for ion_channel_name, ion_channel_info in chan_info.items(): + inline_expression = chan_info[ion_channel_name]["root_expression"] + expr_str = str(inline_expression.get_expression()) + sympy_expr = sympy.parsing.sympy_parser.parse_expr(expr_str) + sympy_expr = sympy.diff(sympy_expr, "v_comp") + + ast_expression_d = ModelParser.parse_expression(str(sympy_expr)) + # copy scope of the original inline_expression into the the derivative + ast_expression_d.update_scope(inline_expression.get_scope()) + ast_expression_d.accept(ASTSymbolTableVisitor()) + + chan_info[ion_channel_name]["inline_derivative"] = ast_expression_d + + return chan_info + + @classmethod + def get_variable_names_used(cls, node) -> set: + variable_names_extractor = ASTUsedVariableNamesExtractor(node) + return variable_names_extractor.variable_names + + @classmethod + def get_all_synapse_variables(cls, single_synapse_info): + """returns all variable names referenced by the synapse inline + and by the analytical solution + assumes that the model has already been transformed""" + + # get all variables from transformed inline + inline_variables = cls.get_variable_names_used( + single_synapse_info["root_expression"]) + + analytic_solution_vars = set() + # get all variables from transformed analytic solution + for convolution_name, convolution_info in single_synapse_info["convolutions"].items( + ): + analytic_sol = convolution_info["analytic_solution"] + # get variables from init and update expressions + # for each kernel + for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( + ): + analytic_solution_vars.add(kernel_var_name) + + update_vars = cls.get_variable_names_used( + kernel_info["update_expression"]) + init_vars = cls.get_variable_names_used( + kernel_info["init_expression"]) + + analytic_solution_vars.update(update_vars) + analytic_solution_vars.update(init_vars) + + # get variables from init expressions + # for each propagator + # include propagator variable itself + for propagator_var_name, propagator_info in analytic_sol["propagators"].items( + ): + analytic_solution_vars.add(propagator_var_name) + + init_vars = cls.get_variable_names_used( + propagator_info["init_expression"]) + + analytic_solution_vars.update(init_vars) + + return analytic_solution_vars.union(inline_variables) + + @classmethod + def get_new_variables_after_transformation(cls, single_synapse_info): + return cls.get_all_synapse_variables(single_synapse_info).difference( + single_synapse_info["total_used_declared"]) + + @classmethod + def get_analytic_helper_variable_names(cls, single_synapse_info): + """get new variables that only occur on the right hand side of analytic solution Expressions + but for wich analytic solution does not offer any values + this can isolate out additional variables that suddenly appear such as __h + whose initial values are not inlcuded in the output of analytic solver""" + + analytic_lhs_vars = set() + + for convolution_name, convolution_info in single_synapse_info["convolutions"].items( + ): + analytic_sol = convolution_info["analytic_solution"] + + # get variables representing convolutions by kernel + for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( + ): + analytic_lhs_vars.add(kernel_var_name) + + # get propagator variable names + for propagator_var_name, propagator_info in analytic_sol["propagators"].items( + ): + analytic_lhs_vars.add(propagator_var_name) + + return cls.get_new_variables_after_transformation( + single_synapse_info).symmetric_difference(analytic_lhs_vars) + + @classmethod + def get_analytic_helper_variable_declarations(cls, single_synapse_info): + variable_names = cls.get_analytic_helper_variable_names( + single_synapse_info) + result = dict() + for variable_name in variable_names: + if variable_name not in RecsInfoEnricherVisitor.internal_variable_name_to_variable: + continue + variable = RecsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] + expression = RecsInfoEnricherVisitor.variables_to_internal_declarations[variable] + result[variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + } + if expression.is_function_call() and expression.get_function_call( + ).callee_name == PredefinedFunctions.TIME_RESOLUTION: + result[variable_name]["is_time_resolution"] = True + else: + result[variable_name]["is_time_resolution"] = False + + return result + + +class RecsInfoEnricherVisitor(ASTVisitor): + variables_to_internal_declarations = {} + internal_variable_name_to_variable = {} + inline_name_to_transformed_inline = {} + + # assuming depth first traversal + # collect declaratins in the order + # in which they were present in the neuron + declarations_ordered = [] + + def __init__(self): + super(RecsInfoEnricherVisitor, self).__init__() + + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internals_block = False + self.inside_inline_expression = False + self.inside_inline_expression = False + self.inside_declaration = False + self.inside_simple_expression = False + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + inline_name = node.variable_name + RecsInfoEnricherVisitor.inline_name_to_transformed_inline[inline_name] = node + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = False + if node.is_parameters: + self.inside_parameter_block = False + if node.is_internals: + self.inside_internals_block = False + + def visit_simple_expression(self, node): + self.inside_simple_expression = True + + def endvisit_simple_expression(self, node): + self.inside_simple_expression = False + + def visit_declaration(self, node): + self.declarations_ordered.append(node) + self.inside_declaration = True + if self.inside_internals_block: + variable = node.get_variables()[0] + expression = node.get_expression() + RecsInfoEnricherVisitor.variables_to_internal_declarations[variable] = expression + RecsInfoEnricherVisitor.internal_variable_name_to_variable[variable.get_name( + )] = variable + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTUsedVariableNamesExtractor(ASTVisitor): + def __init__(self, node): + super(ASTUsedVariableNamesExtractor, self).__init__() + self.variable_names = set() + node.accept(self) + + def visit_variable(self, node): + self.variable_names.add(node.get_name()) diff --git a/pynestml/utils/synapse_processing.py b/pynestml/utils/synapse_processing.py index 464abd269..69b498be6 100644 --- a/pynestml/utils/synapse_processing.py +++ b/pynestml/utils/synapse_processing.py @@ -19,117 +19,205 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -import copy from collections import defaultdict +import copy + +from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter +from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter +from pynestml.codegeneration.printers.ode_toolbox_function_call_printer import ODEToolboxFunctionCallPrinter +from pynestml.codegeneration.printers.ode_toolbox_variable_printer import ODEToolboxVariablePrinter + +from pynestml.codegeneration.printers.sympy_simple_expression_printer import SympySimpleExpressionPrinter from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables +from pynestml.meta_model.ast_expression import ASTExpression from pynestml.meta_model.ast_model import ASTModel +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression from pynestml.symbols.symbol import SymbolKind -from pynestml.utils.ast_synapse_information_collector import ASTSynapseInformationCollector +from pynestml.utils.ast_synapse_information_collector import ASTSynapseInformationCollector, \ + ASTKernelInformationCollectorVisitor from pynestml.utils.ast_utils import ASTUtils + +from odetoolbox import analysis + from pynestml.utils.logger import Logger, LoggingLevel -from pynestml.utils.mechanism_processing import MechanismProcessing from pynestml.utils.messages import Messages -from odetoolbox import analysis +class SynapseProcessing: + """ + This file is part of the compartmental code generation process. + + Synapse information processing. + """ + + # used to keep track of whenever check_co_co was already called + # see inside check_co_co + first_time_run = defaultdict(lambda: True) + # stores synapse from the first call of check_co_co + syn_info = defaultdict() + + # ODE-toolbox printers + _constant_printer = ConstantPrinter() + _ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) + _ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) + _ode_toolbox_printer = ODEToolboxExpressionPrinter( + simple_expression_printer=SympySimpleExpressionPrinter( + variable_printer=_ode_toolbox_variable_printer, + constant_printer=_constant_printer, + function_call_printer=_ode_toolbox_function_call_printer)) + + _ode_toolbox_variable_printer._expression_printer = _ode_toolbox_printer + _ode_toolbox_function_call_printer._expression_printer = _ode_toolbox_printer -class SynapseProcessing(MechanismProcessing): - mechType = "receptor" + @classmethod + def prepare_equations_for_ode_toolbox(cls, synapse, syn_info): + """Transforms the collected ode equations to the required input format of ode-toolbox and adds it to the + syn_info dictionary""" + + mechanism_odes = defaultdict() + for ode in syn_info["ODEs"]: + nestml_printer = NESTMLPrinter() + ode_nestml_expression = nestml_printer.print_ode_equation(ode) + mechanism_odes[ode.lhs.name] = defaultdict() + mechanism_odes[ode.lhs.name]["ASTOdeEquation"] = ode + mechanism_odes[ode.lhs.name]["ODENestmlExpression"] = ode_nestml_expression + syn_info["ODEs"] = mechanism_odes + + for ode_variable_name, ode_info in syn_info["ODEs"].items(): + # Expression: + odetoolbox_indict = {"dynamics": []} + lhs = ASTUtils.to_ode_toolbox_name(ode_info["ASTOdeEquation"].get_lhs().get_complete_name()) + rhs = cls._ode_toolbox_printer.print(ode_info["ASTOdeEquation"].get_rhs()) + entry = {"expression": lhs + " = " + rhs, "initial_values": {}} + + # Initial values: + symbol_order = ode_info["ASTOdeEquation"].get_lhs().get_differential_order() + for order in range(symbol_order): + iv_symbol_name = ode_info["ASTOdeEquation"].get_lhs().get_name() + "'" * order + initial_value_expr = synapse.get_initial_value(iv_symbol_name) + entry["initial_values"][ + ASTUtils.to_ode_toolbox_name(iv_symbol_name)] = cls._ode_toolbox_printer.print( + initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + syn_info["ODEs"][ode_variable_name]["ode_toolbox_input"] = odetoolbox_indict + + return syn_info + + @classmethod + def collect_raw_odetoolbox_output(cls, syn_info): + """calls ode-toolbox for each ode individually and collects the raw output""" + for ode_variable_name, ode_info in syn_info["ODEs"].items(): + solver_result = analysis(ode_info["ode_toolbox_input"], disable_stiffness_check=True) + syn_info["ODEs"][ode_variable_name]["ode_toolbox_output"] = solver_result + + return syn_info + + @classmethod + def ode_toolbox_processing(cls, synapse, syn_info): + syn_info = cls.prepare_equations_for_ode_toolbox(synapse, syn_info) + syn_info = cls.collect_raw_odetoolbox_output(syn_info) + return syn_info - def __init__(self, params): - super(MechanismProcessing, self).__init__(params) + @classmethod + def collect_information_for_specific_mech_types(cls, synapse, syn_info): + # to be implemented for specific mechanisms by child class (concentration, synapse, channel) + pass @classmethod - def collect_information_for_specific_mech_types(cls, neuron, mechs_info): - mechs_info, add_info_collector = cls.collect_additional_base_infos(neuron, mechs_info) - if len(mechs_info) > 0: - # only do this if any synapses found - # otherwise tests may fail - mechs_info = cls.collect_and_check_inputs_per_synapse(mechs_info) + def determine_dependencies(cls, syn_info): + for mechanism_name, mechanism_info in syn_info.items(): + dependencies = list() + for inline in mechanism_info["Inlines"]: + if isinstance(inline.get_decorators(), list): + if "mechanism" in [e.namespace for e in inline.get_decorators()]: + dependencies.append(inline) + for ode in mechanism_info["ODEs"]: + if isinstance(ode.get_decorators(), list): + if "mechanism" in [e.namespace for e in ode.get_decorators()]: + dependencies.append(ode) + syn_info[mechanism_name]["dependencies"] = dependencies + return syn_info - mechs_info = cls.convolution_ode_toolbox_processing(neuron, mechs_info) + @classmethod + def get_port_names(cls, syn_info): + spiking_port_names = list() + continuous_port_names = list() + for port in syn_info["SpikingPorts"]: + spiking_port_names.append(port.get_name()) + for port in syn_info["ContinuousPorts"]: + continuous_port_names.append(port.get_name()) - return mechs_info + return spiking_port_names, continuous_port_names @classmethod - def collect_additional_base_infos(cls, neuron, syns_info): + def collect_kernels(cls, neuron, syn_info, neuron_synapse_pairs): """ Collect internals, kernels, inputs and convolutions associated with the synapse. """ - info_collector = ASTSynapseInformationCollector() + syn_info["convolutions"] = defaultdict() + info_collector = ASTKernelInformationCollectorVisitor() neuron.accept(info_collector) - for synapse_name, synapse_info in syns_info.items(): - synapse_inline = syns_info[synapse_name]["root_expression"] - syns_info[synapse_name][ + for inline in syn_info["Inlines"]: + synapse_inline = inline + syn_info[ "internals_used_declared"] = info_collector.get_synapse_specific_internal_declarations(synapse_inline) - syns_info[synapse_name]["total_used_declared"] = info_collector.get_variable_names_of_synapse( + syn_info["total_used_declared"] = info_collector.get_variable_names_of_synapse( synapse_inline) - syns_info[synapse_name]["convolutions"] = defaultdict() - kernel_arg_pairs = info_collector.get_extracted_kernel_args( - synapse_inline) + kernel_arg_pairs = info_collector.get_extracted_kernel_args_by_name( + inline.get_variable_name()) for kernel_var, spikes_var in kernel_arg_pairs: kernel_name = kernel_var.get_name() spikes_name = spikes_var.get_name() - convolution_name = info_collector.construct_kernel_X_spike_buf_name( - kernel_name, spikes_name, 0) - syns_info[synapse_name]["convolutions"][convolution_name] = { - "kernel": { - "name": kernel_name, - "ASTKernel": info_collector.get_kernel_by_name(kernel_name), - }, - "spikes": { - "name": spikes_name, - "ASTInputPort": info_collector.get_input_port_by_name(spikes_name), - }, - } - return syns_info, info_collector + if spikes_name != "self_spikes": + convolution_name = info_collector.construct_kernel_X_spike_buf_name( + kernel_name, spikes_name, 0) + syn_info["convolutions"][convolution_name] = { + "kernel": { + "name": kernel_name, + "ASTKernel": info_collector.get_kernel_by_name(kernel_name), + }, + "spikes": { + "name": spikes_name, + "ASTInputPort": info_collector.get_input_port_by_name(spikes_name), + }, + "post_port": (len([dict for dict in neuron_synapse_pairs if + dict["synapse"] + "_nestml" == neuron.name and spikes_name in dict[ + "post_ports"]]) > 0), + } + return syn_info @classmethod def collect_and_check_inputs_per_synapse( cls, - syns_info: dict): - new_syns_info = copy.copy(syns_info) + syn_info: dict): + new_syn_info = copy.copy(syn_info) # collect all buffers used - for synapse_name, synapse_info in syns_info.items(): - new_syns_info[synapse_name]["buffers_used"] = set() - for convolution_name, convolution_info in synapse_info["convolutions"].items( - ): - input_name = convolution_info["spikes"]["name"] - new_syns_info[synapse_name]["buffers_used"].add(input_name) - - # now make sure each synapse is using exactly one buffer - for synapse_name, synapse_info in syns_info.items(): - buffers = new_syns_info[synapse_name]["buffers_used"] - if len(buffers) != 1: - code, message = Messages.get_syns_bad_buffer_count( - buffers, synapse_name) - causing_object = synapse_info["inline_expression"] - Logger.log_message( - code=code, - message=message, - error_position=causing_object.get_source_position(), - log_level=LoggingLevel.ERROR, - node=causing_object) - - return new_syns_info - - @classmethod - def convolution_ode_toolbox_processing(cls, neuron, syns_info): + new_syn_info["buffers_used"] = set() + for convolution_name, convolution_info in syn_info["convolutions"].items( + ): + input_name = convolution_info["spikes"]["name"] + new_syn_info["buffers_used"].add(input_name) + + return new_syn_info + + @classmethod + def convolution_ode_toolbox_processing(cls, neuron, syn_info): if not neuron.get_parameters_blocks(): - return syns_info + return syn_info parameters_block = neuron.get_parameters_blocks()[0] - for synapse_name, synapse_info in syns_info.items(): - for convolution_name, convolution_info in synapse_info["convolutions"].items(): - kernel_buffer = (convolution_info["kernel"]["ASTKernel"], convolution_info["spikes"]["ASTInputPort"]) - convolution_solution = cls.ode_solve_convolution(neuron, parameters_block, kernel_buffer) - syns_info[synapse_name]["convolutions"][convolution_name]["analytic_solution"] = convolution_solution - return syns_info + for convolution_name, convolution_info in syn_info["convolutions"].items(): + kernel_buffer = (convolution_info["kernel"]["ASTKernel"], convolution_info["spikes"]["ASTInputPort"]) + convolution_solution = cls.ode_solve_convolution(neuron, parameters_block, kernel_buffer) + syn_info["convolutions"][convolution_name]["analytic_solution"] = convolution_solution + return syn_info @classmethod def ode_solve_convolution(cls, @@ -228,3 +316,96 @@ def transform_ode_and_kernels_to_json( )] = cls._ode_toolbox_printer.print(decl.get_expression()) return odetoolbox_indict + + @classmethod + def get_syn_info(cls, synapse: ASTModel): + """ + returns previously generated syn_info + as a deep copy so it can't be changed externally + via object references + :param synapse: a single synapse instance. + """ + return copy.deepcopy(cls.syn_info) + + @classmethod + def process(cls, synapse: ASTModel, neuron_synapse_pairs): + """ + Checks if mechanism conditions apply for the handed over synapse. + :param synapse: a single synapse instance. + """ + + # make sure we only run this a single time + # subsequent calls will be after AST has been transformed + # and there would be no kernels or inlines any more + if cls.first_time_run[synapse]: + # collect root expressions and initialize collector + info_collector = ASTSynapseInformationCollector(synapse) + + # collect and process all basic mechanism information + syn_info = defaultdict() + syn_info = info_collector.collect_definitions(synapse, syn_info) + syn_info = info_collector.extend_variables_with_initialisations(synapse, syn_info) + syn_info = cls.ode_toolbox_processing(synapse, syn_info) + + # collect all spiking ports + syn_info = info_collector.collect_ports(synapse, syn_info) + + # collect the onReceive function of pre- and post-spikes + spiking_port_names, continuous_port_names = cls.get_port_names(syn_info) + post_ports = FrontendConfiguration.get_codegen_opts()["neuron_synapse_pairs"][0]["post_ports"] + pre_ports = list(set(spiking_port_names) - set(post_ports)) + syn_info = info_collector.collect_on_receive_blocks(synapse, syn_info, pre_ports, post_ports) + + # get corresponding delay variable + syn_info["DelayVariable"] = FrontendConfiguration.get_codegen_opts()["delay_variable"][synapse.get_name().removesuffix("_nestml")] + + # collect the update block + syn_info = info_collector.collect_update_block(synapse, syn_info) + + # collect dependencies (defined mechanism in neuron and no LHS appearance in synapse) + syn_info = info_collector.collect_potential_dependencies(synapse, syn_info) + + syn_info = cls.collect_kernels(synapse, syn_info, neuron_synapse_pairs) + + syn_info = cls.convolution_ode_toolbox_processing(synapse, syn_info) + + cls.syn_info[synapse.get_name()] = syn_info + cls.first_time_run[synapse.get_name()] = False + + @classmethod + def print_element(cls, name, element, rec_step): + message = "" + for indent in range(rec_step): + message += "----" + message += name + ": " + if isinstance(element, defaultdict): + message += "\n" + message += cls.print_dictionary(element, rec_step + 1) + else: + if hasattr(element, 'name'): + message += element.name + elif isinstance(element, str): + message += element + elif isinstance(element, dict): + message += "\n" + message += cls.print_dictionary(element, rec_step + 1) + elif isinstance(element, list): + for index in range(len(element)): + message += "\n" + message += cls.print_element(str(index), element[index], rec_step + 1) + elif isinstance(element, ASTExpression) or isinstance(element, ASTSimpleExpression): + message += cls._ode_toolbox_printer.print(element) + + message += "(" + type(element).__name__ + ")" + return message + + @classmethod + def print_dictionary(cls, dictionary, rec_step): + """ + Print the mechanisms info dictionaries. + """ + message = "" + for name, element in dictionary.items(): + message += cls.print_element(name, element, rec_step) + message += "\n" + return message diff --git a/pynestml/utils/syns_info_enricher.py b/pynestml/utils/syns_info_enricher.py index 5c4b639ec..3a7b31b72 100644 --- a/pynestml/utils/syns_info_enricher.py +++ b/pynestml/utils/syns_info_enricher.py @@ -18,139 +18,284 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - -from _collections import defaultdict - import copy + import sympy +from pynestml.cocos.co_cos_manager import CoCosManager + +from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.meta_model.ast_expression import ASTExpression from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_model import ASTModel -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.symbol import SymbolKind -from pynestml.utils.mechs_info_enricher import MechsInfoEnricher -from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.utils.ast_utils import ASTUtils from pynestml.visitors.ast_visitor import ASTVisitor +from pynestml.utils.model_parser import ModelParser +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind +from collections import defaultdict -class SynsInfoEnricher(MechsInfoEnricher): - """ - input: a neuron after ODE-toolbox transformations - the kernel analysis solves all kernels at the same time - this splits the variables on per kernel basis +class SynsInfoEnricher: """ + This file is part of the compartmental code generation process. - def __init__(self, params): - super(MechsInfoEnricher, self).__init__(params) + Adds information collection that can't be done in the processing class since that is used in the cocos. + Here we use the ModelParser which would lead to a cyclic dependency. + + Additionally we require information about the paired synapses mechanism to confirm what dependencies are actually existent in the synapse. + """ @classmethod - def enrich_mechanism_specific(cls, neuron, mechs_info): + def enrich_with_additional_info(cls, synapse: ASTModel, syns_info: dict, chan_info: dict, recs_info: dict, + conc_info: dict, con_in_info: dict): specific_enricher_visitor = SynsInfoEnricherVisitor() - neuron.accept(specific_enricher_visitor) - mechs_info = cls.transform_convolutions_analytic_solutions(neuron, mechs_info) - mechs_info = cls.restore_order_internals(neuron, mechs_info) - return mechs_info + + cls.add_propagators_to_internals(synapse, syns_info) + synapse.accept(specific_enricher_visitor) + + synapse_info = syns_info[synapse.get_name()] + synapse_info = cls.transform_ode_solutions(synapse, synapse_info) + synapse_info = cls.confirm_dependencies(synapse_info, chan_info, recs_info, conc_info, con_in_info) + synapse_info = cls.extract_infunction_declarations(synapse_info) + + synapse_info = cls.transform_convolutions_analytic_solutions(synapse, synapse_info) + syns_info[synapse.get_name()] = synapse_info + + return syns_info @classmethod - def transform_convolutions_analytic_solutions(cls, neuron: ASTModel, cm_syns_info: dict): + def add_propagators_to_internals(cls, neuron, mechs_info): + for mechanism_name, mechanism_info in mechs_info.items(): + for ode_var_name, ode_info in mechanism_info["ODEs"].items(): + for ode_solution_index in range(len(ode_info["ode_toolbox_output"])): + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["propagators"].items(): + ASTUtils.add_declaration_to_internals(neuron, variable_name, rhs_str) + + if "convolutions" in mechanism_info: + for convolution_name, convolution_info in mechanism_info["convolutions"].items(): + for variable_name, rhs_str in convolution_info["analytic_solution"]["propagators"].items(): + ASTUtils.add_declaration_to_internals(neuron, variable_name, rhs_str) + + SymbolTable.delete_model_scope(neuron.get_name()) + symbol_table_visitor = ASTSymbolTableVisitor() + neuron.accept(symbol_table_visitor) + CoCosManager.check_cocos(neuron, after_ast_rewrite=True, syn_model=True) + SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) - enriched_syns_info = copy.copy(cm_syns_info) - for synapse_name, synapse_info in cm_syns_info.items(): - for convolution_name in synapse_info["convolutions"].keys(): - analytic_solution = enriched_syns_info[synapse_name][ - "convolutions"][convolution_name]["analytic_solution"] - analytic_solution_transformed = defaultdict( - lambda: defaultdict()) - - for variable_name, expression_str in analytic_solution["initial_values"].items(): - variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, - SymbolKind.VARIABLE) - - expression = ModelParser.parse_expression(expression_str) + @classmethod + def transform_ode_solutions(cls, synapse, syns_info): + for ode_var_name, ode_info in syns_info["ODEs"].items(): + syns_info["ODEs"][ode_var_name]["transformed_solutions"] = list() + + for ode_solution_index in range(len(ode_info["ode_toolbox_output"])): + solution_transformed = defaultdict() + solution_transformed["states"] = defaultdict() + solution_transformed["propagators"] = defaultdict() + + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items(): + variable = synapse.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) # pretend that update expressions are in "equations" block, # which should always be present, as synapses have been # defined to get here - expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.update_scope(synapse.get_equations_blocks()[0].get_scope()) expression.accept(ASTSymbolTableVisitor()) - update_expr_str = analytic_solution["update_expressions"][variable_name] + update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ + variable_name] update_expr_ast = ModelParser.parse_expression( update_expr_str) # pretend that update expressions are in "equations" block, # which should always be present, as differential equations # must have been defined to get here update_expr_ast.update_scope( - neuron.get_equations_blocks()[0].get_scope()) + synapse.get_equations_blocks()[0].get_scope()) update_expr_ast.accept(ASTSymbolTableVisitor()) - analytic_solution_transformed['kernel_states'][variable_name] = { + solution_transformed["states"][variable_name] = { "ASTVariable": variable, "init_expression": expression, "update_expression": update_expr_ast, } - - for variable_name, expression_string in analytic_solution["propagators"].items( - ): - variable = SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] - expression = ModelParser.parse_expression( - expression_string) + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["propagators"].items(): + prop_variable = synapse.get_internals_blocks()[0].get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE) + if prop_variable is None: + ASTUtils.add_declarations_to_internals( + synapse, ode_info["ode_toolbox_output"][ode_solution_index]["propagators"]) + prop_variable = synapse.get_internals_blocks()[0].get_scope().resolve_to_symbol( + variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) # pretend that update expressions are in "equations" block, # which should always be present, as synapses have been # defined to get here expression.update_scope( - neuron.get_equations_blocks()[0].get_scope()) + synapse.get_equations_blocks()[0].get_scope()) expression.accept(ASTSymbolTableVisitor()) - analytic_solution_transformed['propagators'][variable_name] = { - "ASTVariable": variable, "init_expression": expression, } - - enriched_syns_info[synapse_name]["convolutions"][convolution_name]["analytic_solution"] = \ - analytic_solution_transformed - - # only one buffer allowed, so allow direct access - # to it instead of a list - if "buffer_name" not in enriched_syns_info[synapse_name]: - buffers_used = list( - enriched_syns_info[synapse_name]["buffers_used"]) - del enriched_syns_info[synapse_name]["buffers_used"] - enriched_syns_info[synapse_name]["buffer_name"] = buffers_used[0] - - inline_expression_name = enriched_syns_info[synapse_name]["root_expression"].variable_name - enriched_syns_info[synapse_name]["root_expression"] = \ - SynsInfoEnricherVisitor.inline_name_to_transformed_inline[inline_expression_name] - enriched_syns_info[synapse_name]["inline_expression_d"] = \ - cls.compute_expression_derivative( - enriched_syns_info[synapse_name]["root_expression"]) - # now also identify analytic helper variables such as __h - enriched_syns_info[synapse_name]["analytic_helpers"] = cls.get_analytic_helper_variable_declarations( - enriched_syns_info[synapse_name]) + solution_transformed["propagators"][variable_name] = { + "ASTVariable": prop_variable, "init_expression": expression, } + expression_variable_collector = ASTEnricherInfoCollectorVisitor() + expression.accept(expression_variable_collector) - return enriched_syns_info + synapse_internal_declaration_collector = ASTEnricherInfoCollectorVisitor() + synapse.accept(synapse_internal_declaration_collector) + + for variable in expression_variable_collector.all_variables: + for internal_declaration in synapse_internal_declaration_collector.internal_declarations: + if variable.get_name() == internal_declaration.get_variables()[0].get_name() \ + and internal_declaration.get_expression().is_function_call() \ + and internal_declaration.get_expression().get_function_call().callee_name == \ + PredefinedFunctions.TIME_RESOLUTION: + syns_info["time_resolution_var"] = variable + + syns_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) + + synapse.accept(ASTParentVisitor()) + + return syns_info @classmethod - def restore_order_internals(cls, neuron: ASTModel, cm_syns_info: dict): - """orders user defined internals - back to the order they were originally defined - this is important if one such variable uses another - user needs to have control over the order - assign each variable a rank - that corresponds to the order in - SynsInfoEnricher.declarations_ordered""" - variable_name_to_order = {} - for index, declaration in enumerate( - SynsInfoEnricherVisitor.declarations_ordered): - variable_name = declaration.get_variables()[0].get_name() - variable_name_to_order[variable_name] = index + def confirm_dependencies(cls, syns_info: dict, chan_info: dict, recs_info: dict, conc_info: dict, + con_in_info: dict): + actual_dependencies = dict() + chan_deps = list() + rec_deps = list() + conc_deps = list() + con_in_deps = list() + for pot_dep, dep_info in syns_info["PotentialDependencies"].items(): + for channel_name, channel_info in chan_info.items(): + if pot_dep == channel_name: + chan_deps.append(channel_info["root_expression"]) + for receptor_name, receptor_info in recs_info.items(): + if pot_dep == receptor_name: + rec_deps.append(receptor_info["root_expression"]) + for concentration_name, concentration_info in conc_info.items(): + if pot_dep == concentration_name: + conc_deps.append(concentration_info["root_expression"]) + for continuous_name, continuous_info in con_in_info.items(): + if pot_dep == continuous_name: + con_in_deps.append(continuous_info["root_expression"]) + + actual_dependencies["channels"] = chan_deps + actual_dependencies["receptors"] = rec_deps + actual_dependencies["concentrations"] = conc_deps + actual_dependencies["continuous"] = con_in_deps + syns_info["Dependencies"] = actual_dependencies + return syns_info + + @classmethod + def extract_infunction_declarations(cls, syn_info): + pre_spike_function = syn_info["PreSpikeFunction"] + post_spike_function = syn_info["PostSpikeFunction"] + update_block = syn_info["UpdateBlock"] + # general_functions = syn_info["Functions"] + declaration_visitor = ASTDeclarationCollectorAndUniqueRenamerVisitor() + if pre_spike_function is not None: + pre_spike_function.accept(declaration_visitor) + if post_spike_function is not None: + post_spike_function.accept(declaration_visitor) + if update_block is not None: + update_block.accept(declaration_visitor) + + declaration_vars = list() + for decl in declaration_visitor.declarations: + for var in decl.get_variables(): + declaration_vars.append(var.get_name()) + + syn_info["InFunctionDeclarationsVars"] = declaration_visitor.declarations # list(declaration_vars) + return syn_info + + @classmethod + def transform_convolutions_analytic_solutions(cls, neuron: ASTModel, cm_syns_info: dict): enriched_syns_info = copy.copy(cm_syns_info) - for synapse_name, synapse_info in cm_syns_info.items(): - user_internals = enriched_syns_info[synapse_name]["internals_used_declared"] - user_internals_sorted = sorted( - user_internals.items(), key=lambda x: variable_name_to_order[x[0]]) - enriched_syns_info[synapse_name]["internals_used_declared"] = user_internals_sorted + for convolution_name in cm_syns_info["convolutions"].keys(): + analytic_solution = enriched_syns_info[ + "convolutions"][convolution_name]["analytic_solution"] + analytic_solution_transformed = defaultdict( + lambda: defaultdict()) + + for variable_name, expression_str in analytic_solution["initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + if variable is None: + ASTUtils.add_declarations_to_internals( + neuron, analytic_solution["initial_values"]) + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol( + variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(expression_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + update_expr_str = analytic_solution["update_expressions"][variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + + analytic_solution_transformed['kernel_states'][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } + + for variable_name, expression_string in analytic_solution["propagators"].items( + ): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + if variable is None: + ASTUtils.add_declarations_to_internals( + neuron, analytic_solution["propagators"]) + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol( + variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression( + expression_string) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + analytic_solution_transformed['propagators'][variable_name] = { + "ASTVariable": variable, "init_expression": expression, } + + enriched_syns_info["convolutions"][convolution_name]["analytic_solution"] = \ + analytic_solution_transformed + + transformed_inlines = dict() + for inline in enriched_syns_info["Inlines"]: + transformed_inlines[inline.get_variable_name()] = dict() + transformed_inlines[inline.get_variable_name()]["inline_expression"] = \ + SynsInfoEnricherVisitor.inline_name_to_transformed_inline[inline.get_variable_name()] + transformed_inlines[inline.get_variable_name()]["inline_expression_d"] = \ + cls.compute_expression_derivative( + transformed_inlines[inline.get_variable_name()]["inline_expression"]) + enriched_syns_info["Inlines"] = transformed_inlines + + # now also identify analytic helper variables such as __h + enriched_syns_info["analytic_helpers"] = cls.get_analytic_helper_variable_declarations( + enriched_syns_info) + + neuron.accept(ASTParentVisitor()) return enriched_syns_info @@ -169,9 +314,60 @@ def compute_expression_derivative( return ast_expression_d @classmethod - def get_variable_names_used(cls, node) -> set: - variable_names_extractor = ASTUsedVariableNamesExtractor(node) - return variable_names_extractor.variable_names + def get_analytic_helper_variable_declarations(cls, single_synapse_info): + variable_names = cls.get_analytic_helper_variable_names( + single_synapse_info) + result = dict() + for variable_name in variable_names: + if variable_name not in SynsInfoEnricherVisitor.internal_variable_name_to_variable: + continue + variable = SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] + expression = SynsInfoEnricherVisitor.variables_to_internal_declarations[variable] + result[variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + } + if expression.is_function_call() and expression.get_function_call( + ).callee_name == PredefinedFunctions.TIME_RESOLUTION: + result[variable_name]["is_time_resolution"] = True + else: + result[variable_name]["is_time_resolution"] = False + + return result + + @classmethod + def get_analytic_helper_variable_names(cls, single_synapse_info): + """get new variables that only occur on the right hand side of analytic solution Expressions + but for wich analytic solution does not offer any values + this can isolate out additional variables that suddenly appear such as __h + whose initial values are not inlcuded in the output of analytic solver""" + + analytic_lhs_vars = set() + + for convolution_name, convolution_info in single_synapse_info["convolutions"].items( + ): + analytic_sol = convolution_info["analytic_solution"] + + # get variables representing convolutions by kernel + for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( + ): + analytic_lhs_vars.add(kernel_var_name) + + # get propagator variable names + for propagator_var_name, propagator_info in analytic_sol["propagators"].items( + ): + analytic_lhs_vars.add(propagator_var_name) + + return cls.get_new_variables_after_transformation( + single_synapse_info).symmetric_difference(analytic_lhs_vars) + + @classmethod + def get_new_variables_after_transformation(cls, single_synapse_info): + total = set() + if "total_used_declared" in single_synapse_info: + total = single_synapse_info["total_used_declared"] + return cls.get_all_synapse_variables(single_synapse_info).difference( + total) @classmethod def get_all_synapse_variables(cls, single_synapse_info): @@ -179,9 +375,9 @@ def get_all_synapse_variables(cls, single_synapse_info): and by the analytical solution assumes that the model has already been transformed""" - # get all variables from transformed inline - inline_variables = cls.get_variable_names_used( - single_synapse_info["root_expression"]) + inline_variables = set() + for inline_name, inline in single_synapse_info["Inlines"].items(): + inline_variables = cls.get_variable_names_used(inline["inline_expression"]) analytic_solution_vars = set() # get all variables from transformed analytic solution @@ -217,57 +413,109 @@ def get_all_synapse_variables(cls, single_synapse_info): return analytic_solution_vars.union(inline_variables) @classmethod - def get_new_variables_after_transformation(cls, single_synapse_info): - return cls.get_all_synapse_variables(single_synapse_info).difference( - single_synapse_info["total_used_declared"]) + def get_variable_names_used(cls, node) -> set: + variable_names_extractor = ASTUsedVariableNamesExtractor(node) + return variable_names_extractor.variable_names - @classmethod - def get_analytic_helper_variable_names(cls, single_synapse_info): - """get new variables that only occur on the right hand side of analytic solution Expressions - but for wich analytic solution does not offer any values - this can isolate out additional variables that suddenly appear such as __h - whose initial values are not inlcuded in the output of analytic solver""" - analytic_lhs_vars = set() +class ASTEnricherInfoCollectorVisitor(ASTVisitor): - for convolution_name, convolution_info in single_synapse_info["convolutions"].items( - ): - analytic_sol = convolution_info["analytic_solution"] + def __init__(self): + super(ASTEnricherInfoCollectorVisitor, self).__init__() + self.inside_variable = False + self.inside_block_with_variables = False + self.all_states = list() + self.all_parameters = list() + self.inside_states_block = False + self.inside_parameters_block = False + self.all_variables = list() + self.inside_internals_block = False + self.inside_declaration = False + self.internal_declarations = list() - # get variables representing convolutions by kernel - for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( - ): - analytic_lhs_vars.add(kernel_var_name) + def visit_block_with_variables(self, node): + self.inside_block_with_variables = True + if node.is_state: + self.inside_states_block = True + if node.is_parameters: + self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True - # get propagator variable names - for propagator_var_name, propagator_info in analytic_sol["propagators"].items( - ): - analytic_lhs_vars.add(propagator_var_name) + def endvisit_block_with_variables(self, node): + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_block_with_variables = False + self.inside_internals_block = False - return cls.get_new_variables_after_transformation( - single_synapse_info).symmetric_difference(analytic_lhs_vars) + def visit_variable(self, node): + self.inside_variable = True + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states.append(node.clone()) + if self.inside_parameters_block: + self.all_parameters.append(node.clone()) - @classmethod - def get_analytic_helper_variable_declarations(cls, single_synapse_info): - variable_names = cls.get_analytic_helper_variable_names( - single_synapse_info) - result = dict() - for variable_name in variable_names: - if variable_name not in SynsInfoEnricherVisitor.internal_variable_name_to_variable: - continue - variable = SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] - expression = SynsInfoEnricherVisitor.variables_to_internal_declarations[variable] - result[variable_name] = { - "ASTVariable": variable, - "init_expression": expression, - } - if expression.is_function_call() and expression.get_function_call( - ).callee_name == PredefinedFunctions.TIME_RESOLUTION: - result[variable_name]["is_time_resolution"] = True + def endvisit_variable(self, node): + self.inside_variable = False + + def visit_declaration(self, node): + self.inside_declaration = True + if self.inside_internals_block: + self.internal_declarations.append(node) + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTDeclarationCollectorAndUniqueRenamerVisitor(ASTVisitor): + def __init__(self): + super(ASTDeclarationCollectorAndUniqueRenamerVisitor, self).__init__() + self.declarations = list() + self.variable_names = dict() + self.inside_declaration = False + self.inside_block = False + self.current_block = None + + def visit_block(self, node): + self.inside_block = True + self.current_block = node + + def endvisit_block(self, node): + self.inside_block = False + self.current_block = None + + def visit_declaration(self, node): + self.inside_declaration = True + for variable in node.get_variables(): + if variable.get_name() in self.variable_names: + self.variable_names[variable.get_name()] += 1 else: - result[variable_name]["is_time_resolution"] = False + self.variable_names[variable.get_name()] = 0 + new_name = variable.get_name() + '_' + str(self.variable_names[variable.get_name()]) + name_replacer = ASTVariableNameReplacerVisitor(variable.get_name(), new_name) + self.current_block.accept(name_replacer) + node.accept(ASTSymbolTableVisitor()) + self.declarations.append(node.clone()) - return result + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTVariableNameReplacerVisitor(ASTVisitor): + def __init__(self, old_name, new_name): + super(ASTVariableNameReplacerVisitor, self).__init__() + self.inside_variable = False + self.new_name = new_name + self.old_name = old_name + + def visit_variable(self, node): + self.inside_variable = True + if node.get_name() == self.old_name: + node.set_name(self.new_name) + + def endvisit_variable(self, node): + self.inside_variable = False class SynsInfoEnricherVisitor(ASTVisitor): diff --git a/setup.py b/setup.py index 8925d7b80..2802baa69 100755 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ "codegeneration/resources_nest/point_neuron/setup/common/*.jinja2", "codegeneration/resources_nest_compartmental/cm_neuron/*.jinja2", "codegeneration/resources_nest_compartmental/cm_neuron/directives_cpp/*.jinja2", + "codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/*.jinja2", "codegeneration/resources_nest_compartmental/cm_neuron/setup/*.jinja2", "codegeneration/resources_nest_compartmental/cm_neuron/setup/common/*.jinja2", "codegeneration/resources_python_standalone/point_neuron/*.jinja2", diff --git a/tests/nest_compartmental_tests/resources/cm_default.nestml b/tests/nest_compartmental_tests/resources/cm_default.nestml index 4736bc227..aac92001b 100644 --- a/tests/nest_compartmental_tests/resources/cm_default.nestml +++ b/tests/nest_compartmental_tests/resources/cm_default.nestml @@ -1,39 +1,35 @@ # Example compartmental model for NESTML -# +# # Description # +++++++++++ # Corresponds to standard compartmental model implemented in NEST. -# -# +# +# # Copyright statement # +++++++++++++++++++ -# +# # This file is part of NEST. -# +# # Copyright (C) 2004 The NEST Initiative -# +# # NEST is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 2 of the License, or # (at your option) any later version. -# +# # NEST is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_default: state: - # compartmental voltage variable, - # rhs value is irrelevant but the state must exist so that the nestml parser doesn't complain - v_comp real = 0 - ### ion channels ### # initial values state variables sodium channel m_Na real = 0.01696863 @@ -42,6 +38,10 @@ model cm_default: # initial values state variables potassium channel n_K real = 0.00014943 + # compartmental voltage variable, + # rhs value is irrelevant but the state must exist so that the nestml parser doesn't complain + v_comp real = 0 + parameters: ### ion channels ### diff --git a/tests/nest_compartmental_tests/resources/cm_iaf_psc_exp_dend_neuron.nestml b/tests/nest_compartmental_tests/resources/cm_iaf_psc_exp_dend_neuron.nestml new file mode 100644 index 000000000..3248a8f82 --- /dev/null +++ b/tests/nest_compartmental_tests/resources/cm_iaf_psc_exp_dend_neuron.nestml @@ -0,0 +1,89 @@ +# iaf_psc_exp_dend - Leaky integrate-and-fire neuron model with exponential PSCs +# ######################################################################### +# +# Description +# +++++++++++ +# +# iaf_psc_exp is an implementation of a leaky integrate-and-fire model +# with exponential-kernel postsynaptic currents (PSCs) according to [1]_. +# Thus, postsynaptic currents have an infinitely short rise time. +# +# The threshold crossing is followed by an absolute refractory period (t_ref) +# during which the membrane potential is clamped to the resting potential +# and spiking is prohibited. +# +# .. note:: +# If tau_m is very close to tau_syn_ex or tau_syn_in, numerical problems +# may arise due to singularities in the propagator matrics. If this is +# the case, replace equal-valued parameters by a single parameter. +# +# For details, please see ``IAF_neurons_singularity.ipynb`` in +# the NEST source code (``docs/model_details``). +# +# +# References +# ++++++++++ +# +# .. [1] Tsodyks M, Uziel A, Markram H (2000). Synchrony generation in recurrent +# networks with frequency-dependent synapses. The Journal of Neuroscience, +# 20,RC50:1-5. URL: https://infoscience.epfl.ch/record/183402 +# +# +# See also +# ++++++++ +# +# iaf_cond_exp +# +# +model iaf_psc_exp_cm_dend: + + state: + v_comp real = 0 # Membrane potential + refr_t ms = 0 ms # Refractory period timer + + is_refr real = 0.0 + + + equations: + kernel I_kernel_inh = exp(-t/tau_syn_inh) + kernel I_kernel_exc = exp(-t/tau_syn_exc) + + inline leak real = (E_l - v_comp) * C_m / tau_m @mechanism::channel + inline syn_exc real = convolve(I_kernel_exc, exc_spikes) @mechanism::receptor + inline syn_inh real = convolve(I_kernel_inh, inh_spikes) @mechanism::receptor + inline refr real = G_refr * is_refr * (V_reset - v_comp) @mechanism::channel + + parameters: + C_m pF = 250 pF # Capacity of the membrane + tau_m ms = 10 ms # Membrane time constant + tau_syn_inh ms = 2 ms # Time constant of inhibitory synaptic current + tau_syn_exc ms = 2 ms # Time constant of excitatory synaptic current + refr_T ms = 5 ms # Duration of refractory period + E_l mV = -70 mV # Resting potential + V_reset mV = -70 mV # Reset potential of the membrane + V_th mV = -55 mV # Spike threshold potential + + # constant external input current + I_e pA = 0 pA + + G_refr real = 0. + + input: + exc_spikes <- excitatory spike + inh_spikes <- inhibitory spike + I_stim pA <- continuous + + output: + spike + + update: + if refr_t > resolution() / 2: + # neuron is absolute refractory, do not evolve V_m + refr_t -= resolution() + else: + refr_t = 0 ms + is_refr = 0 + + onReceive(self_spikes): + is_refr = 1 + refr_t = refr_T diff --git a/tests/nest_compartmental_tests/resources/concmech.nestml b/tests/nest_compartmental_tests/resources/concmech.nestml index d9944188e..b1f1acbda 100644 --- a/tests/nest_compartmental_tests/resources/concmech.nestml +++ b/tests/nest_compartmental_tests/resources/concmech.nestml @@ -1,23 +1,23 @@ # Copyright statement # +++++++++++++++++++ -# +# # This file is part of NEST. -# +# # Copyright (C) 2004 The NEST Initiative -# +# # NEST is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 2 of the License, or # (at your option) any later version. -# +# # NEST is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model multichannel_test_model: parameters: diff --git a/tests/nest_compartmental_tests/resources/continuous_test.nestml b/tests/nest_compartmental_tests/resources/continuous_test.nestml index 30d397422..03c6bb1f3 100644 --- a/tests/nest_compartmental_tests/resources/continuous_test.nestml +++ b/tests/nest_compartmental_tests/resources/continuous_test.nestml @@ -1,23 +1,23 @@ # Copyright statement # +++++++++++++++++++ -# +# # This file is part of NEST. -# +# # Copyright (C) 2004 The NEST Initiative -# +# # NEST is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 2 of the License, or # (at your option) any later version. -# +# # NEST is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model continuous_test_model: state: diff --git a/tests/nest_compartmental_tests/resources/invalid/CoCoCmMechSharedCode.nestml b/tests/nest_compartmental_tests/resources/invalid/CoCoCmMechSharedCode.nestml new file mode 100644 index 000000000..a5cf1dcfc --- /dev/null +++ b/tests/nest_compartmental_tests/resources/invalid/CoCoCmMechSharedCode.nestml @@ -0,0 +1,60 @@ +# CoCoCmMechSharedCode.nestml +# ########################### +# +# +# Description +# +++++++++++ +# +# This model is used to test whether incorrectly shared variables +# between mechanisms lead to an appropriate error. +# +# Negative case. +# +# +# Copyright statement +# +++++++++++++++++++ +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +# +# +model CoCoCmSharedCode: + state: + v_comp real = 0. + + shared_state real = 0. + exclusive_state1 real = 0. + exclusive_state2 real = 0. + + shared_and_global_state real = 0. + + parameters: + shared_param real = 0. + exclusive_param1 real = 0. + exclusive_param2 real = 0. + + internals: + shared_int real = 0. + exclusive_int1 real = 0. + exclusive_int2 real = 0. + + equations: + inline channel1 real = (shared_state + shared_param + shared_int) * (exclusive_state1 + exclusive_param1 + exclusive_int1) * shared_and_global_state @mechanism::channel + inline channel2 real = (shared_state + shared_param + shared_int) * (exclusive_state2 + exclusive_param2 + exclusive_int2) * shared_and_global_state @mechanism::channel + + update: + shared_and_global_state = shared_and_global_state + 1. \ No newline at end of file diff --git a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariableHasRhs.nestml b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariableHasRhs.nestml index d2a1e0470..ad6a10df7 100644 --- a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariableHasRhs.nestml +++ b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariableHasRhs.nestml @@ -30,7 +30,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_model_four_invalid: diff --git a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariableMultiUse.nestml b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariableMultiUse.nestml index 09e3687f4..9f4ff0747 100644 --- a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariableMultiUse.nestml +++ b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariableMultiUse.nestml @@ -30,7 +30,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_model_five_invalid: diff --git a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariablesDeclared.nestml b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariablesDeclared.nestml index e67dec20a..02fab5b90 100644 --- a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariablesDeclared.nestml +++ b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVariablesDeclared.nestml @@ -30,7 +30,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_model_seven_invalid: diff --git a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVcompExists.nestml b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVcompExists.nestml index 6eba3dd2b..347e83126 100644 --- a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVcompExists.nestml +++ b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVcompExists.nestml @@ -30,14 +30,15 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_model_eight_invalid: state: # compartmental voltage variable, # rhs value is irrelevant but the state must exist so that the nestml parser doesn't complain - m_Na real = 0.0 + m_Na real = 0.0 + h_Na real = 0.0 #sodium function m_inf_Na(v_comp real) real: @@ -53,7 +54,8 @@ model cm_model_eight_invalid: return 0.3115264797507788/((-0.0091000000000000004*v_comp - 0.68261830000000012)/(1.0 - 3277527.8765015295*exp(0.20000000000000001*v_comp)) + (0.024*v_comp + 1.200312)/(1.0 - 4.5282043263959816e-5*exp(-0.20000000000000001*v_comp))) equations: - inline Na real = m_Na**3 * h_Na**1 + inline Na real = gbar_Na * m_Na**3 * h_Na * (e_Na - v_comp) @mechanism::channel parameters: - foo real = 1. + gbar_Na real = 0. + e_Na real = 50. diff --git a/tests/nest_compartmental_tests/resources/recordable_inline_test.nestml b/tests/nest_compartmental_tests/resources/recordable_inline_test.nestml new file mode 100644 index 000000000..54fcd4fab --- /dev/null +++ b/tests/nest_compartmental_tests/resources/recordable_inline_test.nestml @@ -0,0 +1,78 @@ +# recordable inlines test model +# +# Description +# +++++++++++ +# Corresponds to standard compartmental model implemented in NEST. +# +# +# Copyright statement +# +++++++++++++++++++ +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +# +# +model cm_default: + + state: + m_Na real = 0.01696863 + h_Na real = 0.83381407 + + v_comp real = 0 + + parameters: + e_Na real = 50.0 + gbar_Na real = 0.0 + + e_AMPA real = 0 mV + tau_r_AMPA real = 0.2 ms + tau_d_AMPA real = 3.0 ms + + equations: + h_Na' = (h_inf_Na(v_comp) - h_Na) / (tau_h_Na(v_comp) * 1 s) + m_Na' = (m_inf_Na(v_comp) - m_Na) / (tau_m_Na(v_comp) * 1 s) + + recordable inline open_probability real = m_Na**3 * h_Na + recordable inline equilibrium_distance real = (e_Na - v_comp) + + inline Na real = gbar_Na * open_probability * equilibrium_distance @mechanism::channel + + kernel g_AMPA = g_norm_AMPA * ( - exp(-t / tau_r_AMPA) + exp(-t / tau_d_AMPA) ) + inline AMPA real = convolve(g_AMPA, spikes_AMPA) * (e_AMPA - v_comp) @mechanism::receptor + + function m_inf_Na (v_comp real) real: + return (1.0 - 0.020438532058318*exp(-0.111111111111111*v_comp))**(-1)*((1.0 - 0.020438532058318*exp(-0.111111111111111*v_comp))**(-1)*(6.372366 + 0.182*v_comp) + (1.0 - 48.9271928701465*exp(0.111111111111111*v_comp))**(-1)*(-4.341612 - 0.124*v_comp))**(-1)*(6.372366 + 0.182*v_comp) + + function tau_m_Na (v_comp real) real: + return 0.311526479750779*((1.0 - 0.020438532058318*exp(-0.111111111111111*v_comp))**(-1)*(6.372366 + 0.182*v_comp) + (1.0 - 48.9271928701465*exp(0.111111111111111*v_comp))**(-1)*(-4.341612 - 0.124*v_comp))**(-1) + + function h_inf_Na (v_comp real) real: + return 1.0*(1.0 + 35734.4671267926*exp(0.161290322580645*v_comp))**(-1) + + function tau_h_Na (v_comp real) real: + return 0.311526479750779*((1.0 - 4.52820432639598e-5*exp(-0.2*v_comp))**(-1)*(1.200312 + 0.024*v_comp) + (1.0 - 3277527.87650153*exp(0.2*v_comp))**(-1)*(-0.6826183 - 0.0091*v_comp))**(-1) + + + internals: + tp_AMPA real = (tau_r_AMPA * tau_d_AMPA) / (tau_d_AMPA - tau_r_AMPA) * ln( tau_d_AMPA / tau_r_AMPA ) + g_norm_AMPA real = 1. / ( -exp( -tp_AMPA / tau_r_AMPA ) + exp( -tp_AMPA / tau_d_AMPA ) ) + + input: + spikes_AMPA <- spike + + output: + spike diff --git a/tests/nest_compartmental_tests/resources/self_spike_convolutions.nestml b/tests/nest_compartmental_tests/resources/self_spike_convolutions.nestml new file mode 100644 index 000000000..23bfac6a8 --- /dev/null +++ b/tests/nest_compartmental_tests/resources/self_spike_convolutions.nestml @@ -0,0 +1,78 @@ +# self_spike_convolutions.nestml +# ########################### +# +# +# Description +# +++++++++++ +# +# This model is used to test convolutions with the +# self_spikes variable. +# +# +# Copyright statement +# +++++++++++++++++++ +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +# +# +model self_spikes_convolutions: + state: + v_comp real = 0.0 + + concentration real = 0.0 + + parameters: + w_bap real = 0.0 + e_bap real = 0.0 + + g_norm_bap real = 1.0 + tau_r_bap real = .2 + tau_d_bap real = 3. + + equations: + # double exponential conductance profile to implement backpropagating action potential + kernel g_bap = g_norm_bap * ( - exp(-t / tau_r_bap) + exp(-t / tau_d_bap) ) + inline bap real = w_bap * convolve(g_bap, self_spikes) * (e_bap - v_comp) @mechanism::channel + + kernel kern = exp(-t / 10.0) + inline secondary real = convolve(kern, self_spikes) + + inline chan_primary real = convolve(kern, self_spikes) @mechanism::channel + + inline chan_secondary real = secondary @mechanism::channel + + inline rec_primary real = convolve(kern, in_spikes1) + convolve(kern, self_spikes) @mechanism::receptor + + inline rec_secondary real = convolve(kern, in_spikes2) + secondary @mechanism::receptor + + inline con_in_primary real = in_continuous1*convolve(kern, self_spikes) @mechanism::continuous_input + + inline con_in_secondary real = in_continuous2*secondary @mechanism::continuous_input + + concentration' = (secondary-(concentration/2)) / 10 ms @mechanism::concentration + + input: + self_spikes <- spike + + in_spikes1 <- spike + in_spikes2 <- spike + in_continuous1 real <- continuous + in_continuous2 real <- continuous + + output: + spike \ No newline at end of file diff --git a/tests/nest_compartmental_tests/resources/stdp_synapse.nestml b/tests/nest_compartmental_tests/resources/stdp_synapse.nestml new file mode 100644 index 000000000..ee52f73c7 --- /dev/null +++ b/tests/nest_compartmental_tests/resources/stdp_synapse.nestml @@ -0,0 +1,83 @@ +# stdp - Synapse model for spike-timing dependent plasticity +# ######################################################### +# +# Description +# +++++++++++ +# +# stdp_synapse is a synapse with spike-timing dependent plasticity (as defined in [1]_). Here the weight dependence exponent can be set separately for potentiation and depression. +# This compartmental modification removes the call to integrate_odes() since, in the compartmental case, integration is +# done implicitly at every timestep because the synapse is merged with the post synaptic receptor. +# +# Examples: +# +# =================== ==== ============================= +# Multiplicative STDP [2]_ mu_plus = mu_minus = 1 +# Additive STDP [3]_ mu_plus = mu_minus = 0 +# Guetig STDP [1]_ mu_plus, mu_minus in [0, 1] +# Van Rossum STDP [4]_ mu_plus = 0 mu_minus = 1 +# =================== ==== ============================= +# +# +# References +# ++++++++++ +# +# .. [1] Guetig et al. (2003) Learning Input Correlations through Nonlinear +# Temporally Asymmetric Hebbian Plasticity. Journal of Neuroscience +# +# .. [2] Rubin, J., Lee, D. and Sompolinsky, H. (2001). Equilibrium +# properties of temporally asymmetric Hebbian plasticity, PRL +# 86,364-367 +# +# .. [3] Song, S., Miller, K. D. and Abbott, L. F. (2000). Competitive +# Hebbian learning through spike-timing-dependent synaptic +# plasticity,Nature Neuroscience 3:9,919--926 +# +# .. [4] van Rossum, M. C. W., Bi, G-Q and Turrigiano, G. G. (2000). +# Stable Hebbian learning from spike timing-dependent +# plasticity, Journal of Neuroscience, 20:23,8812--8821 +# +# +model stdp_synapse: + state: + w real = 1. # Synaptic weight + pre_trace real = 0. + post_trace real = 0. + + parameters: + d ms = 1 ms # Synaptic transmission delay + lambda real = .01 + tau_tr_pre ms = 20 ms + tau_tr_post ms = 20 ms + alpha real = 1 + mu_plus real = 1 + mu_minus real = 1 + Wmax real = 100. + Wmin real = 0. + + equations: + pre_trace' = -pre_trace / tau_tr_pre + post_trace' = -post_trace / tau_tr_post + + input: + pre_spikes <- spike + post_spikes <- spike + + output: + spike(weight real, delay ms) + + onReceive(post_spikes): + post_trace += 1 + + # potentiate synapse + w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace )) + w = min(Wmax, w_) + + onReceive(pre_spikes): + pre_trace += 1 + + # depress synapse + w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace )) + w = max(Wmin, w_) + + # deliver spike to postsynaptic partner + emit_spike(w, d) diff --git a/tests/nest_compartmental_tests/resources/third_factor_stdp_synapse.nestml b/tests/nest_compartmental_tests/resources/third_factor_stdp_synapse.nestml new file mode 100644 index 000000000..c7518d075 --- /dev/null +++ b/tests/nest_compartmental_tests/resources/third_factor_stdp_synapse.nestml @@ -0,0 +1,93 @@ +# third_factor_stdp_synapse - Synapse model for spike-timing dependent plasticity with postsynaptic third-factor modulation +# ######################################################################################################################### +# +# Description +# +++++++++++ +# +# third_factor_stdp_synapse is a synapse with spike time dependent plasticity (as defined in [1]). Here the weight dependence exponent can be set separately for potentiation and depression. Examples:: +# +# Multiplicative STDP [2] mu_plus = mu_minus = 1 +# Additive STDP [3] mu_plus = mu_minus = 0 +# Guetig STDP [1] mu_plus, mu_minus in [0, 1] +# Van Rossum STDP [4] mu_plus = 0 mu_minus = 1 +# +# The weight changes are modulated by a "third factor", in this case the postsynaptic dendritic current ``I_post_dend``. +# +# ``I_post_dend`` "gates" the weight update, so that if the current is 0, the weight is constant, whereas for a current of 1 pA, the weight change is maximal. +# +# Do not use values of ``I_post_dend`` larger than 1 pA! +# +# References +# ++++++++++ +# +# [1] Guetig et al. (2003) Learning Input Correlations through Nonlinear +# Temporally Asymmetric Hebbian Plasticity. Journal of Neuroscience +# +# [2] Rubin, J., Lee, D. and Sompolinsky, H. (2001). Equilibrium +# properties of temporally asymmetric Hebbian plasticity, PRL +# 86,364-367 +# +# [3] Song, S., Miller, K. D. and Abbott, L. F. (2000). Competitive +# Hebbian learning through spike-timing-dependent synaptic +# plasticity,Nature Neuroscience 3:9,919--926 +# +# [4] van Rossum, M. C. W., Bi, G-Q and Turrigiano, G. G. (2000). +# Stable Hebbian learning from spike timing-dependent +# plasticity, Journal of Neuroscience, 20:23,8812--8821 +# +# +model third_factor_stdp_synapse: + state: + w real = 1. # Synaptic weight + I_post_dend pA = 0 pA + AMPA pA = 0 pA + Ca_HVA pA = 0 pA + Ca_LVAst pA = 0 pA + NaTa_t pA = 0 pA + SK_E2 pA = 0 pA + + parameters: + d ms = 1 ms # Synaptic transmission delay + lambda real = .01 + tau_tr_pre ms = 20 ms + tau_tr_post ms = 20 ms + alpha real = 1. + mu_plus real = 1. + mu_minus real = 1. + Wmax real = 100. + Wmin real = 0. + + equations: + kernel pre_trace_kernel = exp(-t / tau_tr_pre) + inline pre_trace real = convolve(pre_trace_kernel, pre_spikes) + + # all-to-all trace of postsynaptic neuron + kernel post_trace_kernel = exp(-t / tau_tr_post) + inline post_trace real = convolve(post_trace_kernel, post_spikes) + + input: + pre_spikes <- spike + post_spikes <- spike + + output: + spike(weight real, delay ms) + + onReceive(post_spikes): + # potentiate synapse + w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace )) + if I_post_dend <= 1 pA: + w_ = (I_post_dend / pA) * w_ + (1 - I_post_dend / pA) * w # "gating" of the weight update + w = min(Wmax, w_) + + onReceive(pre_spikes): + # depress synapse + w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace )) + if I_post_dend <= 1 pA: + w_ = (I_post_dend / pA) * w_ + (1 - I_post_dend / pA) * w # "gating" of the weight update + w = max(Wmin, w_) + + # deliver spike to postsynaptic partner + emit_spike(w, d) + + update: + I_post_dend = AMPA + Ca_HVA + Ca_LVAst + NaTa_t + SK_E2 diff --git a/tests/nest_compartmental_tests/resources/valid/CoCoCmVariableHasRhs.nestml b/tests/nest_compartmental_tests/resources/valid/CoCoCmVariableHasRhs.nestml index 0fa4f9af9..5477a12b8 100644 --- a/tests/nest_compartmental_tests/resources/valid/CoCoCmVariableHasRhs.nestml +++ b/tests/nest_compartmental_tests/resources/valid/CoCoCmVariableHasRhs.nestml @@ -30,7 +30,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_model_four: diff --git a/tests/nest_compartmental_tests/resources/valid/CoCoCmVariableMultiUse.nestml b/tests/nest_compartmental_tests/resources/valid/CoCoCmVariableMultiUse.nestml index 58c6908bd..5f8b708a4 100644 --- a/tests/nest_compartmental_tests/resources/valid/CoCoCmVariableMultiUse.nestml +++ b/tests/nest_compartmental_tests/resources/valid/CoCoCmVariableMultiUse.nestml @@ -30,7 +30,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_model_five: diff --git a/tests/nest_compartmental_tests/resources/valid/CoCoCmVariablesDeclared.nestml b/tests/nest_compartmental_tests/resources/valid/CoCoCmVariablesDeclared.nestml index b38c1a1bc..f790da42a 100644 --- a/tests/nest_compartmental_tests/resources/valid/CoCoCmVariablesDeclared.nestml +++ b/tests/nest_compartmental_tests/resources/valid/CoCoCmVariablesDeclared.nestml @@ -30,7 +30,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_model_seven: diff --git a/tests/nest_compartmental_tests/resources/valid/CoCoCmVcompExists.nestml b/tests/nest_compartmental_tests/resources/valid/CoCoCmVcompExists.nestml index 633be6ea1..225f154b1 100644 --- a/tests/nest_compartmental_tests/resources/valid/CoCoCmVcompExists.nestml +++ b/tests/nest_compartmental_tests/resources/valid/CoCoCmVcompExists.nestml @@ -30,7 +30,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - +# # model cm_model_eight_valid: diff --git a/tests/nest_compartmental_tests/test__cm_iaf_psc_exp_dend_neuron.py b/tests/nest_compartmental_tests/test__cm_iaf_psc_exp_dend_neuron.py new file mode 100644 index 000000000..4ea3a91f0 --- /dev/null +++ b/tests/nest_compartmental_tests/test__cm_iaf_psc_exp_dend_neuron.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# +# test__cm_iaf_psc_exp_dend_neuron.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestCompartmentalIAF: + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources/cm_iaf_psc_exp_dend_neuron.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + + if True: + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="iaf_psc_exp_dend_neuron_compartmental_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("iaf_psc_exp_dend_neuron_compartmental_module.so") + + def test_iaf(self): + """We test the concentration mechanism by comparing the concentration value at a certain critical point in + time to a previously achieved value at this point""" + cm = nest.Create('iaf_psc_exp_cm_dend_nestml') + + params = {"G_refr": 1000.} + + cm.compartments = [ + {"parent_idx": -1, "params": params} + ] + + cm.receptors = [ + {"comp_idx": 0, "receptor_type": "syn_exc"} + ] + + sg1 = nest.Create('spike_generator', 1, {'spike_times': [1., 2., 3., 4.]}) + + nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 2000.0, 'delay': 0.5, 'receptor_type': 0}) + + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'leak0', 'refr0'], 'interval': .1}) + + nest.Connect(mm, cm) + + sr = nest.Create('spike_recorder') + + nest.Connect(cm, sr) + + nest.Simulate(10.) + + res = nest.GetStatus(mm, 'events')[0] + + fig, axs = plt.subplots(3) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m_0') + axs[1].plot(res['times'], res['leak0'], c='y', label='leak0') + axs[2].plot(res['times'], res['refr0'], c='b', label='refr0') + + axs[0].set_title('V_m_0') + axs[1].set_title('leak0') + axs[2].set_title('refr0') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + + plt.savefig("cm_iaf_test.png") + + events_dist = nest.GetStatus(sr)[0]['events'] + + assert list(events_dist["times"]) == [1.5, 6.7], "Spike times are not as expected!" diff --git a/tests/nest_compartmental_tests/test__cocos.py b/tests/nest_compartmental_tests/test__cocos.py index 7ee55f8a1..74e18e5e3 100644 --- a/tests/nest_compartmental_tests/test__cocos.py +++ b/tests/nest_compartmental_tests/test__cocos.py @@ -110,7 +110,7 @@ def test_invalid_cm_v_comp_exists(self): 'invalid')), 'CoCoCmVcompExists.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model, LoggingLevel.ERROR)) == 4 + model, LoggingLevel.ERROR)) == 6 def test_valid_cm_v_comp_exists(self): Logger.set_logging_level(LoggingLevel.INFO) @@ -130,7 +130,7 @@ def _parse_and_validate_model(self, fname: str) -> Optional[str]: Logger.init_logger(LoggingLevel.DEBUG) try: - generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + generate_target(input_path=fname, target_platform="NEST_COMPARTMENTAL", logging_level="DEBUG") except BaseException: return None @@ -142,3 +142,14 @@ def _parse_and_validate_model(self, fname: str) -> Optional[str]: model_name = model.get_name() return model_name + + def test_invalid_cm_mech_shared_code(self, setUp): + model = self._parse_and_validate_model( + os.path.join( + os.path.realpath( + os.path.join( + os.path.dirname(__file__), 'resources', + 'invalid')), + 'CoCoCmMechSharedCode.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node( + model, LoggingLevel.ERROR)) == 4 diff --git a/tests/nest_compartmental_tests/test__compartmental_model.py b/tests/nest_compartmental_tests/test__compartmental_model.py index 95e9ff9b7..31250334b 100644 --- a/tests/nest_compartmental_tests/test__compartmental_model.py +++ b/tests/nest_compartmental_tests/test__compartmental_model.py @@ -100,13 +100,14 @@ def install_nestml_model(self): f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" f" {target_path}") - generate_nest_compartmental_target( - input_path=input_path, - target_path=target_path, - module_name="cm_defaultmodule", - suffix="_nestml", - logging_level="ERROR" - ) + if True: + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="cm_defaultmodule", + suffix="_nestml", + logging_level="ERROR" + ) def get_model(self, reinstall_flag=True): if self.nestml_flag: diff --git a/tests/nest_compartmental_tests/test__compartmental_stdp.py b/tests/nest_compartmental_tests/test__compartmental_stdp.py new file mode 100644 index 000000000..9108af051 --- /dev/null +++ b/tests/nest_compartmental_tests/test__compartmental_stdp.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- +# +# test__compartmental_stdp.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os +import unittest + +import pytest + +import nest + +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestCompartmentalConcmech(unittest.TestCase): + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + neuron_input_path = os.path.join( + tests_path, + "resources", + "concmech.nestml" + ) + synapse_input_path = os.path.join( + tests_path, + "resources", + "stdp_synapse.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + if False: + generate_nest_compartmental_target( + input_path=[neuron_input_path, synapse_input_path], + target_path=target_path, + module_name="cm_stdp_module", + suffix="_nestml", + logging_level="DEBUG", + codegen_opts={"neuron_synapse_pairs": [{"neuron": "multichannel_test_model", + "synapse": "stdp_synapse", + "post_ports": ["post_spikes"]}], + "delay_variable": {"stdp_synapse": "d"}, + "weight_variable": {"stdp_synapse": "w"} + } + ) + + nest.Install("cm_stdp_module.so") + + def run_model(self, model_case, pre_spike, post_spike, sim_time): + """ + Test the interaction between the pre- and post-synaptic spikes using STDP (Spike-Timing-Dependent Plasticity). + + This function sets up a simulation environment using NEST Simulator to demonstrate synaptic dynamics with pre-defined spike times for pre- and post-synaptic neurons. The function creates neuron models, assigns parameters, sets up connections, and records data from the simulation. It then plots the results for voltage, synaptic weight, spike timing, and pre- and post-synaptic traces. + + Simulation Procedure: + 1. Define pre- and post-synaptic spike timings and calculate simulation duration. + 2. Set up neuron models: + a. `spike_generator` to provide external spike input. + b. `parrot_neuron` for relaying spikes. + c. Custom `multichannel_test_model_nestml` neuron for the postsynaptic side, with compartments and receptor configurations specified. + 3. Create recording devices: + a. `multimeter` to record voltage, synaptic weights, currents, and traces. + b. `spike_recorder` to record spikes from pre- and post-synaptic neurons. + 4. Establish connections: + a. Connect spike generators to pre and post-neurons with static synaptic configurations. + b. Connect pre-neuron to post-neuron using a configured STDP synapse. + c. Connect recording devices to the respective neurons. + 5. Simulate the network for the specified time duration. + 6. Retrieve data from the multimeter and spike recorders. + 7. Plot the recorded data: + a. Membrane voltage of the post-synaptic neuron. + b. Synaptic weight change. + c. Pre- and post-spike timings marked with vertical lines. + d. Pre- and post-synaptic traces. + + Results: + The plots generated illustrate the effects of spike timing on various properties of the post-synaptic neuron, highlighting STDP-driven synaptic weight changes and trace dynamics. + """ + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + nest.Install("cm_stdp_module.so") + + measuring_spike = sim_time - 1 + #pre_spike_times = [pre_spike] + #if model_case == "nest": + if measuring_spike > pre_spike: + pre_spike_times = [pre_spike, measuring_spike] + else: + pre_spike_times = [measuring_spike, pre_spike] + post_spike_times = [post_spike] + + external_input_pre = nest.Create("spike_generator", params={"spike_times": pre_spike_times}) + external_input_post = nest.Create("spike_generator", params={"spike_times": post_spike_times}) + pre_neuron = nest.Create("parrot_neuron") + post_neuron = nest.Create('multichannel_test_model_nestml') + + params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1.5, 'e_L': -70.0} + post_neuron.compartments = [ + {"parent_idx": -1, "params": params} + ] + + if model_case == "nestml": + post_neuron.receptors = [ + #{"comp_idx": 0, "receptor_type": "AMPA"}, + {"comp_idx": 0, "receptor_type": "AMPA_stdp_synapse_nestml", "params": {'w': 10.0, "d": 0.1, "tau_tr_pre": 40, "tau_tr_post": 40}} + ] + mm = nest.Create('multimeter', 1, { + 'record_from': ['v_comp0', 'w0', 'AMPA_stdp_synapse_nestml0', 'pre_trace0', #'AMPA0', + 'post_trace0'], 'interval': .1}) + elif model_case == "nest": + post_neuron.receptors = [ + #{"comp_idx": 0, "receptor_type": "AMPA"}, + {"comp_idx": 0, "receptor_type": "AMPA", "params": {}} + ] + mm = nest.Create('multimeter', 1, { + 'record_from': ['v_comp0', 'AMPA0'], 'interval': .1}) + + nest.Connect(external_input_pre, pre_neuron, "one_to_one", + syn_spec={'synapse_model': 'static_synapse', 'weight': 2.0, 'delay': 0.1}) + # nest.Connect(external_input_post, post_neuron, "one_to_one", + # syn_spec={'synapse_model': 'static_synapse', 'weight': 5.0, 'delay': 0.1, 'receptor_type': 0}) + if model_case == "nestml": + nest.Connect(pre_neuron, post_neuron, "one_to_one", + syn_spec={'synapse_model': 'static_synapse', 'weight': 1.0, 'delay': 0.1, 'receptor_type': 0}) + elif model_case == "nest": + wr = nest.Create("weight_recorder") + nest.CopyModel( + "stdp_synapse", + "stdp_synapse_rec", + {"weight_recorder": wr[0], "receptor_type": 0, 'weight': 1.0}, + ) + nest.Connect( + pre_neuron, + post_neuron, + "all_to_all", + syn_spec={ + "synapse_model": "stdp_synapse_rec", + "delay": 0.1, + "weight": 10.0, + "receptor_type": 0 + }, + ) + nest.Connect(mm, post_neuron) + + nest.Simulate(sim_time) + + res = nest.GetStatus(mm, 'events')[0] + recorded = dict() + if model_case == "nest": + recorded["weight"] = nest.GetStatus(wr, "events")[0]["weights"] + recorded["w_times"] = nest.GetStatus(wr, "events")[0]["times"] + #print(recorded["weight"]) + elif model_case == "nestml": + recorded["weight"] = res['w0'] + recorded["pre_trace"] = res['pre_trace0'] + recorded["post_trace"] = res['post_trace0'] + + recorded["times"] = res['times'] + recorded["v_comp"] = res['v_comp0'] + + return recorded + + def test__compartmental_stdp(self): + rec_nest_runs = list() + rec_nestml_runs = list() + + sim_time = 40 + resolution = 20 + sim_time = int(sim_time / resolution) * resolution + + sp_td = [] + for i in range(1, resolution): + pre_spike = i * sim_time / resolution + post_spike = sim_time / 2 + sp_td.append(pre_spike - post_spike) + rec_nest_runs.append(self.run_model("nest", pre_spike, post_spike, sim_time)) + rec_nestml_runs.append(self.run_model("nestml", pre_spike, post_spike, sim_time)) + + fig, axs = plt.subplots(2) + + for i in range(len(rec_nest_runs)): + if i == 0: + nest_l = "nest" + nestml_l = "nestml" + else: + nest_l = None + nestml_l = None + + rec_nest_raw = rec_nest_runs[i] + rec_nestml_raw = rec_nestml_runs[i] + axs[0].plot([sp_td[i]], [rec_nest_raw['weight'][-1]], c='grey', marker='o', label=nest_l, markersize=7) + axs[0].plot([sp_td[i]], [rec_nestml_raw['weight'][-1]], c='orange', marker='X', label=nestml_l, markersize=5) + + nest_values = [rec_nest_runs[i]['weight'][-1] for i in range(len(rec_nest_runs))] + nestml_values = [rec_nestml_runs[i]['weight'][-1] for i in range(len(rec_nestml_runs))] + diff_values = [nestml_values[i] - nest_values[i] for i in range(len(rec_nest_runs))] + + axs[1].vlines(sp_td, 0, diff_values, color='red', label='diff', linewidth=3) + + axs[0].set_title('resulting weights') + axs[1].set_title('weight difference') + + axs[0].legend() + axs[1].legend() + + plt.tight_layout() + + plt.savefig("compartmental_stdp.png") + plt.show() + + assert abs(max(diff_values)) <= 0.005, ("the maximum weight difference is too large! (" + str(max(diff_values)) + " > 0.005)") diff --git a/tests/nest_compartmental_tests/test__concmech_model.py b/tests/nest_compartmental_tests/test__concmech_model.py index 7c4add105..1e44edb08 100644 --- a/tests/nest_compartmental_tests/test__concmech_model.py +++ b/tests/nest_compartmental_tests/test__concmech_model.py @@ -63,13 +63,14 @@ def setup(self): nest.ResetKernel() nest.SetKernelStatus(dict(resolution=.1)) - generate_nest_compartmental_target( - input_path=input_path, - target_path=target_path, - module_name="concmech_mockup_module", - suffix="_nestml", - logging_level="DEBUG" - ) + if True: + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="concmech_mockup_module", + suffix="_nestml", + logging_level="DEBUG" + ) nest.Install("concmech_mockup_module.so") @@ -92,7 +93,7 @@ def test_concmech(self): nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 0}) - mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'c_Ca0', 'i_tot_Ca_LVAst0', 'i_tot_Ca_HVA0', 'i_tot_SK_E20', 'm_Ca_HVA0', 'h_Ca_HVA0'], 'interval': .1}) + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'c_Ca0', 'Ca_LVAst0', 'Ca_HVA0', 'SK_E20', 'm_Ca_HVA0', 'h_Ca_HVA0'], 'interval': .1}) nest.Connect(mm, cm) @@ -103,21 +104,21 @@ def test_concmech(self): step_time_delta = res['times'][1] - res['times'][0] data_array_index = int(200 / step_time_delta) - expected_conc = 0.03559438228347359 + expected_conc = 0.03351908393663761 fig, axs = plt.subplots(5) axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m_0') axs[1].plot(res['times'], res['c_Ca0'], c='y', label='c_Ca_0') - axs[2].plot(res['times'], res['i_tot_Ca_HVA0'], c='b', label='i_tot_Ca_HVA0') - axs[3].plot(res['times'], res['i_tot_SK_E20'], c='b', label='i_tot_SK_E20') + axs[2].plot(res['times'], res['SK_E20'], c='b', label='SK_E20') + axs[3].plot(res['times'], res['Ca_HVA0'], c='b', label='Ca_HVA0') axs[4].plot(res['times'], res['m_Ca_HVA0'], c='g', label='gating var m') axs[4].plot(res['times'], res['h_Ca_HVA0'], c='r', label='gating var h') axs[0].set_title('V_m_0') axs[1].set_title('c_Ca_0') - axs[2].set_title('i_Ca_HVA_0') - axs[3].set_title('i_tot_SK_E20') + axs[2].set_title('i_SK_E20') + axs[3].set_title('i_Ca_HVA_0') axs[4].set_title('gating vars') axs[0].legend() @@ -128,4 +129,4 @@ def test_concmech(self): plt.savefig("concmech test.png") - assert res['c_Ca0'][data_array_index] == expected_conc, ("the concentration (left) is not as expected (right). (" + str(res['c_Ca0'][data_array_index]) + "!=" + str(expected_conc) + ")") + assert abs(res['c_Ca0'][data_array_index] - expected_conc) <= 0.0000001, ("the concentration (left) is not as expected (right). (" + str(res['c_Ca0'][data_array_index]) + "!=" + str(expected_conc) + ")") diff --git a/tests/nest_compartmental_tests/test__consistency_between_sim_calls.py b/tests/nest_compartmental_tests/test__consistency_between_sim_calls.py new file mode 100644 index 000000000..1b3a95314 --- /dev/null +++ b/tests/nest_compartmental_tests/test__consistency_between_sim_calls.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- +# +# test__consistency_between_sim_calls.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os +import unittest + +import numpy as np +import pytest + +import nest + +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestConsistencyBetweenSimCalls(unittest.TestCase): + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + neuron_input_path = os.path.join( + tests_path, + "resources", + "cm_default.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + if True: + generate_nest_compartmental_target( + input_path=[neuron_input_path], + target_path=target_path, + module_name="cm_module", + suffix="_nestml", + logging_level="DEBUG", + codegen_opts={} + ) + + nest.Install("cm_module.so") + + def test_cm_stdp(self): + """ + Test the interaction between the pre- and post-synaptic spikes using STDP (Spike-Timing-Dependent Plasticity). + + This function sets up a simulation environment using NEST Simulator to demonstrate synaptic dynamics with pre-defined spike times for pre- and post-synaptic neurons. The function creates neuron models, assigns parameters, sets up connections, and records data from the simulation. It then plots the results for voltage, synaptic weight, spike timing, and pre- and post-synaptic traces. + + Simulation Procedure: + 1. Define pre- and post-synaptic spike timings and calculate simulation duration. + 2. Set up neuron models: + a. `spike_generator` to provide external spike input. + b. `parrot_neuron` for relaying spikes. + c. Custom `multichannel_test_model_nestml` neuron for the postsynaptic side, with compartments and receptor configurations specified. + 3. Create recording devices: + a. `multimeter` to record voltage, synaptic weights, currents, and traces. + b. `spike_recorder` to record spikes from pre- and post-synaptic neurons. + 4. Establish connections: + a. Connect spike generators to pre and post-neurons with static synaptic configurations. + b. Connect pre-neuron to post-neuron using a configured STDP synapse. + c. Connect recording devices to the respective neurons. + 5. Simulate the network for the specified time duration. + 6. Retrieve data from the multimeter and spike recorders. + 7. Plot the recorded data: + a. Membrane voltage of the post-synaptic neuron. + b. Synaptic weight change. + c. Pre- and post-spike timings marked with vertical lines. + d. Pre- and post-synaptic traces. + + Results: + The plots generated illustrate the effects of spike timing on various properties of the post-synaptic neuron, highlighting STDP-driven synaptic weight changes and trace dynamics. + """ + spike_times = [10, 20] + + sim_time = 100 + repeats = 10 + + spike_times_tm1 = spike_times + for i in range(repeats): + spike_times_t = [n + sim_time for n in spike_times_tm1] + spike_times_tm1 = spike_times_t + + spike_times = spike_times + spike_times_t + + external_input_pre = nest.Create("spike_generator", params={"spike_times": spike_times}) + + neuron = nest.Create('cm_default_nestml') + + params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1.5, 'e_L': -70.0, 'gbar_Na': 1.0} + neuron.compartments = [ + {"parent_idx": -1, "params": params}, + {"parent_idx": 0, "params": {}}, + {"parent_idx": 1, "params": params} + ] + + neuron.receptors = [ + {"comp_idx": 0, "receptor_type": "AMPA"}, + ] + + mm = nest.Create('multimeter', 1, { + 'record_from': ['v_comp0', 'Na0', 'AMPA0'], 'interval': .1}) + + nest.Connect(external_input_pre, neuron, "one_to_one", + syn_spec={'synapse_model': 'static_synapse', 'weight': 5.0, 'delay': 0.1}) + nest.Connect(mm, neuron) + + nest.Simulate(sim_time) + for i in range(repeats): + nest.SetStatus(neuron, + {'v_comp0': -70.0, 'm_Na0': 0.01696863, 'h_Na0': 0.83381407, 'Na0': 0.0, 'AMPA0': 0.0, + 'g_AMPA0': 0.0}) + nest.Simulate(sim_time) + res = nest.GetStatus(mm, 'events')[0] + + run_len = list(res['times']).index(sim_time - 0.1) + 1 + res['times'] = np.insert(res['times'], 0, 0.0) + res['v_comp0'] = np.insert(res['v_comp0'], 0, res['v_comp0'][run_len]) + res['Na0'] = np.insert(res['Na0'], 0, res['Na0'][run_len]) + res['AMPA0'] = np.insert(res['AMPA0'], 0, res['AMPA0'][run_len]) + + run_len += 1 + max_deviation = 0.0 + deviations = [] + + for i in range(repeats + 1): + for ii in range(run_len): + deviation = abs(res['v_comp0'][ii + (i * run_len)] - res['v_comp0'][ii]) + deviations.append(deviation) + if deviation > max_deviation: + max_deviation = deviation + + print("max_deviation", max_deviation) + + fig, axs = plt.subplots(4) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m') + axs[1].plot(res['times'], res['Na0'], c='b', label="Na") + axs[2].plot(res['times'], res['AMPA0'], c='b', label="AMPA") + axs[3].plot(list(res['times']), deviations, c='orange', label="dev") + + axs[0].set_title('V_m') + axs[1].set_title('Na') + axs[2].set_title('AMPA') + axs[3].set_title('dev') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + axs[3].legend() + + plt.savefig("consistency sim calls test.png") + + assert max_deviation < 0.0001, ("There should be no deviation between simulation calls! The maximum deviation in this run is (" + str(max_deviation) + ").") diff --git a/tests/nest_compartmental_tests/test__continuous_input.py b/tests/nest_compartmental_tests/test__continuous_input.py index 6f8a60060..d8d6a9e2d 100644 --- a/tests/nest_compartmental_tests/test__continuous_input.py +++ b/tests/nest_compartmental_tests/test__continuous_input.py @@ -96,7 +96,7 @@ def test_continuous_input(self): nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 3.0, 'delay': 0.5, 'receptor_type': 1}) - mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'i_tot_con_in0', 'i_tot_AMPA0'], 'interval': .1}) + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'con_in0', 'AMPA1'], 'interval': .1}) nest.Connect(mm, cm) @@ -107,8 +107,8 @@ def test_continuous_input(self): fig, axs = plt.subplots(2) axs[0].plot(res['times'], res['v_comp0'], c='b', label='V_m_0') - axs[1].plot(res['times'], res['i_tot_con_in0'], c='r', label='continuous') - axs[1].plot(res['times'], res['i_tot_AMPA0'], c='g', label='synapse') + axs[1].plot(res['times'], res['con_in0'], c='r', label='continuous') + axs[1].plot(res['times'], res['AMPA1'], c='g', label='synapse') axs[0].set_title('V_m_0') axs[1].set_title('inputs') @@ -121,4 +121,4 @@ def test_continuous_input(self): step_time_delta = res['times'][1] - res['times'][0] data_array_index = int(212 / step_time_delta) - assert 19.9 < res['i_tot_con_in0'][data_array_index] < 20.1, ("the current (left) is not close enough to expected (right). (" + str(res['i_tot_con_in0'][data_array_index]) + " != " + "20.0 +- 0.1" + ")") + assert 19.9 < res['con_in0'][data_array_index] < 20.1, ("the current (left) is not close enough to expected (right). (" + str(res['i_tot_input0'][data_array_index]) + " != " + "20.0 +- 0.1" + ")") diff --git a/tests/nest_compartmental_tests/test__interaction_with_disabled_mechanism.py b/tests/nest_compartmental_tests/test__interaction_with_disabled_mechanism.py index 8834c5c7f..d6c8b26de 100644 --- a/tests/nest_compartmental_tests/test__interaction_with_disabled_mechanism.py +++ b/tests/nest_compartmental_tests/test__interaction_with_disabled_mechanism.py @@ -91,9 +91,11 @@ def test_interaction_with_disabled(self): sg1 = nest.Create('spike_generator', 1, {'spike_times': [100.]}) - nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 0}) + nest.Connect(sg1, cm, + syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 0}) - mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'c_Ca0', 'i_tot_Ca_LVAst0', 'i_tot_Ca_HVA0', 'i_tot_SK_E20'], 'interval': .1}) + mm = nest.Create('multimeter', 1, + {'record_from': ['v_comp0', 'c_Ca0', 'Ca_LVAst0', 'Ca_HVA0', 'SK_E20'], 'interval': .1}) nest.Connect(mm, cm) @@ -104,14 +106,14 @@ def test_interaction_with_disabled(self): step_time_delta = res['times'][1] - res['times'][0] data_array_index = int(200 / step_time_delta) - expected_conc = 2.8159902294145262e-05 + expected_conc = 0.0001 fig, axs = plt.subplots(4) axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m_0') axs[1].plot(res['times'], res['c_Ca0'], c='y', label='c_Ca_0') - axs[2].plot(res['times'], res['i_tot_Ca_HVA0'], c='b', label='i_tot_Ca_HVA0') - axs[3].plot(res['times'], res['i_tot_SK_E20'], c='b', label='i_tot_SK_E20') + axs[2].plot(res['times'], res['Ca_HVA0'], c='b', label='i_tot_Ca_HVA0') + axs[3].plot(res['times'], res['SK_E20'], c='b', label='i_tot_SK_E20') axs[0].set_title('V_m_0') axs[1].set_title('c_Ca_0') @@ -125,4 +127,4 @@ def test_interaction_with_disabled(self): plt.savefig("interaction with disabled mechanism test.png") - assert res['c_Ca0'][data_array_index] == expected_conc, ("the concentration (left) is not as expected (right). (" + str(res['c_Ca0'][data_array_index]) + "!=" + str(expected_conc) + ")") + assert abs(res['c_Ca0'][data_array_index] - expected_conc) <= 0.0000001, ("the concentration (left) is not as expected (right). (" + str(res['c_Ca0'][data_array_index]) + "!=" + str(expected_conc) + ")") diff --git a/tests/nest_compartmental_tests/test__recordable_inlines.py b/tests/nest_compartmental_tests/test__recordable_inlines.py new file mode 100644 index 000000000..cd24b62d8 --- /dev/null +++ b/tests/nest_compartmental_tests/test__recordable_inlines.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# +# test__concmech_model.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestCompartmentalConcmech: + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources", + "recordable_inline_test.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + + if True: + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="rec_inline_test_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("rec_inline_test_module.so") + + def test_concmech(self): + """We test the concentration mechanism by comparing the concentration value at a certain critical point in + time to a previously achieved value at this point""" + cm = nest.Create('cm_default_nestml') + + params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1.5, 'e_L': -70.0, 'gbar_Na': 1.0} + + cm.compartments = [ + {"parent_idx": -1, "params": params} + ] + + cm.receptors = [ + {"comp_idx": 0, "receptor_type": "AMPA"} + ] + + sg1 = nest.Create('spike_generator', 1, {'spike_times': [100.]}) + + nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 0}) + + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'Na0', 'open_probability0', 'equilibrium_distance0'], 'interval': .1}) + + nest.Connect(mm, cm) + + nest.Simulate(1000.) + + res = nest.GetStatus(mm, 'events')[0] + + fig, axs = plt.subplots(4) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m') + axs[1].plot(res['times'], res['Na0'], c='y', label='I_Na') + axs[2].plot(res['times'], res['open_probability0'], c='b', label='open probability') + axs[3].plot(res['times'], res['equilibrium_distance0'], c='b', label='equilibrium distance') + + axs[0].set_title('V_m') + axs[1].set_title('I_Na') + axs[2].set_title('open probability') + axs[3].set_title('equilibrium distance') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + axs[3].legend() + + plt.savefig("rec inline test.png") + plt.show() + + assert res['open_probability0'][1100] != 0, "the recordable inlines could not be recorded correctly" diff --git a/tests/nest_compartmental_tests/test__self_spike_convolutions.py b/tests/nest_compartmental_tests/test__self_spike_convolutions.py new file mode 100644 index 000000000..2000ad239 --- /dev/null +++ b/tests/nest_compartmental_tests/test__self_spike_convolutions.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# +# test__self_spike_convolutions.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestSelfSpikeConvolutions: + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources", + "self_spike_convolutions.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + + if True: + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="self_spike_convolutions_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("self_spike_convolutions_module.so") + + def test_self_spike_convolutions(self): + """We test the concentration mechanism by comparing the concentration value at a certain critical point in + time to a previously achieved value at this point""" + cm = nest.Create('self_spikes_convolutions_nestml') + + cm.compartments = [ + {"parent_idx": -1, "params": {"w_bap": 1.}} + ] + + cm.receptors = [ + {"comp_idx": 0, "receptor_type": "rec_primary"}, + {"comp_idx": 0, "receptor_type": "rec_secondary"}, + {"comp_idx": 0, "receptor_type": "con_in_primary"}, + {"comp_idx": 0, "receptor_type": "con_in_secondary"}, + ] + + sg1 = nest.Create('spike_generator', 1, {'spike_times': [50.]}) + dcg = nest.Create("dc_generator", {"amplitude": 2.0, "start": 40, "stop": 60}) + + nest.Connect(sg1, cm, + syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 0}) + nest.Connect(sg1, cm, + syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 1}) + nest.Connect(dcg, cm, + syn_spec={'synapse_model': 'static_synapse', 'weight': 1.0, 'delay': 0.1, 'receptor_type': 2}) + nest.Connect(dcg, cm, + syn_spec={'synapse_model': 'static_synapse', 'weight': 1.0, 'delay': 0.1, 'receptor_type': 3}) + + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'chan_primary0', 'chan_secondary0', + 'rec_primary0', 'rec_secondary1', + 'con_in_primary2', 'con_in_secondary3', + 'concentration0'], 'interval': .1}) + + nest.Connect(mm, cm) + + spikedet = nest.Create("spike_recorder") + nest.Connect(cm, spikedet) + spikes_rec = nest.GetStatus(spikedet, 'events')[0] + + nest.Simulate(200.) + + res = nest.GetStatus(mm, 'events')[0] + + fig, axs = plt.subplots(8) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m_0') + axs[1].plot(res['times'], res['chan_primary0'], c='g', label='chan_primary') + axs[2].plot(res['times'], res['chan_secondary0'], c='g', label='chan_secondary') + axs[3].plot(res['times'], res['rec_primary0'], c='orange', label='input0') + axs[4].plot(res['times'], res['rec_secondary1'], c='orange', label='input1') + axs[5].plot(res['times'], res['con_in_primary2'], c='orange', label='input2') + axs[6].plot(res['times'], res['con_in_secondary3'], c='orange', label='input3') + axs[7].plot(res['times'], res['concentration0'], c='b', label='concentration') + + label_set = False + for spike in spikes_rec['times']: + for ax in axs: + if (label_set): + ax.axvline(x=spike, color='purple', linestyle='--', linewidth=1) + else: + ax.axvline(x=spike, color='purple', linestyle='--', linewidth=1, label="self_spikes") + label_set = True + + axs[0].legend() + axs[1].legend() + axs[2].legend() + axs[3].legend() + axs[4].legend() + axs[5].legend() + axs[6].legend() + axs[7].legend() + + plt.savefig("self_spike_convolutions.png")