From 1cbf52e6d15fc39d99ef272a741641b99cdc31e3 Mon Sep 17 00:00:00 2001 From: Tom Hogervorst Date: Fri, 4 Nov 2022 15:54:24 +0100 Subject: [PATCH 1/8] Added "add" alias to add_parameters method Also made that method work with parameter names instead of only parameter identifiers --- tests/test_parametermap.py | 12 +++++++++ trsfile/compatibility.py | 48 +++++++++++++++++++++++++++++++++++ trsfile/parametermap.py | 10 ++++++-- trsfile/standardparameters.py | 4 +-- 4 files changed, 70 insertions(+), 4 deletions(-) create mode 100644 trsfile/compatibility.py diff --git a/tests/test_parametermap.py b/tests/test_parametermap.py index 676c980..ba0862e 100644 --- a/tests/test_parametermap.py +++ b/tests/test_parametermap.py @@ -76,9 +76,15 @@ def test_add_standard_parameter(self): param_map1 = TraceSetParameterMap() param_map1.add_standard_parameter(StandardTraceSetParameters.KEY, bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map1.add_standard_parameter(StandardTraceSetParameters.TVLA_CIPHER, "AES") param_map2 = TraceSetParameterMap() param_map2.add_parameter('KEY', bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map2.add_parameter('TVLA:CIPHER', "AES") + param_map3 = TraceSetParameterMap() + param_map3.add('key', bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map3.add('tvla_cipher', "AES") self.assertDictEqual(param_map1, param_map2) + self.assertDictEqual(param_map1, param_map3) # Verify that standard trace set parameters enforce a specific type with self.assertRaises(TypeError): @@ -199,9 +205,15 @@ def test_add_standard_parameter(self): param_map1 = TraceParameterMap() param_map1.add_standard_parameter(StandardTraceParameters.INPUT, bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map1.add_standard_parameter(StandardTraceParameters.TVLA_SET_INDEX, 1) param_map2 = TraceParameterMap() param_map2.add_parameter('INPUT', bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map2.add_parameter('TVLA:SET_INDEX', 1) + param_map3 = TraceParameterMap() + param_map3.add("input", bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map3.add("tvla_set_index", 1) self.assertDictEqual(param_map1, param_map2) + self.assertDictEqual(param_map1, param_map3) # Verify that standard trace parameters enforce a specific type with self.assertRaises(TypeError): diff --git a/trsfile/compatibility.py b/trsfile/compatibility.py new file mode 100644 index 0000000..72490bd --- /dev/null +++ b/trsfile/compatibility.py @@ -0,0 +1,48 @@ +import functools + + +class alias: + """ + A decorator for implementing method aliases. + """ + + def __init__(self, *aliases): + self.aliases = set(aliases) + + def __call__(self, obj): + if type(obj) == property: + obj.fget._aliases = self.aliases + else: + obj._aliases = self.aliases + + return obj + + +def aliased(aliased_class): + """ + A decorator for enabling method aliases. + """ + def wrapper(func, name): + @functools.wraps(func) + def inner(*args, **kwargs): + return func(*args, **kwargs) + return inner + + aliased_class_dict = aliased_class.__dict__.copy() + aliased_class_set = set(aliased_class_dict) + + for name, method in aliased_class_dict.items(): + aliases = None + + if (type(method) == property) and hasattr(method.fget, '_aliases'): + aliases = method.fget._aliases + elif hasattr(method, '_aliases'): + aliases = method._aliases + + if aliases: + for method_alias in aliases - aliased_class_set: + wrapped_method = wrapper(method, method_alias) + wrapped_method.__doc__ = str(f"{method_alias} is an alias of {name}.") + setattr(aliased_class, method_alias, wrapped_method) + + return aliased_class \ No newline at end of file diff --git a/trsfile/parametermap.py b/trsfile/parametermap.py index 41ad127..8fb41a4 100644 --- a/trsfile/parametermap.py +++ b/trsfile/parametermap.py @@ -4,6 +4,7 @@ import warnings from typing import Any, Union, List, Dict +from trsfile.compatibility import alias, aliased from trsfile.standardparameters import StandardTraceSetParameters, StandardTraceParameters from trsfile.traceparameter import TraceSetParameter, TraceParameter, TraceParameterDefinition, ParameterType, \ BooleanArrayParameter, ByteArrayParameter, StringParameter, DoubleArrayParameter, IntegerArrayParameter, \ @@ -16,6 +17,7 @@ INT_MIN = -2**31 INT_MAX = 2**31-1 + class ParameterMapUtil: # A placeholder for integers that are actually shorts class ShortType(numbers.Rational): @@ -175,6 +177,7 @@ def lock_content(self): self._is_locked = True +@aliased class TraceSetParameterMap(LockableDict): default_values = { StandardTraceSetParameters.DISPLAY_HINT_X_LABEL: "", @@ -195,6 +198,7 @@ def __setitem__(self, key: str, value: Union[TraceParameter, TraceSetParameter]) self._stop_if_locked() super().__setitem__(key, value) + @alias("add") def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> None: """Add a trace set parameter with a given name and value. If the name matches the identifier of a standard trace set parameter, @@ -208,7 +212,7 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - try: std_param = StandardTraceSetParameters.from_identifier(name) typed_param = std_param.parameter_type.param_class - self[name] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + self[std_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) # if no std_param can be found, a ValueError is raised, # if adding the trace parameter to the map fails, a TypeError is raised except (ValueError, TypeError) as e: @@ -326,6 +330,7 @@ def from_trace_parameter_map(trace_parameters: TraceParameterMap) -> TraceParame return result +@aliased class TraceParameterMap(StringKeyOrderedDict): def __setitem__(self, key: str, value: TraceParameter): if not isinstance(value, TraceParameter): @@ -333,6 +338,7 @@ def __setitem__(self, key: str, value: TraceParameter): ' of TraceParameter (e.g. ByteArrayParameter).') super().__setitem__(key, value) + @alias("add") def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> None: """Add a trace parameter with a given name and value If the name matches the identifier of a standard trace parameter, @@ -346,7 +352,7 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - try: std_param = StandardTraceParameters.from_identifier(name) typed_param = std_param.parameter_type.param_class - self[name] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + self[std_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) # if no std_param can be found, a ValueError is raised, # if adding the trace parameter to the map fails, a TypeError is raised except (TypeError, ValueError) as e: diff --git a/trsfile/standardparameters.py b/trsfile/standardparameters.py index b58d266..cfba128 100644 --- a/trsfile/standardparameters.py +++ b/trsfile/standardparameters.py @@ -19,7 +19,7 @@ def __new__(cls, tag: int, identifier: str, parameter_type: ParameterType): @staticmethod def from_identifier(identifier: str) -> StandardTraceSetParameters: for val in StandardTraceSetParameters: - if identifier == val.identifier: + if identifier.lower() == val.identifier.lower() or identifier.lower() == val.name.lower(): return val raise ValueError(f'{identifier} is not an identifier of a StandardTraceSetParameter') @@ -79,7 +79,7 @@ def __new__(cls, tag: int, identifier: str, parameter_type: ParameterType): @staticmethod def from_identifier(identifier: str) -> StandardTraceParameters: for val in StandardTraceParameters: - if identifier == val.identifier: + if identifier.lower() == val.identifier.lower() or identifier.lower() == val.name.lower(): return val raise ValueError('{} is not a name of a StandardTraceParameter'.format(identifier)) From f6f941dc5e85936c228699141436cf9dc94f7f75 Mon Sep 17 00:00:00 2001 From: Tom Hogervorst Date: Fri, 4 Nov 2022 15:58:56 +0100 Subject: [PATCH 2/8] Move trace parameter check to correct place The check should be performed on all traces that are added, and it should always be done --- tests/test_creation.py | 69 ++++++++++++++++++++++++++++++++++++++++++ trsfile/engine/trs.py | 14 ++++++--- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/tests/test_creation.py b/tests/test_creation.py index 7805afe..c6eee8c 100644 --- a/tests/test_creation.py +++ b/tests/test_creation.py @@ -196,6 +196,75 @@ def test_write_closed(self): with self.assertRaises(ValueError): print(trs_traces) + def test_write_different_trace_sizes(self): + trace_count = 100 + sample_count = 1000 + + with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces: + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(i.to_bytes(8, byteorder='big'))}) + ) + for i in range(0, trace_count)] + ) + with self.assertRaises(TypeError): + # The length is incorrect + # Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions. + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('cafebabedeadbeef0102030405060708'))}) + )] + ) + with self.assertRaises(TypeError): + # The name is incorrect + # Should raise a Type error: The parameters of trace #1 do not match the trace set's definitions. + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('0102030405060708'))}) + ), + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'NEW_DATA': ByteArrayParameter(bytes.fromhex('0102030405060708'))}) + )] + ) + with self.assertRaises(TypeError): + # The type is incorrect + # Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions. + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': IntegerArrayParameter([42, 74])}) + )] + ) + + with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces: + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap() + ) + for i in range(0, trace_count)] + ) + with self.assertRaises(TypeError): + # The length, data and name are incorrect + # Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions. + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('cafebabedeadbeef0102030405060708'))}) + )] + ) + def test_read(self): trace_count = 100 sample_count = 1000 diff --git a/trsfile/engine/trs.py b/trsfile/engine/trs.py index 611948d..fcfe0cc 100644 --- a/trsfile/engine/trs.py +++ b/trsfile/engine/trs.py @@ -177,14 +177,18 @@ def update_headers_with_traces_metadata(self, traces: List[Trace]) -> None: # Add a TraceParameterDefinitionMap if none is present, and verify its validity if one is present if Header.TRACE_PARAMETER_DEFINITIONS not in self.headers: - if data_length > 0: - headers_updates[Header.TRACE_PARAMETER_DEFINITIONS] = \ - TraceParameterDefinitionMap.from_trace_parameter_map(traces[0].parameters) - elif not traces[0].parameters.matches(self.headers[Header.TRACE_PARAMETER_DEFINITIONS]): - raise TypeError("The traces' parameters do not match the trace set's definitions") + headers_updates[Header.TRACE_PARAMETER_DEFINITIONS] = \ + TraceParameterDefinitionMap.from_trace_parameter_map(traces[0].parameters) headers_updates[Header.LENGTH_DATA] = data_length + # Verify that each trace confirms to the traceset's TraceParameterDefinitionMap + for index, trace in enumerate(traces): + if not trace.parameters.matches(self.headers[Header.TRACE_PARAMETER_DEFINITIONS]): + raise TypeError(f"The parameters of trace #{index} do not match the trace set's definitions.\n" + f"Please make sure the trace parameters match those of the other traces in type, " + f"size and name.") + if self.headers[Header.SAMPLE_CODING] is None: if len(set([trace.sample_coding for trace in traces])) > 1: raise TypeError('Traces have different sample coding, this is not supported in TRS files') From b3408f208204f750174d6e5047b18ac2cda4d7fb Mon Sep 17 00:00:00 2001 From: Tom Hogervorst Date: Fri, 4 Nov 2022 15:59:28 +0100 Subject: [PATCH 3/8] Add methods to insert and append (in)to the TraceParameterDefinitionMap --- tests/test_header.py | 4 +++ tests/test_parametermap.py | 36 ++++++++++++++++++++ trsfile/engine/trs.py | 25 ++++++-------- trsfile/parametermap.py | 70 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 14 deletions(-) diff --git a/tests/test_header.py b/tests/test_header.py index 5cf5834..7e567cd 100644 --- a/tests/test_header.py +++ b/tests/test_header.py @@ -90,6 +90,10 @@ def test_trace_param_defs_append_errors(self): trace_parameter_definitions.popitem() with self.assertRaises(TypeError): trace_parameter_definitions.clear() + with self.assertRaises(TypeError): + trace_parameter_definitions.append('input', ParameterType.BYTE, 16) + with self.assertRaises(TypeError): + trace_parameter_definitions.insert('output', ParameterType.BYTE, 16, 0) # Shallow copies still share references to the same trace set parameters as the original, # and should therefore not be modifiable if the original isn't diff --git a/tests/test_parametermap.py b/tests/test_parametermap.py index ba0862e..1f5a02b 100644 --- a/tests/test_parametermap.py +++ b/tests/test_parametermap.py @@ -109,6 +109,14 @@ def create_parameterdefinitionmap() -> TraceParameterDefinitionMap: param_map['中文'] = TraceParameterDefinition(ParameterType.STRING, 15, 29) return param_map + @staticmethod + def create_std_parameterdefinitionmap() -> TraceParameterDefinitionMap: + param_map = TraceParameterDefinitionMap() + param_map['INPUT'] = TraceParameterDefinition(ParameterType.BYTE, 16, 0) + param_map['OUTPUT'] = TraceParameterDefinition(ParameterType.BYTE, 16, 16) + param_map['KEY'] = TraceParameterDefinition(ParameterType.BYTE, 16, 32) + return param_map + @staticmethod def create_traceparametermap() -> TraceParameterMap: param_map = TraceParameterMap() @@ -134,6 +142,34 @@ def test_from_trace_params(self): map_from_trace_params = TraceParameterDefinitionMap.from_trace_parameter_map(param_map) self.assertDictEqual(self.create_parameterdefinitionmap(), map_from_trace_params) + def test_append(self): + map_from_append = TraceParameterDefinitionMap() + map_from_append.append('IN', ParameterType.BYTE, 16) + map_from_append.append('TITLE', ParameterType.STRING, 13) + map_from_append.append('中文', ParameterType.STRING, 15) + self.assertDictEqual(self.create_parameterdefinitionmap(), map_from_append) + + map_from_std_append = TraceParameterDefinitionMap() + map_from_std_append.append_std('INPUT', 16) + map_from_std_append.append_std('OUTPUT', 16) + map_from_std_append.append_std('KEY', 16) + self.assertDictEqual(self.create_std_parameterdefinitionmap(), map_from_std_append) + + def test_insert(self): + map_from_insert = TraceParameterDefinitionMap() + map_from_insert.insert('TITLE', ParameterType.STRING, 13, 0) + with self.assertWarns(UserWarning): + map_from_insert.insert('中文', ParameterType.STRING, 15, 10) + map_from_insert.insert('IN', ParameterType.BYTE, 16, 0) + self.assertDictEqual(self.create_parameterdefinitionmap(), map_from_insert) + + map_from_std_insert = TraceParameterDefinitionMap() + map_from_std_insert.insert_std('INPUT', 16, 0) + with self.assertWarns(UserWarning): + map_from_std_insert.insert_std('KEY', 16, 9) + map_from_std_insert.insert_std('OUTPUT', 16, 16) + self.assertDictEqual(self.create_std_parameterdefinitionmap(), map_from_std_insert) + class TestTraceParameterMap(TestCase): CAFEBABE = bytes.fromhex('cafebabedeadbeef0102030405060708') diff --git a/trsfile/engine/trs.py b/trsfile/engine/trs.py index fcfe0cc..80e3b9f 100644 --- a/trsfile/engine/trs.py +++ b/trsfile/engine/trs.py @@ -173,21 +173,18 @@ def update_headers_with_traces_metadata(self, traces: List[Trace]) -> None: if len(set([len(trace.parameters.serialize()) for trace in traces])) > 1: raise TypeError('Traces have different data length, this is not supported in TRS files') - data_length = len(traces[0].parameters.serialize()) + headers_updates[Header.LENGTH_DATA] = len(traces[0].parameters.serialize()) - # Add a TraceParameterDefinitionMap if none is present, and verify its validity if one is present - if Header.TRACE_PARAMETER_DEFINITIONS not in self.headers: - headers_updates[Header.TRACE_PARAMETER_DEFINITIONS] = \ - TraceParameterDefinitionMap.from_trace_parameter_map(traces[0].parameters) - - headers_updates[Header.LENGTH_DATA] = data_length - - # Verify that each trace confirms to the traceset's TraceParameterDefinitionMap - for index, trace in enumerate(traces): - if not trace.parameters.matches(self.headers[Header.TRACE_PARAMETER_DEFINITIONS]): - raise TypeError(f"The parameters of trace #{index} do not match the trace set's definitions.\n" - f"Please make sure the trace parameters match those of the other traces in type, " - f"size and name.") + # Add a TraceParameterDefinitionMap if none is present, and verify its validity if one is present + if Header.TRACE_PARAMETER_DEFINITIONS not in self.headers: + headers_updates[Header.TRACE_PARAMETER_DEFINITIONS] = \ + TraceParameterDefinitionMap.from_trace_parameter_map(traces[0].parameters) + else: + for index, trace in enumerate(traces): + if not trace.parameters.matches(self.headers[Header.TRACE_PARAMETER_DEFINITIONS]): + raise TypeError(f"The parameters of trace #{index} do not match the trace set's definitions.\n" + f"Please make sure the trace parameters match those of the other traces in type, " + f"size and name.") if self.headers[Header.SAMPLE_CODING] is None: if len(set([trace.sample_coding for trace in traces])) > 1: diff --git a/trsfile/parametermap.py b/trsfile/parametermap.py index 8fb41a4..792b9be 100644 --- a/trsfile/parametermap.py +++ b/trsfile/parametermap.py @@ -293,6 +293,76 @@ def __setitem__(self, key: str, value: TraceParameterDefinition): self._stop_if_locked() super().__setitem__(key, value) + def insert_std(self, name: str, size: int, offset: int): + """Insert a trace parameter definition of a StandardTraceParameter into this map in a specified location. If + the given offset would put the new TraceParameter in the middle of a parameter already present in the map, the + offset is increased to put add the new parameter after that existing one instead. Any parameters already present + in the map that have a greater or equal offset than the new parameter, have their offset increased to make space + for the new parameter. + + :param name: The name of the TraceParameter for which to add a definition. This name must match that of a + StandardTraceParameter + :param size: The size of the TraceParameter, in number of values of its type + :param offset: The offset of the TraceParameter, in bytes + """ + try: + type = StandardTraceParameters.from_identifier(name).parameter_type + self.insert(name, type, size, offset) + except ValueError: + raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either specify a type in this " + f"insert or correct the name to match a standard trace parameter.") + + def insert(self, name: str, type: ParameterType, size: int, offset: int): + """Insert a trace parameter definition into this map in a specified location. If the given offset would put the + new TraceParameter in the middle of a parameter already present in the map, the offset is increased to put add + the new parameter after that existing one instead. Any parameters already present in the map that have a greater + or equal offset than the new parameter, have their offset increased to make space for the new parameter. + + :param name: The name of the TraceParameter for which to add a definition + :param type: The type of the TraceParameter for which to add a definition + :param size: The size of the TraceParameter, in number of values of its type + :param offset: The offset of the TraceParameter, in bytes + """ + self._stop_if_locked() + params_to_move_back = [] + for key, param in self.items(): + if param.offset >= offset: + param.offset += size * type.byte_size + params_to_move_back.append(key) + elif param.offset + param.length * param.param_type.byte_size > offset: + offset = param.offset + param.length * param.param_type.byte_size + warnings.warn("Given offset would put a parameter inside another trace parameter.\n" + f"Increased the offset of the inserted parameter definition to {offset} to prevent this.") + + new_definition = TraceParameterDefinition(type, size, offset) + self.__setitem__(name, new_definition) + for param in params_to_move_back: + self.move_to_end(param) + + def append_std(self, name: str, size: int): + """Append a trace parameter definition of a StandardTraceParameter to this map. The parameter wil be added after + all parameter definitions already in the map. + + :param name: The name of the TraceParameter for which to add a definition. This name must match that of a + StandardTraceParameter + :param size: The size of the TraceParameter, in number of values of its type""" + try: + type = StandardTraceParameters.from_identifier(name).parameter_type + self.append(name, type, size) + except ValueError: + raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either specify a type in this " + f"append or correct the name to match a standard trace parameter.") + + def append(self, name: str, type: ParameterType, size: int): + """Append a trace parameter definition to this map. The parameter wil be added after all parameter definitions + already in the map. + + :param name: The name of the TraceParameter for which to add a definition + :param type: The type of the TraceParameter for which to add a definition + :param size: The size of the TraceParameter, in number of values of its type""" + new_definition = TraceParameterDefinition(type, size, self.get_total_size()) + self.__setitem__(name, new_definition) + @staticmethod def deserialize(raw: BytesIO) -> TraceParameterDefinitionMap: result = TraceParameterDefinitionMap() From 482776680667277e701ef1b0e02c217236f3d7b9 Mon Sep 17 00:00:00 2001 From: Tom Hogervorst Date: Fri, 4 Nov 2022 16:00:47 +0100 Subject: [PATCH 4/8] Have all methods that add additional parameters return the map after modification This allows for constructions like `new Trace(SampleCoding.FLOAT, [0] * sample_count, TraceParameterMap().add('input', bytes.fromhex('0102030405060708')))` --- trsfile/parametermap.py | 70 ++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/trsfile/parametermap.py b/trsfile/parametermap.py index 792b9be..3c4223d 100644 --- a/trsfile/parametermap.py +++ b/trsfile/parametermap.py @@ -199,7 +199,7 @@ def __setitem__(self, key: str, value: Union[TraceParameter, TraceSetParameter]) super().__setitem__(key, value) @alias("add") - def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> None: + def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> TraceSetParameterMap: """Add a trace set parameter with a given name and value. If the name matches the identifier of a standard trace set parameter, then the value's type should be the type that standard trace set parameter expects. @@ -208,11 +208,13 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - :param name: The name of the parameter that will be added :param value: The value of the parameter. If the name matches that of a standard trace set parameter, it is recommended that the type of the value matches that of standard trace set parameter. Otherwise, - valid types are: int, float, bool, List[int], List[float], List[bool], bytes, bytearray or str""" + valid types are: int, float, bool, List[int], List[float], List[bool], bytes, bytearray or str + :return: This TraceSetParameterMap after adding the new parameter""" try: std_param = StandardTraceSetParameters.from_identifier(name) typed_param = std_param.parameter_type.param_class self[std_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + return self # if no std_param can be found, a ValueError is raised, # if adding the trace parameter to the map fails, a TypeError is raised except (ValueError, TypeError) as e: @@ -223,29 +225,34 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - typed_param = ParameterMapUtil.get_typed_parameter(value) self[name] = typed_param(ParameterMapUtil.to_list_if_listable(value)) - def add_standard_parameter(self, std_trace_set_param: StandardTraceSetParameters, value: ParameterMapUtil.ParameterValueType) -> None: + def add_standard_parameter(self, std_trace_set_param: StandardTraceSetParameters, + value: ParameterMapUtil.ParameterValueType) -> TraceSetParameterMap: """Add a standard trace set parameter with a given value. If the parameter already exists within the map, it will be overwritten. :param std_trace_set_param: The standard trace set parameter that will be added :param value: The value of the parameter. The type this value must have depends on - the standard trace set parameter.""" + the standard trace set parameter. + :return: This TraceSetParameterMap after adding the new standard parameter""" typed_param = std_trace_set_param.parameter_type.param_class self[std_trace_set_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + return self - def fill_from_headers(self, headers: Dict['Header', Any]) -> None: + def fill_from_headers(self, headers: Dict['Header', Any]) -> TraceSetParameterMap: """Add to this trace set parameter map all data that is in the header and for which standard trace set parameters exist. Data that already exists in the map will not be overwritten. - :param headers: The headers dictionary from which data will be copied into the trace set parameter map""" + :param headers: The headers dictionary from which data will be copied into the trace set parameter map + :return: This TraceSetParameterMap after adding the parameters based on the headers""" for header_tag, value in headers.items(): std_param = header_tag.equivalent_std_param if std_param is not None and std_param.identifier not in self: self.add_standard_parameter(std_param, value) - def add_defaults(self) -> None: - """If specific standard trace set parameters don't exist yet in the map, add them with default values""" + def add_defaults(self) -> TraceSetParameterMap: + """If specific standard trace set parameters don't exist yet in the map, add them with default values + :return: This TraceSetParameterMap after adding the default parameters""" for key, value in TraceSetParameterMap.default_values.items(): if key.identifier not in self: self.add_standard_parameter(key, value) @@ -293,35 +300,37 @@ def __setitem__(self, key: str, value: TraceParameterDefinition): self._stop_if_locked() super().__setitem__(key, value) - def insert_std(self, name: str, size: int, offset: int): + def insert_std(self, name: str, size: int, offset: int) -> TraceParameterDefinitionMap: """Insert a trace parameter definition of a StandardTraceParameter into this map in a specified location. If the given offset would put the new TraceParameter in the middle of a parameter already present in the map, the offset is increased to put add the new parameter after that existing one instead. Any parameters already present in the map that have a greater or equal offset than the new parameter, have their offset increased to make space for the new parameter. - :param name: The name of the TraceParameter for which to add a definition. This name must match that of a - StandardTraceParameter - :param size: The size of the TraceParameter, in number of values of its type + :param name: The name of the TraceParameter for which to add a definition. This name must match that of a + StandardTraceParameter + :param size: The size of the TraceParameter, in number of values of its type :param offset: The offset of the TraceParameter, in bytes + :return: This TraceParameterDefinition map after adding the new definition """ try: type = StandardTraceParameters.from_identifier(name).parameter_type - self.insert(name, type, size, offset) + return self.insert(name, type, size, offset) except ValueError: raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either specify a type in this " f"insert or correct the name to match a standard trace parameter.") - def insert(self, name: str, type: ParameterType, size: int, offset: int): + def insert(self, name: str, type: ParameterType, size: int, offset: int) -> TraceParameterDefinitionMap: """Insert a trace parameter definition into this map in a specified location. If the given offset would put the new TraceParameter in the middle of a parameter already present in the map, the offset is increased to put add the new parameter after that existing one instead. Any parameters already present in the map that have a greater or equal offset than the new parameter, have their offset increased to make space for the new parameter. - :param name: The name of the TraceParameter for which to add a definition - :param type: The type of the TraceParameter for which to add a definition - :param size: The size of the TraceParameter, in number of values of its type + :param name: The name of the TraceParameter for which to add a definition + :param type: The type of the TraceParameter for which to add a definition + :param size: The size of the TraceParameter, in number of values of its type :param offset: The offset of the TraceParameter, in bytes + :return: This TraceParameterDefinition map after adding the new definition """ self._stop_if_locked() params_to_move_back = [] @@ -338,30 +347,34 @@ def insert(self, name: str, type: ParameterType, size: int, offset: int): self.__setitem__(name, new_definition) for param in params_to_move_back: self.move_to_end(param) + return self - def append_std(self, name: str, size: int): + def append_std(self, name: str, size: int) -> TraceParameterDefinitionMap: """Append a trace parameter definition of a StandardTraceParameter to this map. The parameter wil be added after all parameter definitions already in the map. :param name: The name of the TraceParameter for which to add a definition. This name must match that of a StandardTraceParameter - :param size: The size of the TraceParameter, in number of values of its type""" + :param size: The size of the TraceParameter, in number of values of its type + :return: This TraceParameterDefinition map after adding the new definition""" try: type = StandardTraceParameters.from_identifier(name).parameter_type - self.append(name, type, size) + return self.append(name, type, size) except ValueError: raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either specify a type in this " f"append or correct the name to match a standard trace parameter.") - def append(self, name: str, type: ParameterType, size: int): + def append(self, name: str, type: ParameterType, size: int) -> TraceParameterDefinitionMap: """Append a trace parameter definition to this map. The parameter wil be added after all parameter definitions already in the map. :param name: The name of the TraceParameter for which to add a definition :param type: The type of the TraceParameter for which to add a definition - :param size: The size of the TraceParameter, in number of values of its type""" + :param size: The size of the TraceParameter, in number of values of its type + :return: This TraceParameterDefinition map after adding the new definition""" new_definition = TraceParameterDefinition(type, size, self.get_total_size()) self.__setitem__(name, new_definition) + return self @staticmethod def deserialize(raw: BytesIO) -> TraceParameterDefinitionMap: @@ -409,7 +422,7 @@ def __setitem__(self, key: str, value: TraceParameter): super().__setitem__(key, value) @alias("add") - def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> None: + def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> TraceParameterMap: """Add a trace parameter with a given name and value If the name matches the identifier of a standard trace parameter, then the value's type should be the type that standard trace parameter expects. @@ -418,7 +431,8 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - :param name: The name of the parameter that will be added :param value: The value of the parameter. If the name matches that of a standard trace parameter, it is recommended that the type of the value matches that of standard trace set parameter. Otherwise, - valid types are: int, float, bool, List[int], List[float], List[bool], bytes, bytearray or str""" + valid types are: int, float, bool, List[int], List[float], List[bool], bytes, bytearray or str + :return: This map after adding the parameter""" try: std_param = StandardTraceParameters.from_identifier(name) typed_param = std_param.parameter_type.param_class @@ -432,16 +446,20 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - "behavior when displaying this traceset or processing this trace in Inspector") typed_param = ParameterMapUtil.get_typed_parameter(value) self[name] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + return self - def add_standard_parameter(self, std_trace_param: StandardTraceParameters, value: ParameterMapUtil.ParameterValueType) -> None: + def add_standard_parameter(self, std_trace_param: StandardTraceParameters, + value: ParameterMapUtil.ParameterValueType) -> TraceParameterMap: """Add a standard trace parameter with a given value. If the parameter already exists within the map, it will be overwritten. :param std_trace_param: The standard trace parameter that will be added :param value: The value of the parameter. The type this value must have depends on - the standard trace parameter.""" + the standard trace parameter. + :return: This map after adding the parameter""" typed_param = std_trace_param.parameter_type.param_class self[std_trace_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + return self @staticmethod def deserialize(raw: bytes, definitions: TraceParameterDefinitionMap) -> TraceParameterMap: From 4217c3d803ff01bef7eac6ab19f2eaf0e8b09608 Mon Sep 17 00:00:00 2001 From: Tom Hogervorst Date: Fri, 4 Nov 2022 16:01:14 +0100 Subject: [PATCH 5/8] Add option to create Traces with raw trace data --- tests/test_creation.py | 18 ++++++++++++++++-- tests/test_parametermap.py | 28 +++++++++++++++++++++++++++- trsfile/parametermap.py | 24 ++++++++++++++++++++++++ trsfile/trace.py | 33 +++++++++++++++++++++++++++++---- 4 files changed, 96 insertions(+), 7 deletions(-) diff --git a/tests/test_creation.py b/tests/test_creation.py index c6eee8c..ff0769a 100644 --- a/tests/test_creation.py +++ b/tests/test_creation.py @@ -7,7 +7,7 @@ import shutil from trsfile import Trace, SampleCoding, Header, TracePadding -from trsfile.parametermap import TraceParameterMap, TraceParameterDefinitionMap, TraceSetParameterMap +from trsfile.parametermap import TraceParameterMap, TraceParameterDefinitionMap, TraceSetParameterMap, RawTraceData from trsfile.standardparameters import StandardTraceSetParameters from trsfile.traceparameter import ByteArrayParameter, TraceParameterDefinition, ParameterType, StringParameter, \ IntegerArrayParameter, BooleanArrayParameter, FloatArrayParameter @@ -178,6 +178,20 @@ def test_header_to_trace_set_params(self): except Exception as e: self.fail('Exception occurred: ' + str(e)) + def test_write_different_trace_sizes(self): + trace_count = 100 + sample_count = 1000 + + with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces: + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + RawTraceData(i.to_bytes(8, byteorder='big')) + ) + for i in range(0, trace_count)] + ) + def test_write_closed(self): trace_count = 100 sample_count = 1000 @@ -339,7 +353,7 @@ def test_append(self): Trace( SampleCoding.FLOAT, [0] * sample_count, - TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(i.to_bytes(8, byteorder='big'))}) + raw_data=i.to_bytes(8, byteorder='big') ) for i in range(0, trace_count)] ) diff --git a/tests/test_parametermap.py b/tests/test_parametermap.py index 1f5a02b..6e184c5 100644 --- a/tests/test_parametermap.py +++ b/tests/test_parametermap.py @@ -1,6 +1,6 @@ from unittest import TestCase -from trsfile.parametermap import TraceSetParameterMap, TraceParameterDefinitionMap, TraceParameterMap +from trsfile.parametermap import TraceSetParameterMap, TraceParameterDefinitionMap, TraceParameterMap, RawTraceData from trsfile.standardparameters import StandardTraceSetParameters, StandardTraceParameters from trsfile.traceparameter import * @@ -258,3 +258,29 @@ def test_add_standard_parameter(self): # However, this type check only produces a warning with self.assertWarns(UserWarning): param_map1.add_parameter('INPUT', 'cafebabedeadbeef0102030405060708') + + def test_raw_trace_data(self): + raw_data = RawTraceData(bytes.fromhex('cafebabedeadbeef0102030405060708')) + assert raw_data.serialize() == bytes.fromhex('cafebabedeadbeef0102030405060708') + + # Verify that nothing can be added into a raw data TraceParameterMap + with self.assertRaises(KeyError): + raw_data['INPUT'] = ByteArrayParameter(bytes.fromhex('cafebabedeadbeef0102030405060708')) + with self.assertRaises(KeyError): + raw_data.add('input', bytes.fromhex('cafebabedeadbeef0102030405060708')) + + # Verify that raw data can match any traceparameterdefinition map, as long as the length is correct + traceParamDefs = TraceParameterDefinitionMap() + traceParamDefs.append("INPUT", ParameterType.BYTE, 16) + assert raw_data.matches(traceParamDefs) + + traceParamDefs = TraceParameterDefinitionMap() + traceParamDefs.append("INPUT", ParameterType.BYTE, 8) + traceParamDefs.append("OUTPUT", ParameterType.BYTE, 8) + assert raw_data.matches(traceParamDefs) + traceParamDefs.append("KEY", ParameterType.BYTE, 8) + assert not raw_data.matches(traceParamDefs) + + with self.assertWarns(UserWarning): + TraceParameterDefinitionMap.from_trace_parameter_map(raw_data) + diff --git a/trsfile/parametermap.py b/trsfile/parametermap.py index 3c4223d..5076279 100644 --- a/trsfile/parametermap.py +++ b/trsfile/parametermap.py @@ -403,6 +403,12 @@ def from_trace_parameter_map(trace_parameters: TraceParameterMap) -> TraceParame :param trace_parameters: The trace parameter map from which the definitions will be deduced :return: A parameter definition map that described the metadata of the input trace parameter map""" + if isinstance(trace_parameters, RawTraceData): + warnings.warn("Creating a trace parameter definition map from raw trace data.\nThis is not recommended, " + "as it will not add any meta information about the trace data.\nEither manually define a " + "TraceParameterDefinitionMap for the traceset or make sure the first trace you add to the " + "traceset has a proper TraceParameterMap") + offset = 0 result = TraceParameterDefinitionMap() for key, trace_param in trace_parameters.items(): @@ -503,3 +509,21 @@ def matches(self, definitions: TraceParameterDefinitionMap) -> bool: break match &= matched_keys == list(definitions.keys()) return match + + +class RawTraceData(TraceParameterMap): + def __init__(self, data: bytes): + super().__init__() + super().__setitem__("LEGACY_DATA", ByteArrayParameter(data)) + + def __setitem__(self, key: str, value: TraceParameter): + raise KeyError("Adding Trace Parameters into raw trace data is not allowed") + + def matches(self, definitions: TraceParameterDefinitionMap) -> bool: + """Test whether this RawTraceData could be interpreted by given definitions + + :param definitions: The trace parameter definition map of the trs file to which the trace with this trace + raw trace data will be added + + :return: A boolean that is true if the trace parameter definitions can interpret this raw trace data""" + return definitions.get_total_size() == len(self["LEGACY_DATA"]) diff --git a/trsfile/trace.py b/trsfile/trace.py index 57e8a14..4169ba8 100644 --- a/trsfile/trace.py +++ b/trsfile/trace.py @@ -1,8 +1,9 @@ import numpy -from trsfile.parametermap import TraceParameterMap +from trsfile.parametermap import TraceParameterMap, RawTraceData from trsfile.common import Header, SampleCoding + class Trace: """The :py:obj:`Trace` class behaves like a :py:obj:`list` object were each item in the list is a sample of the trace. @@ -12,11 +13,35 @@ class Trace: provided :py:obj:`sample_coding`. """ - def __init__(self, sample_coding, samples, parameters=TraceParameterMap(), title='trace', headers={}): + def __init__(self, sample_coding, samples, parameters=None, title='trace', headers=None, raw_data: bytes = bytes()): + """ Create a new Trace. + :param sample_coding: The encoding of all samples in the trace + :param samples: The array of samples of the trace + :param parameters: The trace parameter map that contains the trace's data and its meta information. Do not use + in combination with raw_data + :param title: The title of the trace + :param headers: The headers of the trs file to which this trace will be added. This is an optional parameter; + information from the headers may define the locations of the input, output and key data in the trace data, + but it is recommended to store that information in the trace parameter map now. + :param raw_data: A byte array with the raw trace data. Do not use in combination with parameters. If used, it is + recommended that a TraceParameterDefinitionMap is added to the headers of the trsfile that defines the meta + information of this raw trace data. + """ + if parameters is None: + if len(raw_data) > 0: + parameters = RawTraceData(raw_data) + else: + parameters = TraceParameterMap() + else: + if len(raw_data) > 0: + raise Warning("Parameter map and raw data were both defined, but cannot both be used at the same time.\n" + "Only the parameter map will be used, raw data will be discarded.") + if headers is None: + headers = {} self.title = title self.parameters = parameters - if not type(self.parameters) is TraceParameterMap: - raise TypeError('Trace parameter data must be supplied as a TraceParameterMap') + if not isinstance(self.parameters, TraceParameterMap): + raise TypeError('Trace data must be supplied as a TraceParameterMap') # Obtain sample coding if not isinstance(sample_coding, SampleCoding): From b5e66141191a9dbe2ada8f92433fe019f960d484 Mon Sep 17 00:00:00 2001 From: Tom Hogervorst Date: Fri, 4 Nov 2022 16:01:58 +0100 Subject: [PATCH 6/8] Add ndarray testcases and fixes for issues found Issues were: incorrect comparison between ndarrays and lists, no support for multi-dimensional arrays, and no detection of empty ndarrays. --- tests/test_parameter.py | 136 +++++++++++++++++++++++++++++++------- trsfile/traceparameter.py | 15 ++++- 2 files changed, 125 insertions(+), 26 deletions(-) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index fff1782..fc5e94b 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,6 +1,8 @@ from io import BytesIO from unittest import TestCase +from numpy import ndarray, int16, array, int32, int64, single, double, uint8, int8, uint16, bool8 + from trsfile.traceparameter import BooleanArrayParameter, ByteArrayParameter, DoubleArrayParameter, FloatArrayParameter, \ IntegerArrayParameter, ShortArrayParameter, LongArrayParameter, StringParameter @@ -8,9 +10,12 @@ class TestParameter(TestCase): def test_bool_parameter(self): serialized_param = b'\x01\x00\x01' - param = BooleanArrayParameter([True, False, True]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(BooleanArrayParameter.deserialize(BytesIO(serialized_param), 3), param) + param1 = BooleanArrayParameter([True, False, True]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(BooleanArrayParameter.deserialize(BytesIO(serialized_param), 3), param1) + param2 = BooleanArrayParameter(ndarray(shape=[3], dtype=bool8, + buffer=array([bool8(val) for val in [True, False, True]]))) + self.assertEqual(param1, param2) with self.assertRaises(TypeError): BooleanArrayParameter(True) @@ -27,6 +32,11 @@ def test_byte_parameter(self): self.assertEqual(serialized_param, param1.serialize()) self.assertEqual(ByteArrayParameter.deserialize(BytesIO(serialized_param), 16), param1) + with self.assertWarns(UserWarning): + param2 = ByteArrayParameter(ndarray(shape=[2, 2, 4], dtype=uint8, + buffer=array([uint8(val) for val in int_data]))) + self.assertEqual(param1, param2) + param2 = ByteArrayParameter(bytearray(int_data)) self.assertEqual(param1, param2) @@ -37,8 +47,14 @@ def test_byte_parameter(self): ByteArrayParameter([0, '1']) with self.assertRaises(TypeError): ByteArrayParameter([bytes([0, 1, 2, 3]), bytes([4, 5, 6, 7])]) + with self.assertRaises(TypeError): + ByteArrayParameter(ndarray(shape=[16], dtype=int8, buffer=array([int8(val) for val in int_data]))) + with self.assertRaises(TypeError): + ByteArrayParameter(ndarray(shape=[16], dtype=uint16, buffer=array([uint16(val) for val in int_data]))) with self.assertRaises(ValueError): ByteArrayParameter([]) + with self.assertRaises(ValueError): + ByteArrayParameter(ndarray(shape=[0], dtype=uint8, buffer=array([]))) with self.assertRaises(TypeError): ByteArrayParameter([0, 1, 2, -1]) with self.assertRaises(TypeError): @@ -47,13 +63,26 @@ def test_byte_parameter(self): def test_double_parameter(self): serialized_param = b'\x00\x00\x00\x00\x00\x00\xe0\xbf\x00\x00\x00\x00\x00\x00\xe0\x3f' \ b'\x00\x00\x00\x00\x80\x84\x2e\x41' - param = DoubleArrayParameter([-0.5, 0.5, 1e6]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(DoubleArrayParameter.deserialize(BytesIO(serialized_param), 3), param) + param1 = DoubleArrayParameter([-0.5, 0.5, 1e6]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(DoubleArrayParameter.deserialize(BytesIO(serialized_param), 3), param1) + + param2 = DoubleArrayParameter(ndarray(shape=[3], dtype=double, buffer=array([-0.5, 0.5, 1e6]))) + self.assertEqual(param1, param2) # an array of only integers is still a valid value of a DoubleArrayParameter - param1 = DoubleArrayParameter([1, 2, 1000000]) - param2 = DoubleArrayParameter([1, 2.0, 1e6]) + param1 = DoubleArrayParameter([-1, 2, 1000000]) + param2 = DoubleArrayParameter([-1, 2.0, 1e6]) + self.assertEqual(param1, param2) + with self.assertWarns(UserWarning): + param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int32, + buffer=array([-1, 2, 1000000]))) + self.assertEqual(param1, param2) + + with self.assertWarns(UserWarning): + param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int64, + buffer=array([int64(val) for val in [-1, 2, 10000000000]]))) + param2 = DoubleArrayParameter([-1, 2.0, 1e10]) self.assertEqual(param1, param2) # a float array parameter is not the same as a double array parameter @@ -63,42 +92,82 @@ def test_double_parameter(self): with self.assertRaises(TypeError): DoubleArrayParameter([0.5, -0.5, 'NaN']) + with self.assertRaises(TypeError): + DoubleArrayParameter(ndarray(shape=[3], dtype=single, + buffer=array([single(val) for val in [-0.5, 0.5, 1e6]]))) with self.assertRaises(TypeError): DoubleArrayParameter(0.5) with self.assertRaises(ValueError): DoubleArrayParameter([]) + with self.assertRaises(ValueError): + IntegerArrayParameter(ndarray(shape=[0], dtype=double, buffer=array([]))) def test_float_parameter(self): serialized_param = b'\x00\x00\x00\xbf\x00\x00\x00\x3f\x00\x24\x74\x49' - param = FloatArrayParameter([-0.5, 0.5, 1e6]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(FloatArrayParameter.deserialize(BytesIO(serialized_param), 3), param) + param1 = FloatArrayParameter([-0.5, 0.5, 1e6]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(FloatArrayParameter.deserialize(BytesIO(serialized_param), 3), param1) + + param2 = FloatArrayParameter(ndarray(shape=[3], dtype=single, + buffer=array([single(val) for val in [-0.5, 0.5, 1e6]]))) + self.assertEqual(param1, param2) # an array of only integers is still a valid value of a FloatArrayParameter - param1 = FloatArrayParameter([1, 2, 1000000]) - param2 = FloatArrayParameter([1, 2.0, 1e6]) + param1 = FloatArrayParameter([-1, 2, 1000000]) + param2 = FloatArrayParameter([-1, 2.0, 1e6]) + self.assertEqual(param1, param2) + + with self.assertWarns(UserWarning): + param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int32, + buffer=array([-1, 2, 1000000]))) + self.assertEqual(param1, param2) + + with self.assertWarns(UserWarning): + param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int64, + buffer=array([int64(val) for val in [-1, 2, 10000000000]]))) + param2 = FloatArrayParameter([-1, 2.0, 1e10]) self.assertEqual(param1, param2) with self.assertRaises(TypeError): FloatArrayParameter([0.5, -0.5, 'NaN']) + with self.assertRaises(TypeError): + FloatArrayParameter(ndarray(shape=[3], dtype=double, + buffer=array([double(val) for val in [-0.5, 0.5, 1e6]]))) with self.assertRaises(TypeError): FloatArrayParameter(0.5) with self.assertRaises(ValueError): FloatArrayParameter([]) + with self.assertRaises(ValueError): + IntegerArrayParameter(ndarray(shape=[0], dtype=single, buffer=array([]))) def test_integer_parameter(self): serialized_param = b'\xff\xff\xff\xff\x01\x00\x00\x00\xff\xff\xff\x7f\x00\x00\x00\x80' - param = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(IntegerArrayParameter.deserialize(BytesIO(serialized_param), 4), param) + param1 = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(IntegerArrayParameter.deserialize(BytesIO(serialized_param), 4), param1) + + with self.assertWarns(UserWarning): + param2 = IntegerArrayParameter(ndarray(shape=[2, 2], dtype=int32, + buffer=array([-1, 1, 0x7fffffff, -0x80000000]))) + self.assertEqual(param1, param2) # a short array parameter is not the same as an int array parameter param1 = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767]) param2 = IntegerArrayParameter([0, 1, -1, 255, 256, -32768, 32767]) self.assertNotEqual(param1, param2) + # verify that an integer array parameter based on a ndarray filled with int16s works + param1 = IntegerArrayParameter(ndarray(shape=[7], dtype=int16, + buffer=array([int16(val) for val in [0, 1, -1, 255, 256, -32768, 32767]]))) + self.assertEqual(param1, param2) + with self.assertRaises(TypeError): IntegerArrayParameter([1, 256, 1.0]) + with self.assertRaises(TypeError): + IntegerArrayParameter(ndarray(shape=[4], dtype=int64, + buffer=array([-1, 1, 0x7fffffffffffffff, -0x8000000000000000]))) + with self.assertRaises(ValueError): + IntegerArrayParameter(ndarray(shape=[0], dtype=int32, buffer=array([]))) with self.assertRaises(TypeError): IntegerArrayParameter(1) with self.assertRaises(ValueError): @@ -111,30 +180,51 @@ def test_integer_parameter(self): def test_long_parameter(self): serialized_param = b'\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x00\x00\x00\x00\x00\x00' \ b'\xff\xff\xff\xff\xff\xff\xff\x7f\x00\x00\x00\x00\x00\x00\x00\x80' - param = LongArrayParameter([-1, 1, 0x7fffffffffffffff, -0x8000000000000000]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(LongArrayParameter.deserialize(BytesIO(serialized_param), 4), param) + param1= LongArrayParameter([-1, 1, 0x7fffffffffffffff, -0x8000000000000000]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(LongArrayParameter.deserialize(BytesIO(serialized_param), 4), param1) + + with self.assertWarns(UserWarning): + param2 = LongArrayParameter(ndarray(shape=[2, 2], dtype=int64, + buffer=array([-1, 1, 0x7fffffffffffffff, -0x8000000000000000]))) + self.assertEqual(param1, param2) # an int array parameter is not the same as a long array parameter param1 = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000]) param2 = LongArrayParameter([-1, 1, 0x7fffffff, -0x80000000]) self.assertNotEqual(param1, param2) + # verify that a long array parameter based on a ndarray filled with int32s works + with self.assertWarns(UserWarning): + param1 = LongArrayParameter(ndarray(shape=[1, 4], dtype=int32, + buffer=array([-1, 1, 0x7fffffff, -0x80000000]))) + self.assertEqual(param1, param2) + with self.assertRaises(TypeError): LongArrayParameter([1, 256, 1.0]) with self.assertRaises(TypeError): LongArrayParameter(1) with self.assertRaises(ValueError): LongArrayParameter([]) + with self.assertRaises(ValueError): + LongArrayParameter(ndarray(shape=[0], dtype=int64, buffer=array([]))) def test_short_parameter(self): serialized_param = b'\x00\x00\x01\x00\xff\xff\xff\x00\x00\x01\x00\x80\xff\x7f' - param = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(ShortArrayParameter.deserialize(BytesIO(serialized_param), 7), param) + param1 = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(ShortArrayParameter.deserialize(BytesIO(serialized_param), 7), param1) + + param2 = ShortArrayParameter(ndarray(shape=[7], dtype=int16, + buffer=array([int16(val) for val in [0, 1, -1, 255, 256, -32768, 32767]]))) + self.assertEqual(param1, param2) with self.assertRaises(TypeError): ShortArrayParameter([1, 256, 1.0]) + with self.assertRaises(TypeError): + ShortArrayParameter(ndarray(shape=[4], dtype=int32, buffer=array([-1, 1, 0x7fffffff, -0x80000000]))) + with self.assertRaises(ValueError): + ShortArrayParameter(ndarray(shape=[0], dtype=int16, buffer=array([]))) with self.assertRaises(TypeError): ShortArrayParameter(1) with self.assertRaises(ValueError): @@ -159,5 +249,3 @@ def test_string_parameter(self): StringParameter(['The', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'lazy', 'dog']) with self.assertRaises(ValueError): StringParameter(None) - - diff --git a/trsfile/traceparameter.py b/trsfile/traceparameter.py index 469f0ad..29237e0 100644 --- a/trsfile/traceparameter.py +++ b/trsfile/traceparameter.py @@ -1,6 +1,7 @@ from __future__ import annotations import struct +import warnings from abc import ABC, abstractmethod from enum import Enum from io import BytesIO @@ -37,7 +38,11 @@ def _has_expected_type(value: Any) -> bool: pass def __init__(self, value): - if type(value) is not str and (value is None or len(value) <= 0): + if type(value) is ndarray and len(value.shape) > 1: + warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n" + "Information about dimensions of this ndarray will be lost.") + value = value.flatten() + if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0): raise ValueError('The value for a TraceParameter cannot be empty') if not type(self)._has_expected_type(value): raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"' @@ -48,7 +53,13 @@ def __len__(self): return len(self.value) def __eq__(self, other): - return isinstance(other, type(self)) and self.value == other.value + if not isinstance(other, type(self)): + return False + if (type(self.value) == list or type(self.value) == ndarray) and \ + (type(other.value) == list or type(other.value) == ndarray): + return all(this_val == that_val for (this_val, that_val) in zip(self.value, other.value)) + else: + return self.value == other.value def __str__(self): return str(self.value) From 76da8102c4abe339f0ec90bbdcca69a561287e02 Mon Sep 17 00:00:00 2001 From: Tom Hogervorst Date: Tue, 8 Nov 2022 12:56:42 +0100 Subject: [PATCH 7/8] Fix failing testcases More explicitly set types of ndarray input values --- tests/test_parameter.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index fc5e94b..45636a9 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -76,7 +76,7 @@ def test_double_parameter(self): self.assertEqual(param1, param2) with self.assertWarns(UserWarning): param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int32, - buffer=array([-1, 2, 1000000]))) + buffer=array([int32(val) for val in [-1, 2, 1000000]]))) self.assertEqual(param1, param2) with self.assertWarns(UserWarning): @@ -119,7 +119,7 @@ def test_float_parameter(self): with self.assertWarns(UserWarning): param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int32, - buffer=array([-1, 2, 1000000]))) + buffer=array([int32(val) for val in [-1, 2, 1000000]]))) self.assertEqual(param1, param2) with self.assertWarns(UserWarning): @@ -148,7 +148,7 @@ def test_integer_parameter(self): with self.assertWarns(UserWarning): param2 = IntegerArrayParameter(ndarray(shape=[2, 2], dtype=int32, - buffer=array([-1, 1, 0x7fffffff, -0x80000000]))) + buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]]))) self.assertEqual(param1, param2) # a short array parameter is not the same as an int array parameter @@ -165,7 +165,7 @@ def test_integer_parameter(self): IntegerArrayParameter([1, 256, 1.0]) with self.assertRaises(TypeError): IntegerArrayParameter(ndarray(shape=[4], dtype=int64, - buffer=array([-1, 1, 0x7fffffffffffffff, -0x8000000000000000]))) + buffer=array([int64(val) for val in [-1, 1, 0x7fffffffffffffff, -0x8000000000000000]]))) with self.assertRaises(ValueError): IntegerArrayParameter(ndarray(shape=[0], dtype=int32, buffer=array([]))) with self.assertRaises(TypeError): @@ -186,7 +186,7 @@ def test_long_parameter(self): with self.assertWarns(UserWarning): param2 = LongArrayParameter(ndarray(shape=[2, 2], dtype=int64, - buffer=array([-1, 1, 0x7fffffffffffffff, -0x8000000000000000]))) + buffer=array([int64(val) for val in [-1, 1, 0x7fffffffffffffff, -0x8000000000000000]]))) self.assertEqual(param1, param2) # an int array parameter is not the same as a long array parameter @@ -197,7 +197,7 @@ def test_long_parameter(self): # verify that a long array parameter based on a ndarray filled with int32s works with self.assertWarns(UserWarning): param1 = LongArrayParameter(ndarray(shape=[1, 4], dtype=int32, - buffer=array([-1, 1, 0x7fffffff, -0x80000000]))) + buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]]))) self.assertEqual(param1, param2) with self.assertRaises(TypeError): @@ -222,7 +222,8 @@ def test_short_parameter(self): with self.assertRaises(TypeError): ShortArrayParameter([1, 256, 1.0]) with self.assertRaises(TypeError): - ShortArrayParameter(ndarray(shape=[4], dtype=int32, buffer=array([-1, 1, 0x7fffffff, -0x80000000]))) + ShortArrayParameter(ndarray(shape=[4], dtype=int32, + buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]]))) with self.assertRaises(ValueError): ShortArrayParameter(ndarray(shape=[0], dtype=int16, buffer=array([]))) with self.assertRaises(TypeError): From a422c0b82586356de79c036ba61a576227f09b23 Mon Sep 17 00:00:00 2001 From: Tom Hogervorst Date: Thu, 17 Nov 2022 09:38:35 +0100 Subject: [PATCH 8/8] Address review comments --- trsfile/parametermap.py | 25 ++++++++++++------------- trsfile/traceparameter.py | 11 ++++++----- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/trsfile/parametermap.py b/trsfile/parametermap.py index 5076279..a0b9412 100644 --- a/trsfile/parametermap.py +++ b/trsfile/parametermap.py @@ -303,13 +303,13 @@ def __setitem__(self, key: str, value: TraceParameterDefinition): def insert_std(self, name: str, size: int, offset: int) -> TraceParameterDefinitionMap: """Insert a trace parameter definition of a StandardTraceParameter into this map in a specified location. If the given offset would put the new TraceParameter in the middle of a parameter already present in the map, the - offset is increased to put add the new parameter after that existing one instead. Any parameters already present + offset is increased to put the new parameter after that existing one instead. Any parameters already present in the map that have a greater or equal offset than the new parameter, have their offset increased to make space for the new parameter. :param name: The name of the TraceParameter for which to add a definition. This name must match that of a StandardTraceParameter - :param size: The size of the TraceParameter, in number of values of its type + :param size: The size of the TraceParameter, in number of elements :param offset: The offset of the TraceParameter, in bytes :return: This TraceParameterDefinition map after adding the new definition """ @@ -317,18 +317,18 @@ def insert_std(self, name: str, size: int, offset: int) -> TraceParameterDefinit type = StandardTraceParameters.from_identifier(name).parameter_type return self.insert(name, type, size, offset) except ValueError: - raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either specify a type in this " - f"insert or correct the name to match a standard trace parameter.") + raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either use the 'insert' method or " + f"correct the name to match a standard trace parameter.") def insert(self, name: str, type: ParameterType, size: int, offset: int) -> TraceParameterDefinitionMap: """Insert a trace parameter definition into this map in a specified location. If the given offset would put the - new TraceParameter in the middle of a parameter already present in the map, the offset is increased to put add - the new parameter after that existing one instead. Any parameters already present in the map that have a greater + new TraceParameter in the middle of a parameter already present in the map, the offset is increased to put the + new parameter after that existing one instead. Any parameters already present in the map that have a greater or equal offset than the new parameter, have their offset increased to make space for the new parameter. :param name: The name of the TraceParameter for which to add a definition :param type: The type of the TraceParameter for which to add a definition - :param size: The size of the TraceParameter, in number of values of its type + :param size: The size of the TraceParameter, in number of elements :param offset: The offset of the TraceParameter, in bytes :return: This TraceParameterDefinition map after adding the new definition """ @@ -355,14 +355,14 @@ def append_std(self, name: str, size: int) -> TraceParameterDefinitionMap: :param name: The name of the TraceParameter for which to add a definition. This name must match that of a StandardTraceParameter - :param size: The size of the TraceParameter, in number of values of its type + :param size: The size of the TraceParameter, in number of elements :return: This TraceParameterDefinition map after adding the new definition""" try: type = StandardTraceParameters.from_identifier(name).parameter_type return self.append(name, type, size) except ValueError: - raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either specify a type in this " - f"append or correct the name to match a standard trace parameter.") + raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either use the 'append' method or " + f"correct the name to match a standard trace parameter.") def append(self, name: str, type: ParameterType, size: int) -> TraceParameterDefinitionMap: """Append a trace parameter definition to this map. The parameter wil be added after all parameter definitions @@ -370,7 +370,7 @@ def append(self, name: str, type: ParameterType, size: int) -> TraceParameterDef :param name: The name of the TraceParameter for which to add a definition :param type: The type of the TraceParameter for which to add a definition - :param size: The size of the TraceParameter, in number of values of its type + :param size: The size of the TraceParameter, in number of elements :return: This TraceParameterDefinition map after adding the new definition""" new_definition = TraceParameterDefinition(type, size, self.get_total_size()) self.__setitem__(name, new_definition) @@ -522,8 +522,7 @@ def __setitem__(self, key: str, value: TraceParameter): def matches(self, definitions: TraceParameterDefinitionMap) -> bool: """Test whether this RawTraceData could be interpreted by given definitions - :param definitions: The trace parameter definition map of the trs file to which the trace with this trace - raw trace data will be added + :param definitions: The trace parameter definition map to check this raw trace data against :return: A boolean that is true if the trace parameter definitions can interpret this raw trace data""" return definitions.get_total_size() == len(self["LEGACY_DATA"]) diff --git a/trsfile/traceparameter.py b/trsfile/traceparameter.py index 29237e0..494115d 100644 --- a/trsfile/traceparameter.py +++ b/trsfile/traceparameter.py @@ -55,11 +55,12 @@ def __len__(self): def __eq__(self, other): if not isinstance(other, type(self)): return False - if (type(self.value) == list or type(self.value) == ndarray) and \ - (type(other.value) == list or type(other.value) == ndarray): - return all(this_val == that_val for (this_val, that_val) in zip(self.value, other.value)) - else: - return self.value == other.value + + if len(self.value) != len(other.value): + return False + # return true only if both parameter value arrays contain the same elements in the same order, + # regardless of whether it is an ndarray or a list + return all(this_val == that_val for (this_val, that_val) in zip(self.value, other.value)) def __str__(self): return str(self.value)