Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Apr 27, 2022
1 parent 7877ac5 commit d42214e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly
else:
assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,)
assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0))
assert len(target['labels']) == len(target['boxes'])
if class_indices:
assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64
else:
assert isinstance(target['labels'], list) and all(isinstance(s, str) for s in target['labels'])
assert len(target['labels']) == len(target['boxes'])

# Check batching
loader = DataLoader(
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly
else:
assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,)
assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0))
assert len(target['labels']) == len(target['boxes'])
if class_indices:
assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64
else:
assert isinstance(target['labels'], list) and all(isinstance(s, str) for s in target['labels'])
assert len(target['labels']) == len(target['boxes'])

# Check batching
loader = DataLoader(ds, batch_size=batch_size)
Expand Down

0 comments on commit d42214e

Please sign in to comment.