Skip to content

Commit

Permalink
Add ndarray testcases and fixes for issues found
Browse files Browse the repository at this point in the history
Issues were: incorrect comparison between ndarrays and lists, no support for multi-dimensional arrays, and no detection of empty ndarrays.
  • Loading branch information
TomHKeysight committed Nov 4, 2022
1 parent 4217c3d commit b5e6614
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 26 deletions.
136 changes: 112 additions & 24 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from io import BytesIO
from unittest import TestCase

from numpy import ndarray, int16, array, int32, int64, single, double, uint8, int8, uint16, bool8

from trsfile.traceparameter import BooleanArrayParameter, ByteArrayParameter, DoubleArrayParameter, FloatArrayParameter, \
IntegerArrayParameter, ShortArrayParameter, LongArrayParameter, StringParameter


class TestParameter(TestCase):
def test_bool_parameter(self):
serialized_param = b'\x01\x00\x01'
param = BooleanArrayParameter([True, False, True])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(BooleanArrayParameter.deserialize(BytesIO(serialized_param), 3), param)
param1 = BooleanArrayParameter([True, False, True])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(BooleanArrayParameter.deserialize(BytesIO(serialized_param), 3), param1)
param2 = BooleanArrayParameter(ndarray(shape=[3], dtype=bool8,
buffer=array([bool8(val) for val in [True, False, True]])))
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
BooleanArrayParameter(True)
Expand All @@ -27,6 +32,11 @@ def test_byte_parameter(self):
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(ByteArrayParameter.deserialize(BytesIO(serialized_param), 16), param1)

with self.assertWarns(UserWarning):
param2 = ByteArrayParameter(ndarray(shape=[2, 2, 4], dtype=uint8,
buffer=array([uint8(val) for val in int_data])))
self.assertEqual(param1, param2)

param2 = ByteArrayParameter(bytearray(int_data))
self.assertEqual(param1, param2)

Expand All @@ -37,8 +47,14 @@ def test_byte_parameter(self):
ByteArrayParameter([0, '1'])
with self.assertRaises(TypeError):
ByteArrayParameter([bytes([0, 1, 2, 3]), bytes([4, 5, 6, 7])])
with self.assertRaises(TypeError):
ByteArrayParameter(ndarray(shape=[16], dtype=int8, buffer=array([int8(val) for val in int_data])))
with self.assertRaises(TypeError):
ByteArrayParameter(ndarray(shape=[16], dtype=uint16, buffer=array([uint16(val) for val in int_data])))
with self.assertRaises(ValueError):
ByteArrayParameter([])
with self.assertRaises(ValueError):
ByteArrayParameter(ndarray(shape=[0], dtype=uint8, buffer=array([])))
with self.assertRaises(TypeError):
ByteArrayParameter([0, 1, 2, -1])
with self.assertRaises(TypeError):
Expand All @@ -47,13 +63,26 @@ def test_byte_parameter(self):
def test_double_parameter(self):
serialized_param = b'\x00\x00\x00\x00\x00\x00\xe0\xbf\x00\x00\x00\x00\x00\x00\xe0\x3f' \
b'\x00\x00\x00\x00\x80\x84\x2e\x41'
param = DoubleArrayParameter([-0.5, 0.5, 1e6])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(DoubleArrayParameter.deserialize(BytesIO(serialized_param), 3), param)
param1 = DoubleArrayParameter([-0.5, 0.5, 1e6])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(DoubleArrayParameter.deserialize(BytesIO(serialized_param), 3), param1)

param2 = DoubleArrayParameter(ndarray(shape=[3], dtype=double, buffer=array([-0.5, 0.5, 1e6])))
self.assertEqual(param1, param2)

# an array of only integers is still a valid value of a DoubleArrayParameter
param1 = DoubleArrayParameter([1, 2, 1000000])
param2 = DoubleArrayParameter([1, 2.0, 1e6])
param1 = DoubleArrayParameter([-1, 2, 1000000])
param2 = DoubleArrayParameter([-1, 2.0, 1e6])
self.assertEqual(param1, param2)
with self.assertWarns(UserWarning):
param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int32,
buffer=array([-1, 2, 1000000])))
self.assertEqual(param1, param2)

with self.assertWarns(UserWarning):
param1 = DoubleArrayParameter(ndarray(shape=[1, 3], dtype=int64,
buffer=array([int64(val) for val in [-1, 2, 10000000000]])))
param2 = DoubleArrayParameter([-1, 2.0, 1e10])
self.assertEqual(param1, param2)

# a float array parameter is not the same as a double array parameter
Expand All @@ -63,42 +92,82 @@ def test_double_parameter(self):

with self.assertRaises(TypeError):
DoubleArrayParameter([0.5, -0.5, 'NaN'])
with self.assertRaises(TypeError):
DoubleArrayParameter(ndarray(shape=[3], dtype=single,
buffer=array([single(val) for val in [-0.5, 0.5, 1e6]])))
with self.assertRaises(TypeError):
DoubleArrayParameter(0.5)
with self.assertRaises(ValueError):
DoubleArrayParameter([])
with self.assertRaises(ValueError):
IntegerArrayParameter(ndarray(shape=[0], dtype=double, buffer=array([])))

def test_float_parameter(self):
serialized_param = b'\x00\x00\x00\xbf\x00\x00\x00\x3f\x00\x24\x74\x49'
param = FloatArrayParameter([-0.5, 0.5, 1e6])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(FloatArrayParameter.deserialize(BytesIO(serialized_param), 3), param)
param1 = FloatArrayParameter([-0.5, 0.5, 1e6])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(FloatArrayParameter.deserialize(BytesIO(serialized_param), 3), param1)

param2 = FloatArrayParameter(ndarray(shape=[3], dtype=single,
buffer=array([single(val) for val in [-0.5, 0.5, 1e6]])))
self.assertEqual(param1, param2)

# an array of only integers is still a valid value of a FloatArrayParameter
param1 = FloatArrayParameter([1, 2, 1000000])
param2 = FloatArrayParameter([1, 2.0, 1e6])
param1 = FloatArrayParameter([-1, 2, 1000000])
param2 = FloatArrayParameter([-1, 2.0, 1e6])
self.assertEqual(param1, param2)

with self.assertWarns(UserWarning):
param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int32,
buffer=array([-1, 2, 1000000])))
self.assertEqual(param1, param2)

with self.assertWarns(UserWarning):
param1 = FloatArrayParameter(ndarray(shape=[1, 3], dtype=int64,
buffer=array([int64(val) for val in [-1, 2, 10000000000]])))
param2 = FloatArrayParameter([-1, 2.0, 1e10])
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
FloatArrayParameter([0.5, -0.5, 'NaN'])
with self.assertRaises(TypeError):
FloatArrayParameter(ndarray(shape=[3], dtype=double,
buffer=array([double(val) for val in [-0.5, 0.5, 1e6]])))
with self.assertRaises(TypeError):
FloatArrayParameter(0.5)
with self.assertRaises(ValueError):
FloatArrayParameter([])
with self.assertRaises(ValueError):
IntegerArrayParameter(ndarray(shape=[0], dtype=single, buffer=array([])))

def test_integer_parameter(self):
serialized_param = b'\xff\xff\xff\xff\x01\x00\x00\x00\xff\xff\xff\x7f\x00\x00\x00\x80'
param = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(IntegerArrayParameter.deserialize(BytesIO(serialized_param), 4), param)
param1 = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(IntegerArrayParameter.deserialize(BytesIO(serialized_param), 4), param1)

with self.assertWarns(UserWarning):
param2 = IntegerArrayParameter(ndarray(shape=[2, 2], dtype=int32,
buffer=array([-1, 1, 0x7fffffff, -0x80000000])))
self.assertEqual(param1, param2)

# a short array parameter is not the same as an int array parameter
param1 = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767])
param2 = IntegerArrayParameter([0, 1, -1, 255, 256, -32768, 32767])
self.assertNotEqual(param1, param2)

# verify that an integer array parameter based on a ndarray filled with int16s works
param1 = IntegerArrayParameter(ndarray(shape=[7], dtype=int16,
buffer=array([int16(val) for val in [0, 1, -1, 255, 256, -32768, 32767]])))
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
IntegerArrayParameter([1, 256, 1.0])
with self.assertRaises(TypeError):
IntegerArrayParameter(ndarray(shape=[4], dtype=int64,
buffer=array([-1, 1, 0x7fffffffffffffff, -0x8000000000000000])))
with self.assertRaises(ValueError):
IntegerArrayParameter(ndarray(shape=[0], dtype=int32, buffer=array([])))
with self.assertRaises(TypeError):
IntegerArrayParameter(1)
with self.assertRaises(ValueError):
Expand All @@ -111,30 +180,51 @@ def test_integer_parameter(self):
def test_long_parameter(self):
serialized_param = b'\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x00\x00\x00\x00\x00\x00' \
b'\xff\xff\xff\xff\xff\xff\xff\x7f\x00\x00\x00\x00\x00\x00\x00\x80'
param = LongArrayParameter([-1, 1, 0x7fffffffffffffff, -0x8000000000000000])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(LongArrayParameter.deserialize(BytesIO(serialized_param), 4), param)
param1= LongArrayParameter([-1, 1, 0x7fffffffffffffff, -0x8000000000000000])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(LongArrayParameter.deserialize(BytesIO(serialized_param), 4), param1)

with self.assertWarns(UserWarning):
param2 = LongArrayParameter(ndarray(shape=[2, 2], dtype=int64,
buffer=array([-1, 1, 0x7fffffffffffffff, -0x8000000000000000])))
self.assertEqual(param1, param2)

# an int array parameter is not the same as a long array parameter
param1 = IntegerArrayParameter([-1, 1, 0x7fffffff, -0x80000000])
param2 = LongArrayParameter([-1, 1, 0x7fffffff, -0x80000000])
self.assertNotEqual(param1, param2)

# verify that a long array parameter based on a ndarray filled with int32s works
with self.assertWarns(UserWarning):
param1 = LongArrayParameter(ndarray(shape=[1, 4], dtype=int32,
buffer=array([-1, 1, 0x7fffffff, -0x80000000])))
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
LongArrayParameter([1, 256, 1.0])
with self.assertRaises(TypeError):
LongArrayParameter(1)
with self.assertRaises(ValueError):
LongArrayParameter([])
with self.assertRaises(ValueError):
LongArrayParameter(ndarray(shape=[0], dtype=int64, buffer=array([])))

def test_short_parameter(self):
serialized_param = b'\x00\x00\x01\x00\xff\xff\xff\x00\x00\x01\x00\x80\xff\x7f'
param = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767])
self.assertEqual(serialized_param, param.serialize())
self.assertEqual(ShortArrayParameter.deserialize(BytesIO(serialized_param), 7), param)
param1 = ShortArrayParameter([0, 1, -1, 255, 256, -32768, 32767])
self.assertEqual(serialized_param, param1.serialize())
self.assertEqual(ShortArrayParameter.deserialize(BytesIO(serialized_param), 7), param1)

param2 = ShortArrayParameter(ndarray(shape=[7], dtype=int16,
buffer=array([int16(val) for val in [0, 1, -1, 255, 256, -32768, 32767]])))
self.assertEqual(param1, param2)

with self.assertRaises(TypeError):
ShortArrayParameter([1, 256, 1.0])
with self.assertRaises(TypeError):
ShortArrayParameter(ndarray(shape=[4], dtype=int32, buffer=array([-1, 1, 0x7fffffff, -0x80000000])))
with self.assertRaises(ValueError):
ShortArrayParameter(ndarray(shape=[0], dtype=int16, buffer=array([])))
with self.assertRaises(TypeError):
ShortArrayParameter(1)
with self.assertRaises(ValueError):
Expand All @@ -159,5 +249,3 @@ def test_string_parameter(self):
StringParameter(['The', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'lazy', 'dog'])
with self.assertRaises(ValueError):
StringParameter(None)


15 changes: 13 additions & 2 deletions trsfile/traceparameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import struct
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from io import BytesIO
Expand Down Expand Up @@ -37,7 +38,11 @@ def _has_expected_type(value: Any) -> bool:
pass

def __init__(self, value):
if type(value) is not str and (value is None or len(value) <= 0):
if type(value) is ndarray and len(value.shape) > 1:
warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n"
"Information about dimensions of this ndarray will be lost.")
value = value.flatten()
if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0):
raise ValueError('The value for a TraceParameter cannot be empty')
if not type(self)._has_expected_type(value):
raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"'
Expand All @@ -48,7 +53,13 @@ def __len__(self):
return len(self.value)

def __eq__(self, other):
return isinstance(other, type(self)) and self.value == other.value
if not isinstance(other, type(self)):
return False
if (type(self.value) == list or type(self.value) == ndarray) and \
(type(other.value) == list or type(other.value) == ndarray):
return all(this_val == that_val for (this_val, that_val) in zip(self.value, other.value))
else:
return self.value == other.value

def __str__(self):
return str(self.value)
Expand Down

0 comments on commit b5e6614

Please sign in to comment.