Skip to content

Commit

Permalink
Merge pull request #38 from Riscure/ease-of-use-improvements
Browse files Browse the repository at this point in the history
Ease of use improvements
  • Loading branch information
TomHogervorst authored Nov 17, 2022
2 parents d27cf28 + a422c0b commit 4f6fa8d
Show file tree
Hide file tree
Showing 10 changed files with 512 additions and 59 deletions.
87 changes: 85 additions & 2 deletions tests/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shutil

from trsfile import Trace, SampleCoding, Header, TracePadding
from trsfile.parametermap import TraceParameterMap, TraceParameterDefinitionMap, TraceSetParameterMap
from trsfile.parametermap import TraceParameterMap, TraceParameterDefinitionMap, TraceSetParameterMap, RawTraceData
from trsfile.standardparameters import StandardTraceSetParameters
from trsfile.traceparameter import ByteArrayParameter, TraceParameterDefinition, ParameterType, StringParameter, \
IntegerArrayParameter, BooleanArrayParameter, FloatArrayParameter
Expand Down Expand Up @@ -178,6 +178,20 @@ def test_header_to_trace_set_params(self):
except Exception as e:
self.fail('Exception occurred: ' + str(e))

def test_write_different_trace_sizes(self):
trace_count = 100
sample_count = 1000

with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces:
trs_traces.extend([
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
RawTraceData(i.to_bytes(8, byteorder='big'))
)
for i in range(0, trace_count)]
)

def test_write_closed(self):
trace_count = 100
sample_count = 1000
Expand All @@ -196,6 +210,75 @@ def test_write_closed(self):
with self.assertRaises(ValueError):
print(trs_traces)

def test_write_different_trace_sizes(self):
trace_count = 100
sample_count = 1000

with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces:
trs_traces.extend([
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(i.to_bytes(8, byteorder='big'))})
)
for i in range(0, trace_count)]
)
with self.assertRaises(TypeError):
# The length is incorrect
# Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions.
trs_traces.extend([
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('cafebabedeadbeef0102030405060708'))})
)]
)
with self.assertRaises(TypeError):
# The name is incorrect
# Should raise a Type error: The parameters of trace #1 do not match the trace set's definitions.
trs_traces.extend([
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('0102030405060708'))})
),
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap({'NEW_DATA': ByteArrayParameter(bytes.fromhex('0102030405060708'))})
)]
)
with self.assertRaises(TypeError):
# The type is incorrect
# Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions.
trs_traces.extend([
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap({'LEGACY_DATA': IntegerArrayParameter([42, 74])})
)]
)

with trsfile.open(self.tmp_path, 'w', padding_mode=TracePadding.AUTO) as trs_traces:
trs_traces.extend([
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap()
)
for i in range(0, trace_count)]
)
with self.assertRaises(TypeError):
# The length, data and name are incorrect
# Should raise a Type error: The parameters of trace #0 do not match the trace set's definitions.
trs_traces.extend([
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(bytes.fromhex('cafebabedeadbeef0102030405060708'))})
)]
)

def test_read(self):
trace_count = 100
sample_count = 1000
Expand Down Expand Up @@ -270,7 +353,7 @@ def test_append(self):
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap({'LEGACY_DATA': ByteArrayParameter(i.to_bytes(8, byteorder='big'))})
raw_data=i.to_bytes(8, byteorder='big')
)
for i in range(0, trace_count)]
)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def test_trace_param_defs_append_errors(self):
trace_parameter_definitions.popitem()
with self.assertRaises(TypeError):
trace_parameter_definitions.clear()
with self.assertRaises(TypeError):
trace_parameter_definitions.append('input', ParameterType.BYTE, 16)
with self.assertRaises(TypeError):
trace_parameter_definitions.insert('output', ParameterType.BYTE, 16, 0)

# Shallow copies still share references to the same trace set parameters as the original,
# and should therefore not be modifiable if the original isn't
Expand Down
137 changes: 113 additions & 24 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from io import BytesIO
from unittest import TestCase

from numpy import ndarray, int16, array, int32, int64, single, double, uint8, int8, uint16, bool8

from trsfile.traceparameter import BooleanArrayParameter, ByteArrayParameter, DoubleArrayParameter, FloatArrayParameter, \
IntegerArrayParameter, ShortArrayParameter, LongArrayParameter, StringParameter


class TestParameter(TestCase):
def test_bool_parameter(self):
serialized_param = b'\x01\x00\x01'
param = BooleanArrayParameter([True, False, True])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(BooleanArrayParameter.deserialize(BytesIO(serialized_param), 3), param)
param1 = BooleanArrayParameter([True, False, True])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(BooleanArrayParameter.deserialize(BytesIO(serialized_param), 3), param1)
param2 = BooleanArrayParameter(ndarray(shape=[3], dtype=bool8,
buffer=array([bool8(val) for val in [True, False, True]])))
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
BooleanArrayParameter(True)
Expand All @@ -27,6 +32,11 @@ def test_byte_parameter(self):
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(ByteArrayParameter.deserialize(BytesIO(serialized_param), 16), param1)

with self.assertWarns(UserWarning):
param2 = ByteArrayParameter(ndarray(shape=[2, 2, 4], dtype=uint8,
buffer=array([uint8(val) for val in int_data])))
self.assertEqual(param1, param2)

param2 = ByteArrayParameter(bytearray(int_data))
self.assertEqual(param1, param2)

Expand All @@ -37,8 +47,14 @@ def test_byte_parameter(self):
ByteArrayParameter([0, '1'])
with self.assertRaises(TypeError):
ByteArrayParameter([bytes([0, 1, 2, 3]), bytes([4, 5, 6, 7])])
with self.assertRaises(TypeError):
ByteArrayParameter(ndarray(shape=[16], dtype=int8, buffer=array([int8(val) for val in int_data])))
with self.assertRaises(TypeError):
ByteArrayParameter(ndarray(shape=[16], dtype=uint16, buffer=array([uint16(val) for val in int_data])))
with self.assertRaises(ValueError):
ByteArrayParameter([])
with self.assertRaises(ValueError):
ByteArrayParameter(ndarray(shape=[0], dtype=uint8, buffer=array([])))
with self.assertRaises(TypeError):
ByteArrayParameter([0, 1, 2, -1])
with self.assertRaises(TypeError):
Expand All @@ -47,13 +63,26 @@ def test_byte_parameter(self):
def test_double_parameter(self):
serialized_param = b'\x00\x00\x00\x00\x00\x00\xe0\xbf\x00\x00\x00\x00\x00\x00\xe0\x3f' \
b'\x00\x00\x00\x00\x80\x84\x2e\x41'
param = DoubleArrayParameter([-0.5, 0.5, 1e6])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(DoubleArrayParameter.deserialize(BytesIO(serialized_param), 3), param)
param1 = DoubleArrayParameter([-0.5, 0.5, 1e6])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(DoubleArrayParameter.deserialize(BytesIO(serialized_param), 3), param1)

param2 = DoubleArrayParameter(ndarray(shape=[3], dtype=double, buffer=array([-0.5, 0.5, 1e6])))
self.assertEqual(param1, param2)

# an array of only integers is still a valid value of a DoubleArrayParameter
param1 = DoubleArrayParameter([1, 2, 1000000])
param2 = DoubleArrayParameter([1, 2.0, 1e6])
param1 = DoubleArrayParameter([-1, 2, 1000000])
param2 = DoubleArrayParameter([-1, 2.0, 1e6])
self.assertEqual(param1, param2)
with self.assertWarns(UserWarning):
param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int32,
buffer=array([int32(val) for val in [-1, 2, 1000000]])))
self.assertEqual(param1, param2)

with self.assertWarns(UserWarning):
param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int64,
buffer=array([int64(val) for val in [-1, 2, 10000000000]])))
param2 = DoubleArrayParameter([-1, 2.0, 1e10])
self.assertEqual(param1, param2)

# a float array parameter is not the same as a double array parameter
Expand All @@ -63,42 +92,82 @@ def test_double_parameter(self):

with self.assertRaises(TypeError):
DoubleArrayParameter([0.5, -0.5, 'NaN'])
with self.assertRaises(TypeError):
DoubleArrayParameter(ndarray(shape=[3], dtype=single,
buffer=array([single(val) for val in [-0.5, 0.5, 1e6]])))
with self.assertRaises(TypeError):
DoubleArrayParameter(0.5)
with self.assertRaises(ValueError):
DoubleArrayParameter([])
with self.assertRaises(ValueError):
IntegerArrayParameter(ndarray(shape=[0], dtype=double, buffer=array([])))

def test_float_parameter(self):
serialized_param = b'\x00\x00\x00\xbf\x00\x00\x00\x3f\x00\x24\x74\x49'
param = FloatArrayParameter([-0.5, 0.5, 1e6])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(FloatArrayParameter.deserialize(BytesIO(serialized_param), 3), param)
param1 = FloatArrayParameter([-0.5, 0.5, 1e6])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(FloatArrayParameter.deserialize(BytesIO(serialized_param), 3), param1)

param2 = FloatArrayParameter(ndarray(shape=[3], dtype=single,
buffer=array([single(val) for val in [-0.5, 0.5, 1e6]])))
self.assertEqual(param1, param2)

# an array of only integers is still a valid value of a FloatArrayParameter
param1 = FloatArrayParameter([1, 2, 1000000])
param2 = FloatArrayParameter([1, 2.0, 1e6])
param1 = FloatArrayParameter([-1, 2, 1000000])
param2 = FloatArrayParameter([-1, 2.0, 1e6])
self.assertEqual(param1, param2)

with self.assertWarns(UserWarning):
param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int32,
buffer=array([int32(val) for val in [-1, 2, 1000000]])))
self.assertEqual(param1, param2)

with self.assertWarns(UserWarning):
param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int64,
buffer=array([int64(val) for val in [-1, 2, 10000000000]])))
param2 = FloatArrayParameter([-1, 2.0, 1e10])
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
FloatArrayParameter([0.5, -0.5, 'NaN'])
with self.assertRaises(TypeError):
FloatArrayParameter(ndarray(shape=[3], dtype=double,
buffer=array([double(val) for val in [-0.5, 0.5, 1e6]])))
with self.assertRaises(TypeError):
FloatArrayParameter(0.5)
with self.assertRaises(ValueError):
FloatArrayParameter([])
with self.assertRaises(ValueError):
IntegerArrayParameter(ndarray(shape=[0], dtype=single, buffer=array([])))

def test_integer_parameter(self):
serialized_param = b'\xff\xff\xff\xff\x01\x00\x00\x00\xff\xff\xff\x7f\x00\x00\x00\x80'
param = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(IntegerArrayParameter.deserialize(BytesIO(serialized_param), 4), param)
param1 = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(IntegerArrayParameter.deserialize(BytesIO(serialized_param), 4), param1)

with self.assertWarns(UserWarning):
param2 = IntegerArrayParameter(ndarray(shape=[2, 2], dtype=int32,
buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]])))
self.assertEqual(param1, param2)

# a short array parameter is not the same as an int array parameter
param1 = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767])
param2 = IntegerArrayParameter([0, 1, -1, 255, 256, -32768, 32767])
self.assertNotEqual(param1, param2)

# verify that an integer array parameter based on a ndarray filled with int16s works
param1 = IntegerArrayParameter(ndarray(shape=[7], dtype=int16,
buffer=array([int16(val) for val in [0, 1, -1, 255, 256, -32768, 32767]])))
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
IntegerArrayParameter([1, 256, 1.0])
with self.assertRaises(TypeError):
IntegerArrayParameter(ndarray(shape=[4], dtype=int64,
buffer=array([int64(val) for val in [-1, 1, 0x7fffffffffffffff, -0x8000000000000000]])))
with self.assertRaises(ValueError):
IntegerArrayParameter(ndarray(shape=[0], dtype=int32, buffer=array([])))
with self.assertRaises(TypeError):
IntegerArrayParameter(1)
with self.assertRaises(ValueError):
Expand All @@ -111,30 +180,52 @@ def test_integer_parameter(self):
def test_long_parameter(self):
serialized_param = b'\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x00\x00\x00\x00\x00\x00' \
b'\xff\xff\xff\xff\xff\xff\xff\x7f\x00\x00\x00\x00\x00\x00\x00\x80'
param = LongArrayParameter([-1, 1, 0x7fffffffffffffff, -0x8000000000000000])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(LongArrayParameter.deserialize(BytesIO(serialized_param), 4), param)
param1= LongArrayParameter([-1, 1, 0x7fffffffffffffff, -0x8000000000000000])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(LongArrayParameter.deserialize(BytesIO(serialized_param), 4), param1)

with self.assertWarns(UserWarning):
param2 = LongArrayParameter(ndarray(shape=[2, 2], dtype=int64,
buffer=array([int64(val) for val in [-1, 1, 0x7fffffffffffffff, -0x8000000000000000]])))
self.assertEqual(param1, param2)

# an int array parameter is not the same as a long array parameter
param1 = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000])
param2 = LongArrayParameter([-1, 1, 0x7fffffff, -0x80000000])
self.assertNotEqual(param1, param2)

# verify that a long array parameter based on a ndarray filled with int32s works
with self.assertWarns(UserWarning):
param1 = LongArrayParameter(ndarray(shape=[1, 4], dtype=int32,
buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]])))
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
LongArrayParameter([1, 256, 1.0])
with self.assertRaises(TypeError):
LongArrayParameter(1)
with self.assertRaises(ValueError):
LongArrayParameter([])
with self.assertRaises(ValueError):
LongArrayParameter(ndarray(shape=[0], dtype=int64, buffer=array([])))

def test_short_parameter(self):
serialized_param = b'\x00\x00\x01\x00\xff\xff\xff\x00\x00\x01\x00\x80\xff\x7f'
param = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(ShortArrayParameter.deserialize(BytesIO(serialized_param), 7), param)
param1 = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(ShortArrayParameter.deserialize(BytesIO(serialized_param), 7), param1)

param2 = ShortArrayParameter(ndarray(shape=[7], dtype=int16,
buffer=array([int16(val) for val in [0, 1, -1, 255, 256, -32768, 32767]])))
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
ShortArrayParameter([1, 256, 1.0])
with self.assertRaises(TypeError):
ShortArrayParameter(ndarray(shape=[4], dtype=int32,
buffer=array([int32(val) for val in [-1, 1, 0x7fffffff, -0x80000000]])))
with self.assertRaises(ValueError):
ShortArrayParameter(ndarray(shape=[0], dtype=int16, buffer=array([])))
with self.assertRaises(TypeError):
ShortArrayParameter(1)
with self.assertRaises(ValueError):
Expand All @@ -159,5 +250,3 @@ def test_string_parameter(self):
StringParameter(['The', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'lazy', 'dog'])
with self.assertRaises(ValueError):
StringParameter(None)


Loading

0 comments on commit 4f6fa8d

Please sign in to comment.