Skip to content

Commit

Permalink
solve_gesv() function (untested)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbouilla committed Sep 27, 2023
1 parent 0bd5b41 commit e60de3a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/spasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ spasm * spasm_kernel_from_rref(const spasm *R, const int *qinv);

/* spasm_solve.c */
bool spasm_solve(const spasm_lu *fact, const spasm_ZZp *b, spasm_ZZp *x);
spasm * spasm_solve_gesv(const spasm_lu *fact, const spasm *B);

/* spasm_certificate.c */
spasm_rowspan_certificate * spasm_certificate_rowspan_create(const spasm *A, const spasm_lu *fact, u64 seed);
Expand Down
2 changes: 2 additions & 0 deletions src/spasm_echelonize.c
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ static void echelonize_dense(const spasm *A, const int *p, int n, const int *p_i
update_U_after_rref(rr, Sm, S, datatype, Sqinv, q, fact);
}

// TODO: test completion and allow early abort

/* move on to the next chunk */
round += 1;
processed += Sn;
Expand Down
43 changes: 38 additions & 5 deletions src/spasm_solve.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,43 @@ bool spasm_solve(const spasm_lu *fact, const spasm_ZZp *b, spasm_ZZp *x)
return ok;
}


spasm * spasm_solve_gesv(const spasm_lu *fact, const spasm *b)
/* solve XA == B (returns garbage if a solution does not exist) */
spasm * spasm_solve_gesv(const spasm_lu *fact, const spasm *B)
{
(void) fact;
(void) b;
return NULL;
i64 prime = B->field->p;
assert(prime == fact->L->field->p);
assert(fact->L != NULL);
int n = B->n;
int m = B->m;
int Xm = fact->L->n;
spasm_triplet *X = spasm_triplet_alloc(n, Xm, (i64) Xm * n, prime, true);
int *Xi = X->i;
int *Xj = X->j;
spasm_ZZp *Xx = X->x;

#pragma omp parallel
{
spasm_ZZp *b = spasm_malloc(m * sizeof(*b));
spasm_ZZp *x = spasm_malloc(Xm * sizeof(*x));
#pragma omp for schedule(dynamic)
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++)
b[j] = 0;
spasm_scatter(B, i, 1, b);
spasm_solve(fact, b, x);
for (int j = 0; j < Xm; j++)
if (x[j] != 0) {
i64 xnz;
#pragma omp atomic capture
{ xnz = X->nz; X->nz += 1; }
Xi[xnz] = i;
Xj[xnz] = j;
Xx[xnz] = x[j];
}
}
free(b);
}
spasm *XX = spasm_compress(X);
spasm_triplet_free(X);
return XX;
}

0 comments on commit e60de3a

Please sign in to comment.