Skip to content

Commit

Permalink
testing spasm_gesv
Browse files Browse the repository at this point in the history
  • Loading branch information
cbouilla committed Oct 2, 2023
1 parent 4e95397 commit 5204982
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/spasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ struct spasm_csr * spasm_kernel_from_rref(const struct spasm_csr *R, const int *

/* spasm_solve.c */
bool spasm_solve(const struct spasm_lu *fact, const spasm_ZZp *b, spasm_ZZp *x);
struct spasm_csr * spasm_solve_gesv(const struct spasm_lu *fact, const struct spasm_csr *B);
struct spasm_csr * spasm_gesv(const struct spasm_lu *fact, const struct spasm_csr *B, bool *ok);

/* spasm_certificate.c */
struct spasm_rank_certificate * spasm_certificate_rank_create(const struct spasm_csr *A, const struct spasm_lu *fact, u64 seed);
Expand Down
2 changes: 2 additions & 0 deletions src/spasm_schur.c
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ static void * row_pointer(void *A, i64 ldA, spasm_datatype datatype, i64 i)
* q must be preallocated of size at least (m - U->n).
* on output, q sends columns of S to non-pivotal columns of A
* p_out must be of size n, p_int of size A->n
*
* TODO: detect empty rows ; push them to the end.
*/
void spasm_schur_dense(const struct spasm_csr *A, const int *p, int n, const int *p_in,
struct spasm_lu *fact, void *S, spasm_datatype datatype,int *q, int *p_out)
Expand Down
11 changes: 8 additions & 3 deletions src/spasm_solve.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ bool spasm_solve(const struct spasm_lu *fact, const spasm_ZZp *b, spasm_ZZp *x)
return ok;
}

/* solve XA == B (returns garbage if a solution does not exist) */
struct spasm_csr * spasm_solve_gesv(const struct spasm_lu *fact, const struct spasm_csr *B)
/* Solve XA == B (returns garbage if a solution does not exist).
* If ok != NULL, then sets ok[i] == 1 iff xA == B[i] has a solution
*/
struct spasm_csr * spasm_gesv(const struct spasm_lu *fact, const struct spasm_csr *B, bool *ok)
{
i64 prime = B->field->p;
assert(prime == fact->L->field->p);
Expand All @@ -69,7 +71,9 @@ struct spasm_csr * spasm_solve_gesv(const struct spasm_lu *fact, const struct sp
for (int j = 0; j < m; j++)
b[j] = 0;
spasm_scatter(B, i, 1, b);
spasm_solve(fact, b, x);
bool res = spasm_solve(fact, b, x);
if (ok)
ok[i] = res;
for (int j = 0; j < Xm; j++)
if (x[j] != 0) {
i64 xnz;
Expand All @@ -81,6 +85,7 @@ struct spasm_csr * spasm_solve_gesv(const struct spasm_lu *fact, const struct sp
}
}
free(b);
free(x);
}
struct spasm_csr *XX = spasm_compress(X);
spasm_triplet_free(X);
Expand Down
2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,11 @@ spasm_run_tests_mod(kernel "${ALL_TEST_MATRICES}")

spasm_declare_test(lu)
spasm_declare_test(solve)
spasm_declare_test(gesv)

spasm_run_tests_mod(lu "${ALL_TEST_MATRICES}")
spasm_run_tests_mod(solve "${ALL_TEST_MATRICES}")
spasm_run_tests_mod(gesv "${ALL_TEST_MATRICES}")

########## certificates

Expand Down
79 changes: 79 additions & 0 deletions tests/gesv.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <getopt.h>
#include <err.h>

#include "spasm.h"
#include "test_tools.h"

i64 prime = 42013;

void parse_command_line_options(int argc, char **argv)
{
struct option longopts[] = {
{"modulus", required_argument, NULL, 'p'},
{NULL, 0, NULL, 0}
};
char ch;
while ((ch = getopt_long(argc, argv, "", longopts, NULL)) != -1) {
switch (ch) {
case 'p':
prime = atoll(optarg);
break;
default:
errx(1, "Unknown option\n");
}
}
}

int main(int argc, char **argv)
{
struct spasm_triplet *T = spasm_triplet_load(stdin, prime, NULL);
struct spasm_csr *A = spasm_compress(T);
spasm_triplet_free(T);
int n = A->n;
int m = A->m;

/* compute LU factorization with L matrix */
struct echelonize_opts opts;
spasm_echelonize_init_opts(&opts);
opts.L = 1;
struct spasm_lu *fact = spasm_echelonize(A, &opts);
assert(spasm_factorization_verify(A, fact, 42));

bool *ok = spasm_malloc(n * sizeof(*ok));
struct spasm_csr *X = spasm_gesv(fact, A, ok);
assert(X->n == n);
assert(X->m == n);
for (int i = 0; i < n; i++)
printf("ok[%d] = %d\n", i, ok[i]);

/* check XA == A */
spasm_ZZp *x = spasm_malloc(n * sizeof(*x));
spasm_ZZp *y = spasm_malloc(n * sizeof(*y));
spasm_ZZp *z = spasm_malloc(m * sizeof(*y));
spasm_ZZp *b = spasm_malloc(m * sizeof(*b));

spasm_prng_seed(prime, 0);
for (int i = 0; i < n; i++) {
x[i] = spasm_ZZp_init(A->field, spasm_prng_next());
y[i] = 0;
}
for (int j = 0; j < m; j++) {
z[j] = 0;
b[j] = 0;
}

spasm_xApy(x, A, b);
spasm_xApy(x, X, y);
spasm_xApy(y, A, z);

for (int j = 0; j < m; j++)
if (b[j] != z[j]) {
printf("not ok - gesv solver [incorrect solution found]\n");
exit(1);
}
printf("ok\n");
return 0;
}
1 change: 0 additions & 1 deletion tests/solve.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ int main(int argc, char **argv)
struct echelonize_opts opts;
spasm_echelonize_init_opts(&opts);
opts.L = 1;
opts.max_round = 0;
struct spasm_lu *fact = spasm_echelonize(A, &opts);
assert(spasm_factorization_verify(A, fact, 42));

Expand Down
2 changes: 1 addition & 1 deletion tests/sparse_lu_usolve.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int main(int argc, char **argv)
struct spasm_csr *A = spasm_compress(T);
spasm_triplet_free(T);
int m = A->m;
if (m == 0) {
if (m <= 0) {
printf("SKIP --- empty matrix / useless\n");
exit(EXIT_SUCCESS);
}
Expand Down

0 comments on commit 5204982

Please sign in to comment.