diff --git a/tests/test_creation.py b/tests/test_creation.py index 7805afe..ff0769a 100644 --- a/tests/test_creation.py +++ b/tests/test_creation.py @@ -7,7 +7,7 @@ import shutil from trsfile import Trace, SampleCoding, Header, TracePadding -from trsfile.parametermap import TraceParameterMap, TraceParameterDefinitionMap, TraceSetParameterMap +from trsfile.parametermap import TraceParameterMap, TraceParameterDefinitionMap, TraceSetParameterMap, RawTraceData from trsfile.standardparameters import StandardTraceSetParameters from trsfile.traceparameter import ByteArrayParameter, TraceParameterDefinition, ParameterType, StringParameter, \ IntegerArrayParameter, BooleanArrayParameter, FloatArrayParameter @@ -178,6 +178,20 @@ def test_header_to_trace_set_params(self): except Exception as e: self.fail('Exception occurred: ' + str(e)) + def test_write_different_trace_sizes(self): + trace_count = 100 + sample_count = 1000 + + with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces: + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + RawTraceData(i.to_bytes(8, byteorder='big')) + ) + for i in range(0, trace_count)] + ) + def test_write_closed(self): trace_count = 100 sample_count = 1000 @@ -196,6 +210,75 @@ def test_write_closed(self): with self.assertRaises(ValueError): print(trs_traces) + def test_write_different_trace_sizes(self): + trace_count = 100 + sample_count = 1000 + + with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces: + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(i.to_bytes(8, byteorder='big'))}) + ) + for i in range(0, trace_count)] + ) + with self.assertRaises(TypeError): + # The length is incorrect + # Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions. + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('cafebabedeadbeef0102030405060708'))}) + )] + ) + with self.assertRaises(TypeError): + # The name is incorrect + # Should raise a Type error: The parameters of trace #1 do not match the trace set's definitions. + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('0102030405060708'))}) + ), + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'NEW_DATA': ByteArrayParameter(bytes.fromhex('0102030405060708'))}) + )] + ) + with self.assertRaises(TypeError): + # The type is incorrect + # Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions. + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': IntegerArrayParameter([42, 74])}) + )] + ) + + with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces: + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap() + ) + for i in range(0, trace_count)] + ) + with self.assertRaises(TypeError): + # The length, data and name are incorrect + # Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions. + trs_traces.extend([ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('cafebabedeadbeef0102030405060708'))}) + )] + ) + def test_read(self): trace_count = 100 sample_count = 1000 @@ -270,7 +353,7 @@ def test_append(self): Trace( SampleCoding.FLOAT, [0] * sample_count, - TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(i.to_bytes(8, byteorder='big'))}) + raw_data=i.to_bytes(8, byteorder='big') ) for i in range(0, trace_count)] ) diff --git a/tests/test_header.py b/tests/test_header.py index 5cf5834..7e567cd 100644 --- a/tests/test_header.py +++ b/tests/test_header.py @@ -90,6 +90,10 @@ def test_trace_param_defs_append_errors(self): trace_parameter_definitions.popitem() with self.assertRaises(TypeError): trace_parameter_definitions.clear() + with self.assertRaises(TypeError): + trace_parameter_definitions.append('input', ParameterType.BYTE, 16) + with self.assertRaises(TypeError): + trace_parameter_definitions.insert('output', ParameterType.BYTE, 16, 0) # Shallow copies still share references to the same trace set parameters as the original, # and should therefore not be modifiable if the original isn't diff --git a/tests/test_parameter.py b/tests/test_parameter.py index fff1782..45636a9 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,6 +1,8 @@ from io import BytesIO from unittest import TestCase +from numpy import ndarray, int16, array, int32, int64, single, double, uint8, int8, uint16, bool8 + from trsfile.traceparameter import BooleanArrayParameter, ByteArrayParameter, DoubleArrayParameter, FloatArrayParameter, \ IntegerArrayParameter, ShortArrayParameter, LongArrayParameter, StringParameter @@ -8,9 +10,12 @@ class TestParameter(TestCase): def test_bool_parameter(self): serialized_param = b'\x01\x00\x01' - param = BooleanArrayParameter([True, False, True]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(BooleanArrayParameter.deserialize(BytesIO(serialized_param), 3), param) + param1 = BooleanArrayParameter([True, False, True]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(BooleanArrayParameter.deserialize(BytesIO(serialized_param), 3), param1) + param2 = BooleanArrayParameter(ndarray(shape=[3], dtype=bool8, + buffer=array([bool8(val) for val in [True, False, True]]))) + self.assertEqual(param1, param2) with self.assertRaises(TypeError): BooleanArrayParameter(True) @@ -27,6 +32,11 @@ def test_byte_parameter(self): self.assertEqual(serialized_param, param1.serialize()) self.assertEqual(ByteArrayParameter.deserialize(BytesIO(serialized_param), 16), param1) + with self.assertWarns(UserWarning): + param2 = ByteArrayParameter(ndarray(shape=[2, 2, 4], dtype=uint8, + buffer=array([uint8(val) for val in int_data]))) + self.assertEqual(param1, param2) + param2 = ByteArrayParameter(bytearray(int_data)) self.assertEqual(param1, param2) @@ -37,8 +47,14 @@ def test_byte_parameter(self): ByteArrayParameter([0, '1']) with self.assertRaises(TypeError): ByteArrayParameter([bytes([0, 1, 2, 3]), bytes([4, 5, 6, 7])]) + with self.assertRaises(TypeError): + ByteArrayParameter(ndarray(shape=[16], dtype=int8, buffer=array([int8(val) for val in int_data]))) + with self.assertRaises(TypeError): + ByteArrayParameter(ndarray(shape=[16], dtype=uint16, buffer=array([uint16(val) for val in int_data]))) with self.assertRaises(ValueError): ByteArrayParameter([]) + with self.assertRaises(ValueError): + ByteArrayParameter(ndarray(shape=[0], dtype=uint8, buffer=array([]))) with self.assertRaises(TypeError): ByteArrayParameter([0, 1, 2, -1]) with self.assertRaises(TypeError): @@ -47,13 +63,26 @@ def test_byte_parameter(self): def test_double_parameter(self): serialized_param = b'\x00\x00\x00\x00\x00\x00\xe0\xbf\x00\x00\x00\x00\x00\x00\xe0\x3f' \ b'\x00\x00\x00\x00\x80\x84\x2e\x41' - param = DoubleArrayParameter([-0.5, 0.5, 1e6]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(DoubleArrayParameter.deserialize(BytesIO(serialized_param), 3), param) + param1 = DoubleArrayParameter([-0.5, 0.5, 1e6]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(DoubleArrayParameter.deserialize(BytesIO(serialized_param), 3), param1) + + param2 = DoubleArrayParameter(ndarray(shape=[3], dtype=double, buffer=array([-0.5, 0.5, 1e6]))) + self.assertEqual(param1, param2) # an array of only integers is still a valid value of a DoubleArrayParameter - param1 = DoubleArrayParameter([1, 2, 1000000]) - param2 = DoubleArrayParameter([1, 2.0, 1e6]) + param1 = DoubleArrayParameter([-1, 2, 1000000]) + param2 = DoubleArrayParameter([-1, 2.0, 1e6]) + self.assertEqual(param1, param2) + with self.assertWarns(UserWarning): + param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int32, + buffer=array([int32(val) for val in [-1, 2, 1000000]]))) + self.assertEqual(param1, param2) + + with self.assertWarns(UserWarning): + param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int64, + buffer=array([int64(val) for val in [-1, 2, 10000000000]]))) + param2 = DoubleArrayParameter([-1, 2.0, 1e10]) self.assertEqual(param1, param2) # a float array parameter is not the same as a double array parameter @@ -63,42 +92,82 @@ def test_double_parameter(self): with self.assertRaises(TypeError): DoubleArrayParameter([0.5, -0.5, 'NaN']) + with self.assertRaises(TypeError): + DoubleArrayParameter(ndarray(shape=[3], dtype=single, + buffer=array([single(val) for val in [-0.5, 0.5, 1e6]]))) with self.assertRaises(TypeError): DoubleArrayParameter(0.5) with self.assertRaises(ValueError): DoubleArrayParameter([]) + with self.assertRaises(ValueError): + IntegerArrayParameter(ndarray(shape=[0], dtype=double, buffer=array([]))) def test_float_parameter(self): serialized_param = b'\x00\x00\x00\xbf\x00\x00\x00\x3f\x00\x24\x74\x49' - param = FloatArrayParameter([-0.5, 0.5, 1e6]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(FloatArrayParameter.deserialize(BytesIO(serialized_param), 3), param) + param1 = FloatArrayParameter([-0.5, 0.5, 1e6]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(FloatArrayParameter.deserialize(BytesIO(serialized_param), 3), param1) + + param2 = FloatArrayParameter(ndarray(shape=[3], dtype=single, + buffer=array([single(val) for val in [-0.5, 0.5, 1e6]]))) + self.assertEqual(param1, param2) # an array of only integers is still a valid value of a FloatArrayParameter - param1 = FloatArrayParameter([1, 2, 1000000]) - param2 = FloatArrayParameter([1, 2.0, 1e6]) + param1 = FloatArrayParameter([-1, 2, 1000000]) + param2 = FloatArrayParameter([-1, 2.0, 1e6]) + self.assertEqual(param1, param2) + + with self.assertWarns(UserWarning): + param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int32, + buffer=array([int32(val) for val in [-1, 2, 1000000]]))) + self.assertEqual(param1, param2) + + with self.assertWarns(UserWarning): + param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int64, + buffer=array([int64(val) for val in [-1, 2, 10000000000]]))) + param2 = FloatArrayParameter([-1, 2.0, 1e10]) self.assertEqual(param1, param2) with self.assertRaises(TypeError): FloatArrayParameter([0.5, -0.5, 'NaN']) + with self.assertRaises(TypeError): + FloatArrayParameter(ndarray(shape=[3], dtype=double, + buffer=array([double(val) for val in [-0.5, 0.5, 1e6]]))) with self.assertRaises(TypeError): FloatArrayParameter(0.5) with self.assertRaises(ValueError): FloatArrayParameter([]) + with self.assertRaises(ValueError): + IntegerArrayParameter(ndarray(shape=[0], dtype=single, buffer=array([]))) def test_integer_parameter(self): serialized_param = b'\xff\xff\xff\xff\x01\x00\x00\x00\xff\xff\xff\x7f\x00\x00\x00\x80' - param = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(IntegerArrayParameter.deserialize(BytesIO(serialized_param), 4), param) + param1 = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(IntegerArrayParameter.deserialize(BytesIO(serialized_param), 4), param1) + + with self.assertWarns(UserWarning): + param2 = IntegerArrayParameter(ndarray(shape=[2, 2], dtype=int32, + buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]]))) + self.assertEqual(param1, param2) # a short array parameter is not the same as an int array parameter param1 = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767]) param2 = IntegerArrayParameter([0, 1, -1, 255, 256, -32768, 32767]) self.assertNotEqual(param1, param2) + # verify that an integer array parameter based on a ndarray filled with int16s works + param1 = IntegerArrayParameter(ndarray(shape=[7], dtype=int16, + buffer=array([int16(val) for val in [0, 1, -1, 255, 256, -32768, 32767]]))) + self.assertEqual(param1, param2) + with self.assertRaises(TypeError): IntegerArrayParameter([1, 256, 1.0]) + with self.assertRaises(TypeError): + IntegerArrayParameter(ndarray(shape=[4], dtype=int64, + buffer=array([int64(val) for val in [-1, 1, 0x7fffffffffffffff, -0x8000000000000000]]))) + with self.assertRaises(ValueError): + IntegerArrayParameter(ndarray(shape=[0], dtype=int32, buffer=array([]))) with self.assertRaises(TypeError): IntegerArrayParameter(1) with self.assertRaises(ValueError): @@ -111,30 +180,52 @@ def test_integer_parameter(self): def test_long_parameter(self): serialized_param = b'\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x00\x00\x00\x00\x00\x00' \ b'\xff\xff\xff\xff\xff\xff\xff\x7f\x00\x00\x00\x00\x00\x00\x00\x80' - param = LongArrayParameter([-1, 1, 0x7fffffffffffffff, -0x8000000000000000]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(LongArrayParameter.deserialize(BytesIO(serialized_param), 4), param) + param1= LongArrayParameter([-1, 1, 0x7fffffffffffffff, -0x8000000000000000]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(LongArrayParameter.deserialize(BytesIO(serialized_param), 4), param1) + + with self.assertWarns(UserWarning): + param2 = LongArrayParameter(ndarray(shape=[2, 2], dtype=int64, + buffer=array([int64(val) for val in [-1, 1, 0x7fffffffffffffff, -0x8000000000000000]]))) + self.assertEqual(param1, param2) # an int array parameter is not the same as a long array parameter param1 = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000]) param2 = LongArrayParameter([-1, 1, 0x7fffffff, -0x80000000]) self.assertNotEqual(param1, param2) + # verify that a long array parameter based on a ndarray filled with int32s works + with self.assertWarns(UserWarning): + param1 = LongArrayParameter(ndarray(shape=[1, 4], dtype=int32, + buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]]))) + self.assertEqual(param1, param2) + with self.assertRaises(TypeError): LongArrayParameter([1, 256, 1.0]) with self.assertRaises(TypeError): LongArrayParameter(1) with self.assertRaises(ValueError): LongArrayParameter([]) + with self.assertRaises(ValueError): + LongArrayParameter(ndarray(shape=[0], dtype=int64, buffer=array([]))) def test_short_parameter(self): serialized_param = b'\x00\x00\x01\x00\xff\xff\xff\x00\x00\x01\x00\x80\xff\x7f' - param = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767]) - self.assertEqual(serialized_param, param.serialize()) - self.assertEqual(ShortArrayParameter.deserialize(BytesIO(serialized_param), 7), param) + param1 = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767]) + self.assertEqual(serialized_param, param1.serialize()) + self.assertEqual(ShortArrayParameter.deserialize(BytesIO(serialized_param), 7), param1) + + param2 = ShortArrayParameter(ndarray(shape=[7], dtype=int16, + buffer=array([int16(val) for val in [0, 1, -1, 255, 256, -32768, 32767]]))) + self.assertEqual(param1, param2) with self.assertRaises(TypeError): ShortArrayParameter([1, 256, 1.0]) + with self.assertRaises(TypeError): + ShortArrayParameter(ndarray(shape=[4], dtype=int32, + buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]]))) + with self.assertRaises(ValueError): + ShortArrayParameter(ndarray(shape=[0], dtype=int16, buffer=array([]))) with self.assertRaises(TypeError): ShortArrayParameter(1) with self.assertRaises(ValueError): @@ -159,5 +250,3 @@ def test_string_parameter(self): StringParameter(['The', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'lazy', 'dog']) with self.assertRaises(ValueError): StringParameter(None) - - diff --git a/tests/test_parametermap.py b/tests/test_parametermap.py index 676c980..6e184c5 100644 --- a/tests/test_parametermap.py +++ b/tests/test_parametermap.py @@ -1,6 +1,6 @@ from unittest import TestCase -from trsfile.parametermap import TraceSetParameterMap, TraceParameterDefinitionMap, TraceParameterMap +from trsfile.parametermap import TraceSetParameterMap, TraceParameterDefinitionMap, TraceParameterMap, RawTraceData from trsfile.standardparameters import StandardTraceSetParameters, StandardTraceParameters from trsfile.traceparameter import * @@ -76,9 +76,15 @@ def test_add_standard_parameter(self): param_map1 = TraceSetParameterMap() param_map1.add_standard_parameter(StandardTraceSetParameters.KEY, bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map1.add_standard_parameter(StandardTraceSetParameters.TVLA_CIPHER, "AES") param_map2 = TraceSetParameterMap() param_map2.add_parameter('KEY', bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map2.add_parameter('TVLA:CIPHER', "AES") + param_map3 = TraceSetParameterMap() + param_map3.add('key', bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map3.add('tvla_cipher', "AES") self.assertDictEqual(param_map1, param_map2) + self.assertDictEqual(param_map1, param_map3) # Verify that standard trace set parameters enforce a specific type with self.assertRaises(TypeError): @@ -103,6 +109,14 @@ def create_parameterdefinitionmap() -> TraceParameterDefinitionMap: param_map['中文'] = TraceParameterDefinition(ParameterType.STRING, 15, 29) return param_map + @staticmethod + def create_std_parameterdefinitionmap() -> TraceParameterDefinitionMap: + param_map = TraceParameterDefinitionMap() + param_map['INPUT'] = TraceParameterDefinition(ParameterType.BYTE, 16, 0) + param_map['OUTPUT'] = TraceParameterDefinition(ParameterType.BYTE, 16, 16) + param_map['KEY'] = TraceParameterDefinition(ParameterType.BYTE, 16, 32) + return param_map + @staticmethod def create_traceparametermap() -> TraceParameterMap: param_map = TraceParameterMap() @@ -128,6 +142,34 @@ def test_from_trace_params(self): map_from_trace_params = TraceParameterDefinitionMap.from_trace_parameter_map(param_map) self.assertDictEqual(self.create_parameterdefinitionmap(), map_from_trace_params) + def test_append(self): + map_from_append = TraceParameterDefinitionMap() + map_from_append.append('IN', ParameterType.BYTE, 16) + map_from_append.append('TITLE', ParameterType.STRING, 13) + map_from_append.append('中文', ParameterType.STRING, 15) + self.assertDictEqual(self.create_parameterdefinitionmap(), map_from_append) + + map_from_std_append = TraceParameterDefinitionMap() + map_from_std_append.append_std('INPUT', 16) + map_from_std_append.append_std('OUTPUT', 16) + map_from_std_append.append_std('KEY', 16) + self.assertDictEqual(self.create_std_parameterdefinitionmap(), map_from_std_append) + + def test_insert(self): + map_from_insert = TraceParameterDefinitionMap() + map_from_insert.insert('TITLE', ParameterType.STRING, 13, 0) + with self.assertWarns(UserWarning): + map_from_insert.insert('中文', ParameterType.STRING, 15, 10) + map_from_insert.insert('IN', ParameterType.BYTE, 16, 0) + self.assertDictEqual(self.create_parameterdefinitionmap(), map_from_insert) + + map_from_std_insert = TraceParameterDefinitionMap() + map_from_std_insert.insert_std('INPUT', 16, 0) + with self.assertWarns(UserWarning): + map_from_std_insert.insert_std('KEY', 16, 9) + map_from_std_insert.insert_std('OUTPUT', 16, 16) + self.assertDictEqual(self.create_std_parameterdefinitionmap(), map_from_std_insert) + class TestTraceParameterMap(TestCase): CAFEBABE = bytes.fromhex('cafebabedeadbeef0102030405060708') @@ -199,9 +241,15 @@ def test_add_standard_parameter(self): param_map1 = TraceParameterMap() param_map1.add_standard_parameter(StandardTraceParameters.INPUT, bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map1.add_standard_parameter(StandardTraceParameters.TVLA_SET_INDEX, 1) param_map2 = TraceParameterMap() param_map2.add_parameter('INPUT', bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map2.add_parameter('TVLA:SET_INDEX', 1) + param_map3 = TraceParameterMap() + param_map3.add("input", bytes.fromhex('cafebabedeadbeef0102030405060708')) + param_map3.add("tvla_set_index", 1) self.assertDictEqual(param_map1, param_map2) + self.assertDictEqual(param_map1, param_map3) # Verify that standard trace parameters enforce a specific type with self.assertRaises(TypeError): @@ -210,3 +258,29 @@ def test_add_standard_parameter(self): # However, this type check only produces a warning with self.assertWarns(UserWarning): param_map1.add_parameter('INPUT', 'cafebabedeadbeef0102030405060708') + + def test_raw_trace_data(self): + raw_data = RawTraceData(bytes.fromhex('cafebabedeadbeef0102030405060708')) + assert raw_data.serialize() == bytes.fromhex('cafebabedeadbeef0102030405060708') + + # Verify that nothing can be added into a raw data TraceParameterMap + with self.assertRaises(KeyError): + raw_data['INPUT'] = ByteArrayParameter(bytes.fromhex('cafebabedeadbeef0102030405060708')) + with self.assertRaises(KeyError): + raw_data.add('input', bytes.fromhex('cafebabedeadbeef0102030405060708')) + + # Verify that raw data can match any traceparameterdefinition map, as long as the length is correct + traceParamDefs = TraceParameterDefinitionMap() + traceParamDefs.append("INPUT", ParameterType.BYTE, 16) + assert raw_data.matches(traceParamDefs) + + traceParamDefs = TraceParameterDefinitionMap() + traceParamDefs.append("INPUT", ParameterType.BYTE, 8) + traceParamDefs.append("OUTPUT", ParameterType.BYTE, 8) + assert raw_data.matches(traceParamDefs) + traceParamDefs.append("KEY", ParameterType.BYTE, 8) + assert not raw_data.matches(traceParamDefs) + + with self.assertWarns(UserWarning): + TraceParameterDefinitionMap.from_trace_parameter_map(raw_data) + diff --git a/trsfile/compatibility.py b/trsfile/compatibility.py new file mode 100644 index 0000000..72490bd --- /dev/null +++ b/trsfile/compatibility.py @@ -0,0 +1,48 @@ +import functools + + +class alias: + """ + A decorator for implementing method aliases. + """ + + def __init__(self, *aliases): + self.aliases = set(aliases) + + def __call__(self, obj): + if type(obj) == property: + obj.fget._aliases = self.aliases + else: + obj._aliases = self.aliases + + return obj + + +def aliased(aliased_class): + """ + A decorator for enabling method aliases. + """ + def wrapper(func, name): + @functools.wraps(func) + def inner(*args, **kwargs): + return func(*args, **kwargs) + return inner + + aliased_class_dict = aliased_class.__dict__.copy() + aliased_class_set = set(aliased_class_dict) + + for name, method in aliased_class_dict.items(): + aliases = None + + if (type(method) == property) and hasattr(method.fget, '_aliases'): + aliases = method.fget._aliases + elif hasattr(method, '_aliases'): + aliases = method._aliases + + if aliases: + for method_alias in aliases - aliased_class_set: + wrapped_method = wrapper(method, method_alias) + wrapped_method.__doc__ = str(f"{method_alias} is an alias of {name}.") + setattr(aliased_class, method_alias, wrapped_method) + + return aliased_class \ No newline at end of file diff --git a/trsfile/engine/trs.py b/trsfile/engine/trs.py index 611948d..80e3b9f 100644 --- a/trsfile/engine/trs.py +++ b/trsfile/engine/trs.py @@ -173,17 +173,18 @@ def update_headers_with_traces_metadata(self, traces: List[Trace]) -> None: if len(set([len(trace.parameters.serialize()) for trace in traces])) > 1: raise TypeError('Traces have different data length, this is not supported in TRS files') - data_length = len(traces[0].parameters.serialize()) + headers_updates[Header.LENGTH_DATA] = len(traces[0].parameters.serialize()) - # Add a TraceParameterDefinitionMap if none is present, and verify its validity if one is present - if Header.TRACE_PARAMETER_DEFINITIONS not in self.headers: - if data_length > 0: - headers_updates[Header.TRACE_PARAMETER_DEFINITIONS] = \ - TraceParameterDefinitionMap.from_trace_parameter_map(traces[0].parameters) - elif not traces[0].parameters.matches(self.headers[Header.TRACE_PARAMETER_DEFINITIONS]): - raise TypeError("The traces' parameters do not match the trace set's definitions") - - headers_updates[Header.LENGTH_DATA] = data_length + # Add a TraceParameterDefinitionMap if none is present, and verify its validity if one is present + if Header.TRACE_PARAMETER_DEFINITIONS not in self.headers: + headers_updates[Header.TRACE_PARAMETER_DEFINITIONS] = \ + TraceParameterDefinitionMap.from_trace_parameter_map(traces[0].parameters) + else: + for index, trace in enumerate(traces): + if not trace.parameters.matches(self.headers[Header.TRACE_PARAMETER_DEFINITIONS]): + raise TypeError(f"The parameters of trace #{index} do not match the trace set's definitions.\n" + f"Please make sure the trace parameters match those of the other traces in type, " + f"size and name.") if self.headers[Header.SAMPLE_CODING] is None: if len(set([trace.sample_coding for trace in traces])) > 1: diff --git a/trsfile/parametermap.py b/trsfile/parametermap.py index 41ad127..a0b9412 100644 --- a/trsfile/parametermap.py +++ b/trsfile/parametermap.py @@ -4,6 +4,7 @@ import warnings from typing import Any, Union, List, Dict +from trsfile.compatibility import alias, aliased from trsfile.standardparameters import StandardTraceSetParameters, StandardTraceParameters from trsfile.traceparameter import TraceSetParameter, TraceParameter, TraceParameterDefinition, ParameterType, \ BooleanArrayParameter, ByteArrayParameter, StringParameter, DoubleArrayParameter, IntegerArrayParameter, \ @@ -16,6 +17,7 @@ INT_MIN = -2**31 INT_MAX = 2**31-1 + class ParameterMapUtil: # A placeholder for integers that are actually shorts class ShortType(numbers.Rational): @@ -175,6 +177,7 @@ def lock_content(self): self._is_locked = True +@aliased class TraceSetParameterMap(LockableDict): default_values = { StandardTraceSetParameters.DISPLAY_HINT_X_LABEL: "", @@ -195,7 +198,8 @@ def __setitem__(self, key: str, value: Union[TraceParameter, TraceSetParameter]) self._stop_if_locked() super().__setitem__(key, value) - def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> None: + @alias("add") + def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> TraceSetParameterMap: """Add a trace set parameter with a given name and value. If the name matches the identifier of a standard trace set parameter, then the value's type should be the type that standard trace set parameter expects. @@ -204,11 +208,13 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - :param name: The name of the parameter that will be added :param value: The value of the parameter. If the name matches that of a standard trace set parameter, it is recommended that the type of the value matches that of standard trace set parameter. Otherwise, - valid types are: int, float, bool, List[int], List[float], List[bool], bytes, bytearray or str""" + valid types are: int, float, bool, List[int], List[float], List[bool], bytes, bytearray or str + :return: This TraceSetParameterMap after adding the new parameter""" try: std_param = StandardTraceSetParameters.from_identifier(name) typed_param = std_param.parameter_type.param_class - self[name] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + self[std_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + return self # if no std_param can be found, a ValueError is raised, # if adding the trace parameter to the map fails, a TypeError is raised except (ValueError, TypeError) as e: @@ -219,29 +225,34 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - typed_param = ParameterMapUtil.get_typed_parameter(value) self[name] = typed_param(ParameterMapUtil.to_list_if_listable(value)) - def add_standard_parameter(self, std_trace_set_param: StandardTraceSetParameters, value: ParameterMapUtil.ParameterValueType) -> None: + def add_standard_parameter(self, std_trace_set_param: StandardTraceSetParameters, + value: ParameterMapUtil.ParameterValueType) -> TraceSetParameterMap: """Add a standard trace set parameter with a given value. If the parameter already exists within the map, it will be overwritten. :param std_trace_set_param: The standard trace set parameter that will be added :param value: The value of the parameter. The type this value must have depends on - the standard trace set parameter.""" + the standard trace set parameter. + :return: This TraceSetParameterMap after adding the new standard parameter""" typed_param = std_trace_set_param.parameter_type.param_class self[std_trace_set_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + return self - def fill_from_headers(self, headers: Dict['Header', Any]) -> None: + def fill_from_headers(self, headers: Dict['Header', Any]) -> TraceSetParameterMap: """Add to this trace set parameter map all data that is in the header and for which standard trace set parameters exist. Data that already exists in the map will not be overwritten. - :param headers: The headers dictionary from which data will be copied into the trace set parameter map""" + :param headers: The headers dictionary from which data will be copied into the trace set parameter map + :return: This TraceSetParameterMap after adding the parameters based on the headers""" for header_tag, value in headers.items(): std_param = header_tag.equivalent_std_param if std_param is not None and std_param.identifier not in self: self.add_standard_parameter(std_param, value) - def add_defaults(self) -> None: - """If specific standard trace set parameters don't exist yet in the map, add them with default values""" + def add_defaults(self) -> TraceSetParameterMap: + """If specific standard trace set parameters don't exist yet in the map, add them with default values + :return: This TraceSetParameterMap after adding the default parameters""" for key, value in TraceSetParameterMap.default_values.items(): if key.identifier not in self: self.add_standard_parameter(key, value) @@ -289,6 +300,82 @@ def __setitem__(self, key: str, value: TraceParameterDefinition): self._stop_if_locked() super().__setitem__(key, value) + def insert_std(self, name: str, size: int, offset: int) -> TraceParameterDefinitionMap: + """Insert a trace parameter definition of a StandardTraceParameter into this map in a specified location. If + the given offset would put the new TraceParameter in the middle of a parameter already present in the map, the + offset is increased to put the new parameter after that existing one instead. Any parameters already present + in the map that have a greater or equal offset than the new parameter, have their offset increased to make space + for the new parameter. + + :param name: The name of the TraceParameter for which to add a definition. This name must match that of a + StandardTraceParameter + :param size: The size of the TraceParameter, in number of elements + :param offset: The offset of the TraceParameter, in bytes + :return: This TraceParameterDefinition map after adding the new definition + """ + try: + type = StandardTraceParameters.from_identifier(name).parameter_type + return self.insert(name, type, size, offset) + except ValueError: + raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either use the 'insert' method or " + f"correct the name to match a standard trace parameter.") + + def insert(self, name: str, type: ParameterType, size: int, offset: int) -> TraceParameterDefinitionMap: + """Insert a trace parameter definition into this map in a specified location. If the given offset would put the + new TraceParameter in the middle of a parameter already present in the map, the offset is increased to put the + new parameter after that existing one instead. Any parameters already present in the map that have a greater + or equal offset than the new parameter, have their offset increased to make space for the new parameter. + + :param name: The name of the TraceParameter for which to add a definition + :param type: The type of the TraceParameter for which to add a definition + :param size: The size of the TraceParameter, in number of elements + :param offset: The offset of the TraceParameter, in bytes + :return: This TraceParameterDefinition map after adding the new definition + """ + self._stop_if_locked() + params_to_move_back = [] + for key, param in self.items(): + if param.offset >= offset: + param.offset += size * type.byte_size + params_to_move_back.append(key) + elif param.offset + param.length * param.param_type.byte_size > offset: + offset = param.offset + param.length * param.param_type.byte_size + warnings.warn("Given offset would put a parameter inside another trace parameter.\n" + f"Increased the offset of the inserted parameter definition to {offset} to prevent this.") + + new_definition = TraceParameterDefinition(type, size, offset) + self.__setitem__(name, new_definition) + for param in params_to_move_back: + self.move_to_end(param) + return self + + def append_std(self, name: str, size: int) -> TraceParameterDefinitionMap: + """Append a trace parameter definition of a StandardTraceParameter to this map. The parameter wil be added after + all parameter definitions already in the map. + + :param name: The name of the TraceParameter for which to add a definition. This name must match that of a + StandardTraceParameter + :param size: The size of the TraceParameter, in number of elements + :return: This TraceParameterDefinition map after adding the new definition""" + try: + type = StandardTraceParameters.from_identifier(name).parameter_type + return self.append(name, type, size) + except ValueError: + raise ValueError(f"No StandardTraceParameter found with name '{name}'. Either use the 'append' method or " + f"correct the name to match a standard trace parameter.") + + def append(self, name: str, type: ParameterType, size: int) -> TraceParameterDefinitionMap: + """Append a trace parameter definition to this map. The parameter wil be added after all parameter definitions + already in the map. + + :param name: The name of the TraceParameter for which to add a definition + :param type: The type of the TraceParameter for which to add a definition + :param size: The size of the TraceParameter, in number of elements + :return: This TraceParameterDefinition map after adding the new definition""" + new_definition = TraceParameterDefinition(type, size, self.get_total_size()) + self.__setitem__(name, new_definition) + return self + @staticmethod def deserialize(raw: BytesIO) -> TraceParameterDefinitionMap: result = TraceParameterDefinitionMap() @@ -316,6 +403,12 @@ def from_trace_parameter_map(trace_parameters: TraceParameterMap) -> TraceParame :param trace_parameters: The trace parameter map from which the definitions will be deduced :return: A parameter definition map that described the metadata of the input trace parameter map""" + if isinstance(trace_parameters, RawTraceData): + warnings.warn("Creating a trace parameter definition map from raw trace data.\nThis is not recommended, " + "as it will not add any meta information about the trace data.\nEither manually define a " + "TraceParameterDefinitionMap for the traceset or make sure the first trace you add to the " + "traceset has a proper TraceParameterMap") + offset = 0 result = TraceParameterDefinitionMap() for key, trace_param in trace_parameters.items(): @@ -326,6 +419,7 @@ def from_trace_parameter_map(trace_parameters: TraceParameterMap) -> TraceParame return result +@aliased class TraceParameterMap(StringKeyOrderedDict): def __setitem__(self, key: str, value: TraceParameter): if not isinstance(value, TraceParameter): @@ -333,7 +427,8 @@ def __setitem__(self, key: str, value: TraceParameter): ' of TraceParameter (e.g. ByteArrayParameter).') super().__setitem__(key, value) - def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> None: + @alias("add") + def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) -> TraceParameterMap: """Add a trace parameter with a given name and value If the name matches the identifier of a standard trace parameter, then the value's type should be the type that standard trace parameter expects. @@ -342,11 +437,12 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - :param name: The name of the parameter that will be added :param value: The value of the parameter. If the name matches that of a standard trace parameter, it is recommended that the type of the value matches that of standard trace set parameter. Otherwise, - valid types are: int, float, bool, List[int], List[float], List[bool], bytes, bytearray or str""" + valid types are: int, float, bool, List[int], List[float], List[bool], bytes, bytearray or str + :return: This map after adding the parameter""" try: std_param = StandardTraceParameters.from_identifier(name) typed_param = std_param.parameter_type.param_class - self[name] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + self[std_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) # if no std_param can be found, a ValueError is raised, # if adding the trace parameter to the map fails, a TypeError is raised except (TypeError, ValueError) as e: @@ -356,16 +452,20 @@ def add_parameter(self, name: str, value: ParameterMapUtil.ParameterValueType) - "behavior when displaying this traceset or processing this trace in Inspector") typed_param = ParameterMapUtil.get_typed_parameter(value) self[name] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + return self - def add_standard_parameter(self, std_trace_param: StandardTraceParameters, value: ParameterMapUtil.ParameterValueType) -> None: + def add_standard_parameter(self, std_trace_param: StandardTraceParameters, + value: ParameterMapUtil.ParameterValueType) -> TraceParameterMap: """Add a standard trace parameter with a given value. If the parameter already exists within the map, it will be overwritten. :param std_trace_param: The standard trace parameter that will be added :param value: The value of the parameter. The type this value must have depends on - the standard trace parameter.""" + the standard trace parameter. + :return: This map after adding the parameter""" typed_param = std_trace_param.parameter_type.param_class self[std_trace_param.identifier] = typed_param(ParameterMapUtil.to_list_if_listable(value)) + return self @staticmethod def deserialize(raw: bytes, definitions: TraceParameterDefinitionMap) -> TraceParameterMap: @@ -409,3 +509,20 @@ def matches(self, definitions: TraceParameterDefinitionMap) -> bool: break match &= matched_keys == list(definitions.keys()) return match + + +class RawTraceData(TraceParameterMap): + def __init__(self, data: bytes): + super().__init__() + super().__setitem__("LEGACY_DATA", ByteArrayParameter(data)) + + def __setitem__(self, key: str, value: TraceParameter): + raise KeyError("Adding Trace Parameters into raw trace data is not allowed") + + def matches(self, definitions: TraceParameterDefinitionMap) -> bool: + """Test whether this RawTraceData could be interpreted by given definitions + + :param definitions: The trace parameter definition map to check this raw trace data against + + :return: A boolean that is true if the trace parameter definitions can interpret this raw trace data""" + return definitions.get_total_size() == len(self["LEGACY_DATA"]) diff --git a/trsfile/standardparameters.py b/trsfile/standardparameters.py index b58d266..cfba128 100644 --- a/trsfile/standardparameters.py +++ b/trsfile/standardparameters.py @@ -19,7 +19,7 @@ def __new__(cls, tag: int, identifier: str, parameter_type: ParameterType): @staticmethod def from_identifier(identifier: str) -> StandardTraceSetParameters: for val in StandardTraceSetParameters: - if identifier == val.identifier: + if identifier.lower() == val.identifier.lower() or identifier.lower() == val.name.lower(): return val raise ValueError(f'{identifier} is not an identifier of a StandardTraceSetParameter') @@ -79,7 +79,7 @@ def __new__(cls, tag: int, identifier: str, parameter_type: ParameterType): @staticmethod def from_identifier(identifier: str) -> StandardTraceParameters: for val in StandardTraceParameters: - if identifier == val.identifier: + if identifier.lower() == val.identifier.lower() or identifier.lower() == val.name.lower(): return val raise ValueError('{} is not a name of a StandardTraceParameter'.format(identifier)) diff --git a/trsfile/trace.py b/trsfile/trace.py index 57e8a14..4169ba8 100644 --- a/trsfile/trace.py +++ b/trsfile/trace.py @@ -1,8 +1,9 @@ import numpy -from trsfile.parametermap import TraceParameterMap +from trsfile.parametermap import TraceParameterMap, RawTraceData from trsfile.common import Header, SampleCoding + class Trace: """The :py:obj:`Trace` class behaves like a :py:obj:`list` object were each item in the list is a sample of the trace. @@ -12,11 +13,35 @@ class Trace: provided :py:obj:`sample_coding`. """ - def __init__(self, sample_coding, samples, parameters=TraceParameterMap(), title='trace', headers={}): + def __init__(self, sample_coding, samples, parameters=None, title='trace', headers=None, raw_data: bytes = bytes()): + """ Create a new Trace. + :param sample_coding: The encoding of all samples in the trace + :param samples: The array of samples of the trace + :param parameters: The trace parameter map that contains the trace's data and its meta information. Do not use + in combination with raw_data + :param title: The title of the trace + :param headers: The headers of the trs file to which this trace will be added. This is an optional parameter; + information from the headers may define the locations of the input, output and key data in the trace data, + but it is recommended to store that information in the trace parameter map now. + :param raw_data: A byte array with the raw trace data. Do not use in combination with parameters. If used, it is + recommended that a TraceParameterDefinitionMap is added to the headers of the trsfile that defines the meta + information of this raw trace data. + """ + if parameters is None: + if len(raw_data) > 0: + parameters = RawTraceData(raw_data) + else: + parameters = TraceParameterMap() + else: + if len(raw_data) > 0: + raise Warning("Parameter map and raw data were both defined, but cannot both be used at the same time.\n" + "Only the parameter map will be used, raw data will be discarded.") + if headers is None: + headers = {} self.title = title self.parameters = parameters - if not type(self.parameters) is TraceParameterMap: - raise TypeError('Trace parameter data must be supplied as a TraceParameterMap') + if not isinstance(self.parameters, TraceParameterMap): + raise TypeError('Trace data must be supplied as a TraceParameterMap') # Obtain sample coding if not isinstance(sample_coding, SampleCoding): diff --git a/trsfile/traceparameter.py b/trsfile/traceparameter.py index 469f0ad..494115d 100644 --- a/trsfile/traceparameter.py +++ b/trsfile/traceparameter.py @@ -1,6 +1,7 @@ from __future__ import annotations import struct +import warnings from abc import ABC, abstractmethod from enum import Enum from io import BytesIO @@ -37,7 +38,11 @@ def _has_expected_type(value: Any) -> bool: pass def __init__(self, value): - if type(value) is not str and (value is None or len(value) <= 0): + if type(value) is ndarray and len(value.shape) > 1: + warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n" + "Information about dimensions of this ndarray will be lost.") + value = value.flatten() + if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0): raise ValueError('The value for a TraceParameter cannot be empty') if not type(self)._has_expected_type(value): raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"' @@ -48,7 +53,14 @@ def __len__(self): return len(self.value) def __eq__(self, other): - return isinstance(other, type(self)) and self.value == other.value + if not isinstance(other, type(self)): + return False + + if len(self.value) != len(other.value): + return False + # return true only if both parameter value arrays contain the same elements in the same order, + # regardless of whether it is an ndarray or a list + return all(this_val == that_val for (this_val, that_val) in zip(self.value, other.value)) def __str__(self): return str(self.value)