Skip to content

Commit

Permalink
fix lu
Browse files Browse the repository at this point in the history
  • Loading branch information
cbouilla committed Sep 20, 2023
1 parent 0f299af commit a4271ef
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
11 changes: 6 additions & 5 deletions src/spasm_echelonize.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ bool spasm_echelonize_test_completion(const spasm *A, const int *p, int n, spasm


/* not dry w.r.t. spasm_LU() */
void spasm_echelonize_GPLU(const spasm *A, const int *p, int n, spasm_lu *fact, struct echelonize_opts *opts)
void spasm_echelonize_GPLU(const spasm *A, const int *p, int n, spasm_lu *fact, int *p_in, struct echelonize_opts *opts)
{
(void) opts;
assert(p != NULL);
Expand Down Expand Up @@ -123,6 +123,7 @@ void spasm_echelonize_GPLU(const spasm *A, const int *p, int n, spasm_lu *fact,

/* Triangular solve: x * U = A[i] */
int inew = p[i];
int i_orig = (p_in != NULL) ? p_in[inew] : inew;
int top = spasm_sparse_triangular_solve(U, A, inew, xj, x, Uqinv);

/* Find pivot column; current poor strategy= choose leftmost */
Expand All @@ -137,7 +138,7 @@ void spasm_echelonize_GPLU(const spasm *A, const int *p, int n, spasm_lu *fact,
jpiv = j;
} else if (L != NULL) {
/* everything under pivotal columns goes into L */
Li[lnz] = inew;
Li[lnz] = i_orig;
Lj[lnz] = Uqinv[j];
Lx[lnz] = x[j];
lnz += 1;
Expand All @@ -150,8 +151,8 @@ void spasm_echelonize_GPLU(const spasm *A, const int *p, int n, spasm_lu *fact,
/* add entry entry in L for the pivot */
if (L != NULL) {
assert(x[jpiv] != 0);
Lqinv[U->n] = inew;
Li[lnz] = inew;
Lqinv[U->n] = i_orig;
Li[lnz] = i_orig;
Lj[lnz] = U->n;
Lx[lnz] = x[jpiv];
lnz += 1;
Expand Down Expand Up @@ -477,7 +478,7 @@ spasm_lu * spasm_echelonize(const spasm *A, struct echelonize_opts *opts)
else if (opts->enable_dense && density > opts->sparsity_threshold)
spasm_echelonize_dense(A, p + npiv, n - npiv, U, Uqinv, opts);
else if (opts->enable_GPLU)
spasm_echelonize_GPLU(A, p + npiv, n - npiv, fact, opts);
spasm_echelonize_GPLU(A, p + npiv, n - npiv, fact, p_in, opts);
else
fprintf(stderr, "[echelonize] Cannot finish (no valid method enabled). Incomplete echelonization returned\n");

Expand Down
8 changes: 3 additions & 5 deletions src/spasm_schur.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ spasm *spasm_schur(const spasm *A, const int *p, int n, const spasm_lu *fact,
if (est_density < 0)
est_density = spasm_schur_estimate_density(A, p, n, fact->U, qinv, 100);
long long size = (est_density * n) * m;
if (size > 2147483648)
errx(1, "Matrix too large (more than 2^31 entries)");
i64 prime = spasm_get_prime(A);
spasm *S = spasm_csr_alloc(n, m, size, prime, SPASM_WITH_NUMERICAL_VALUES);
i64 *Sp = S->p;
Expand Down Expand Up @@ -151,9 +149,9 @@ spasm *spasm_schur(const spasm *A, const int *p, int n, const spasm_lu *fact,
}

/* write the new row in L / S */
int i_out = (p_in != NULL) ? p_in[inew] : inew;
int i_orig = (p_in != NULL) ? p_in[inew] : inew;
if (p_out != NULL)
p_out[local_i] = i_out;
p_out[local_i] = i_orig;

for (int px = top; px < m; px++) {
int j = xj[px];
Expand All @@ -164,7 +162,7 @@ spasm *spasm_schur(const spasm *A, const int *p, int n, const spasm_lu *fact,
Sx[local_snz] = x[j];
local_snz += 1;
} else if (L != NULL) {
Li[local_lnz] = i_out;
Li[local_lnz] = i_orig;
Lj[local_lnz] = qinv[j];
Lx[local_lnz] = x[j];
// fprintf(stderr, "Adding L[%d, %d] = %d\n", i_out, qinv[j], x[j]);
Expand Down
6 changes: 3 additions & 3 deletions tests/lu.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ int main(int argc, char **argv)
struct echelonize_opts opts;
spasm_echelonize_init_opts(&opts);
opts.L = 1;
opts.enable_greedy_pivot_search = 0;
spasm_lu *fact = spasm_echelonize(A, &opts);
int r = fact->U->n;
assert(fact->L != NULL);
Expand Down Expand Up @@ -79,7 +78,8 @@ int main(int argc, char **argv)
y[j] = 0;
v[j] = 0;
}
// printf("###################### i=%d\n", i);
printf("\ri=%d / %d\n", i, n);
fflush(stdout);
x[i] = 1;

spasm_xApy(x, A, y); // y <- x*A
Expand All @@ -89,7 +89,7 @@ int main(int argc, char **argv)
for (int j = 0; j < m; j++) {
// printf("# x*A[%4d] = %8d VS x*L[%4d] = %8d VS x*LU[%4d] = %8d\n", j, y[j], j, u[j], j, v[j]);
if (y[j] != v[j])
printf("mismatch on row %d (pivotal=%d), column %d (pivotal=%d)\n",
printf("\nmismatch on row %d (pivotal=%d), column %d (pivotal=%d)\n",
i, pivotal_row[i], j, pivotal_col[j]);
assert(y[j] == v[j]);
}
Expand Down

0 comments on commit a4271ef

Please sign in to comment.