-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
251 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,34 @@ | ||
#include <catch2/catch_test_macros.hpp> | ||
|
||
#include <cstddef> | ||
#include <dmt/fdmt.hpp> | ||
|
||
TEST_CASE("make_vector", "[make_vector]") { | ||
REQUIRE(make_vector<int>(0) == std::vector<int>{}); | ||
REQUIRE(make_vector<int>(5) == std::vector<int>{0, 1, 2, 3, 4}); | ||
REQUIRE(make_vector<short>(3) == std::vector<short>{0, 1, 2}); | ||
REQUIRE(make_vector<std::size_t>(4) == std::vector<std::size_t>{0, 1, 2, 3}); | ||
TEST_CASE("FDMT class tests", "[fdmt]") { | ||
SECTION("Test case 1: Constructor and getter methods") { | ||
FDMT fdmt(1000.0F, 1500.0F, 500, 1024, 0.001F, 512, 1, 0); | ||
REQUIRE(fdmt.get_df() == 1.0F); | ||
REQUIRE(fdmt.get_correction() == 0.5F); | ||
REQUIRE(fdmt.get_dt_grid_final().size() == 513); | ||
REQUIRE(fdmt.get_niters() == 9); | ||
REQUIRE(fdmt.get_plan().df_top.size() == 10); | ||
REQUIRE(fdmt.get_plan().df_bot.size() == 10); | ||
REQUIRE(fdmt.get_plan().state_shape.size() == 10); | ||
REQUIRE(fdmt.get_plan().sub_plan.size() == 10); | ||
} | ||
SECTION("Test case 2: initialise method") { | ||
FDMT fdmt(1000.0F, 1500.0F, 500, 1024, 0.001F, 512, 1, 0); | ||
std::vector<float> waterfall(500 * 1024, 1.0f); | ||
const size_t dt_init_size = fdmt.get_dt_grid_init().size(); | ||
std::vector<float> state(500 * 1024 * dt_init_size, 0.0f); | ||
REQUIRE_NOTHROW(fdmt.initialise(waterfall.data(), state.data())); | ||
} | ||
|
||
SECTION("Test case 3: execute method") { | ||
FDMT fdmt(1000.0F, 1500.0F, 500, 1024, 0.001F, 512, 1, 0); | ||
std::vector<float> waterfall(500 * 1024, 1.0f); | ||
const size_t dt_final_size = fdmt.get_dt_grid_final().size(); | ||
std::vector<float> dmt(dt_final_size * 1024, 0.0f); | ||
REQUIRE_NOTHROW(fdmt.execute(waterfall.data(), waterfall.size(), | ||
dmt.data(), dmt.size())); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
#include <catch2/catch_test_macros.hpp> | ||
|
||
#include <dmt/fdmt_utils.hpp> | ||
|
||
TEST_CASE("cff", "[fdmt_utils]") { | ||
REQUIRE(fdmt::cff(1000.0F, 1500.0F, 1000.0F, 1500.0F) == 1.0F); | ||
REQUIRE(fdmt::cff(1500.0F, 1000.0F, 1500.0F, 1000.0F) == 1.0F); | ||
REQUIRE(fdmt::cff(1000.0F, 1000.0F, 1000.0F, 1500.0F) == 0.0F); | ||
} | ||
|
||
TEST_CASE("calculate_dt_sub", "[fdmt_utils]") { | ||
REQUIRE(fdmt::calculate_dt_sub(1000.0F, 1500.0F, 1000.0F, 1500.0F, 100) | ||
== 100); | ||
REQUIRE(fdmt::calculate_dt_sub(1000.0F, 1500.0F, 1000.0F, 1500.0F, 0) == 0); | ||
} | ||
|
||
TEST_CASE("calculate_dt_grid_sub", "[fdmt_utils]") { | ||
SECTION("Test case 1: only dt_max") { | ||
const size_t dt_max = 512; | ||
const size_t dt_step = 1; | ||
const size_t dt_min = 0; | ||
auto dt_grid = fdmt::calculate_dt_grid_sub( | ||
1000.0F, 1500.0F, 1000.0F, 1500.0F, dt_max, dt_step, dt_min); | ||
REQUIRE(dt_grid.size() == dt_max - dt_min + 1); | ||
REQUIRE(dt_grid[0] == dt_min); | ||
REQUIRE(dt_grid[512] == dt_max); | ||
} | ||
SECTION("Test case 2: dt_max and dt_min") { | ||
const size_t dt_max = 512; | ||
const size_t dt_step = 1; | ||
const size_t dt_min = 100; | ||
auto dt_grid = fdmt::calculate_dt_grid_sub( | ||
1000.0F, 1500.0F, 1000.0F, 1500.0F, dt_max, dt_step, dt_min); | ||
REQUIRE(dt_grid.size() == dt_max - dt_min + 1); | ||
REQUIRE(dt_grid[0] == dt_min); | ||
REQUIRE(dt_grid[412] == dt_max); | ||
} | ||
} | ||
|
||
TEST_CASE("add_offset_kernel", "[fdmt_utils]") { | ||
SECTION("Test case 1: Valid input and output vectors") { | ||
std::vector<float> arr1 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; | ||
std::vector<float> arr2 = {6.0f, 7.0f, 8.0f}; | ||
std::vector<float> arr_out(8, 0.0f); | ||
size_t offset = 2; | ||
REQUIRE_NOTHROW(fdmt::add_offset_kernel(arr1.data(), arr1.size(), | ||
arr2.data(), arr_out.data(), | ||
arr_out.size(), offset)); | ||
std::vector<float> expected_output | ||
= {1.0f, 2.0f, 9.0f, 11.0f, 13.0f, 0.0f, 0.0f, 0.0f}; | ||
REQUIRE(arr_out == expected_output); | ||
} | ||
SECTION("Test case 2: Output size less than input size") { | ||
std::vector<float> arr1 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; | ||
std::vector<float> arr2 = {6.0f, 7.0f, 8.0f}; | ||
std::vector<float> arr_out(4, 0.0f); | ||
size_t offset = 2; | ||
REQUIRE_THROWS_AS(fdmt::add_offset_kernel(arr1.data(), arr1.size(), | ||
arr2.data(), arr_out.data(), | ||
arr_out.size(), offset), | ||
std::runtime_error); | ||
} | ||
|
||
SECTION("Test case 3: Offset greater than input size") { | ||
std::vector<float> arr1 = {1.0f, 2.0f, 3.0f}; | ||
std::vector<float> arr2 = {4.0f, 5.0f}; | ||
std::vector<float> arr_out(5, 0.0f); | ||
size_t offset = 4; | ||
REQUIRE_THROWS_AS(fdmt::add_offset_kernel(arr1.data(), arr1.size(), | ||
arr2.data(), arr_out.data(), | ||
arr_out.size(), offset), | ||
std::runtime_error); | ||
} | ||
SECTION("Test case 4: Empty input vectors") { | ||
std::vector<float> arr1; | ||
std::vector<float> arr2; | ||
std::vector<float> arr_out(3, 0.0f); | ||
size_t offset = 0; | ||
REQUIRE_THROWS_AS(fdmt::add_offset_kernel(arr1.data(), arr1.size(), | ||
arr2.data(), arr_out.data(), | ||
arr_out.size(), offset), | ||
std::runtime_error); | ||
} | ||
} | ||
|
||
TEST_CASE("copy_kernel", "[fdmt_utils]") { | ||
SECTION("Test case 1: Valid input and output vectors") { | ||
std::vector<float> arr1 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; | ||
std::vector<float> arr_out(10, 0.0f); | ||
; | ||
REQUIRE_NOTHROW(fdmt::copy_kernel(arr1.data(), arr1.size(), | ||
arr_out.data(), arr_out.size())); | ||
for (size_t i = 0; i < arr1.size(); ++i) { | ||
REQUIRE(arr_out[i] == arr1[i]); | ||
} | ||
for (size_t i = arr1.size(); i < arr_out.size(); ++i) { | ||
REQUIRE(arr_out[i] == 0.0f); | ||
} | ||
} | ||
SECTION("Test case 2: Output size less than input size") { | ||
std::vector<float> arr1 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; | ||
std::vector<float> arr_out(3, 0.0f); | ||
REQUIRE_THROWS_AS(fdmt::copy_kernel(arr1.data(), arr1.size(), | ||
arr_out.data(), arr_out.size()), | ||
std::runtime_error); | ||
} | ||
SECTION("Test case 4: Empty input vector") { | ||
std::vector<float> arr1; | ||
std::vector<float> arr_out(5, 0.0f); | ||
REQUIRE_NOTHROW(fdmt::copy_kernel(arr1.data(), arr1.size(), | ||
arr_out.data(), arr_out.size())); | ||
for (size_t i = 0; i < arr_out.size(); ++i) { | ||
REQUIRE(arr_out[i] == 0.0f); | ||
} | ||
} | ||
} | ||
|
||
TEST_CASE("find_closest_index", "[fdmt_utils]") { | ||
SECTION("Test case 1: Empty array") { | ||
std::vector<size_t> arr_sorted; | ||
REQUIRE_THROWS_AS(fdmt::find_closest_index(arr_sorted, 10), | ||
std::runtime_error); | ||
} | ||
|
||
SECTION("Test case 2: Array with one element - exact match") { | ||
std::vector<size_t> arr_sorted{10}; | ||
size_t val = 10; | ||
size_t expected = 0; | ||
size_t result = fdmt::find_closest_index(arr_sorted, val); | ||
REQUIRE(result == expected); | ||
} | ||
|
||
SECTION("Test case 3: Array with one element - closest match") { | ||
std::vector<size_t> arr_sorted{10}; | ||
size_t val = 15; | ||
size_t expected = 0; | ||
size_t result = fdmt::find_closest_index(arr_sorted, val); | ||
REQUIRE(result == expected); | ||
} | ||
|
||
SECTION("Test case 4: Array with multiple elements - exact match") { | ||
std::vector<size_t> arr_sorted{10, 20, 30, 40, 50}; | ||
size_t val = 30; | ||
size_t expected = 2; | ||
size_t result = fdmt::find_closest_index(arr_sorted, val); | ||
REQUIRE(result == expected); | ||
} | ||
|
||
SECTION( | ||
"Test case 5: Array with multiple elements - closest match (lower)") { | ||
std::vector<size_t> arr_sorted{10, 20, 30, 40, 50}; | ||
size_t val = 24; | ||
size_t expected = 1; | ||
size_t result = fdmt::find_closest_index(arr_sorted, val); | ||
REQUIRE(result == expected); | ||
} | ||
|
||
SECTION( | ||
"Test case 6: Array with multiple elements - closest match (upper)") { | ||
std::vector<size_t> arr_sorted{10, 20, 30, 40, 50}; | ||
size_t val = 26; | ||
size_t expected = 2; | ||
size_t result = fdmt::find_closest_index(arr_sorted, val); | ||
REQUIRE(result == expected); | ||
} | ||
|
||
SECTION("Test case 7: Array with multiple elements - value smaller than " | ||
"all elements") { | ||
std::vector<size_t> arr_sorted{10, 20, 30, 40, 50}; | ||
size_t val = 5; | ||
size_t expected = 0; | ||
size_t result = fdmt::find_closest_index(arr_sorted, val); | ||
REQUIRE(result == expected); | ||
} | ||
|
||
SECTION("Test case 8: Array with multiple elements - value larger than all " | ||
"elements") { | ||
std::vector<size_t> arr_sorted{10, 20, 30, 40, 50}; | ||
size_t val = 60; | ||
size_t expected = 4; | ||
size_t result = fdmt::find_closest_index(arr_sorted, val); | ||
REQUIRE(result == expected); | ||
} | ||
|
||
SECTION("Test case 9: Array with multiple elements - duplicate values") { | ||
std::vector<size_t> arr_sorted{10, 20, 20, 30, 40, 50}; | ||
size_t val = 20; | ||
size_t expected = 1; | ||
size_t result = fdmt::find_closest_index(arr_sorted, val); | ||
REQUIRE(result == expected); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
from dmt import libdmt | ||
|
||
|
||
class TestFDMT: | ||
def test_initialise_ones(self) -> None: | ||
nchans = 500 | ||
nsamples = 1024 | ||
dt_max = 512 | ||
thefdmt = libdmt.FDMT(1000, 1500, nchans, nsamples, 0.001, dt_max) | ||
waterfall = np.ones((nchans, nsamples), dtype=np.float32) | ||
thefdmt_init = thefdmt.initialise(waterfall) | ||
np.testing.assert_equal( | ||
thefdmt_init.shape, | ||
(nchans, thefdmt.dt_grid_init.size, nsamples), | ||
) | ||
""" | ||
np.testing.assert_equal( | ||
thefdmt_init, | ||
np.ones((nchans, thefdmt.dt_grid_init.size, nsamples), dtype=np.float32), | ||
) | ||
""" |