Skip to content

Commit

Permalink
Fixed the bug. Test file also updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
hantek committed Nov 17, 2014
1 parent ce6381f commit ad161c0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
9 changes: 6 additions & 3 deletions pylearn2/datasets/dense_design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,10 @@ def get_batch_topo(self, batch_size, include_labels=False):
else:
batch_design = self.get_batch_design(batch_size)

rval = self.view_converter.design_mat_to_topo_view(batch_design)
if hasattr(self, 'view_converter'):
rval = self.view_converter.design_mat_to_topo_view(batch_design)
else:
rval = batch_design

if include_labels:
return rval, labels
Expand Down Expand Up @@ -1484,8 +1487,8 @@ def from_dataset(dataset, num_examples):
Returns
-------
sub_dataset : DenseDesignMatrix
A new dataset containing `num_examples` examples randomly
drawn (without replacement) from `dataset`
A new dataset containing `num_examples` examples. It is a random subset
of continuous 'num_examples' examples drawn from `dataset`.
"""
try:

Expand Down
15 changes: 12 additions & 3 deletions pylearn2/datasets/tests/test_dense_design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pylearn2.datasets.dense_design_matrix import DenseDesignMatrix
from pylearn2.datasets.dense_design_matrix import DenseDesignMatrixPyTables
from pylearn2.datasets.dense_design_matrix import DefaultViewConverter
from pylearn2.datasets.dense_design_matrix import from_dataset
from pylearn2.utils import serial


Expand Down Expand Up @@ -79,12 +80,12 @@ def test_pytables():

def test_from_dataset():
"""
Tests whether it supports integer labels.
Tests whether it supports integer labels.
"""
rng = np.random.RandomState([1, 2, 3])
topo_view = rng.randn(12, 2, 2, 3)
topo_view = rng.randn(12, 2, 3, 3)
y = rng.randint(0, 5, 12)

# without y:
d1 = DenseDesignMatrix(topo_view=topo_view)
slice_d = from_dataset(d1, 5)
Expand All @@ -97,3 +98,11 @@ def test_from_dataset():
assert slice_d.X.shape[1] == d2.X.shape[1]
assert slice_d.X.shape[0] == 5
assert slice_d.y.shape[0] == 5

# without topo_view:
x = topo_view.reshape(12, 18)
d3 = DenseDesignMatrix(X=x, y=y)
slice_d = from_dataset(d3, 5)
assert slice_d.X.shape[1] == d3.X.shape[1]
assert slice_d.X.shape[0] == 5
assert slice_d.y.shape[0] == 5

0 comments on commit ad161c0

Please sign in to comment.