Skip to content

Commit

Permalink
backend: fix likelihood datatype to be float, added test for trace lo…
Browse files Browse the repository at this point in the history
…ading
  • Loading branch information
hvasbath committed Mar 20, 2019
1 parent cefc3f4 commit 022c809
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
3 changes: 1 addition & 2 deletions beat/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,7 @@ def construct_data_structure(self):

# creating data type as float
data_types = ['f8'] * len(self.varnames)
# last must be integer
data_types[-1] = 'i4'

# get the size of each array within varnames
data_size = ["{}".format(
len(self.flat_names[name])) for name in self.varnames]
Expand Down
Binary file removed test/PT_TEST/chain-1.bin
Binary file not shown.
3 changes: 0 additions & 3 deletions test/PT_TEST/chain-1.csv

This file was deleted.

21 changes: 16 additions & 5 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def two_gaussians(x):

# create data.
chain_data = num.arange(number_of_parameters).astype(num.float)
chain_like = num.array([10])
chain_like = num.array([10.]).astype(num.float)
self.lpoint = [chain_data, chain_like]
self.data_size = 2
self.sample_size = 5
self.data = []
self.expected_chain_data = []
self.expected_chain_like = []
for i in range(self.data_size):
for i in range(self.sample_size):
self.data.append(self.lpoint)
self.expected_chain_data.append(chain_data)
self.expected_chain_like.append(chain_like)
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_text_chain(self):
self.assertEqual(chain_at[self.data_keys[0]].all(), self.expected_chain_data[data_index].all())
self.assertEqual(chain_at[self.data_keys[1]].all(), self.expected_chain_like[data_index].all())

def test_text_chain_bin(self):
def test_chain_bin(self):

numpy_chain = NumpyChain(dir_path=self.test_dir_path, model=self.PT_test)
numpy_chain.setup(10, 1, overwrite=True)
Expand All @@ -106,7 +106,7 @@ def test_text_chain_bin(self):

chain_data = numpy_chain.get_values(self.data_keys[0])
chain_like = numpy_chain.get_values(self.data_keys[1])
# print("Data: ", chain_data)
print("Data: ", chain_data)
# print("Var shapes: ", numpy_chain.var_shapes)
# print("flat names: ", numpy_chain.flat_names)
# print("Var names: ", numpy_chain.varnames)
Expand All @@ -117,3 +117,14 @@ def test_text_chain_bin(self):
self.assertEqual(chain_like.all(), self.expected_chain_like.all())
self.assertEqual(chain_at[self.data_keys[0]].all(), self.expected_chain_data[data_index].all())
self.assertEqual(chain_at[self.data_keys[1]].all(), self.expected_chain_like[data_index].all())

def test_load_bin_chain(self):
numpy_chain = NumpyChain(dir_path=self.test_dir_path, model=self.PT_test)
numpy_chain.setup(5, 1)
chain_at = numpy_chain.point(1)
print(chain_at)


#def tearDown(self):
# import shutil
# shutil.rmtree(self.test_dir_path)

0 comments on commit 022c809

Please sign in to comment.