Skip to content

Commit

Permalink
refactored Lqinv --> p
Browse files Browse the repository at this point in the history
  • Loading branch information
cbouilla committed Oct 2, 2023
1 parent d191d8e commit e8ee3f6
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 80 deletions.
2 changes: 1 addition & 1 deletion src/spasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct spasm_lu { /* a PLUQ factorisation */
struct spasm_csr *L;
struct spasm_csr *U;
int *Uqinv; /* locate pivots in U (on column j, row Uqinv[j]) */
int *Lqinv; /* locate pivots in L (on column j, row Lqinv[j]) */
int *p; /* locate pivots in L (on column j, row p[j]) */
struct spasm_triplet *Ltmp; /* for internal use during the factorization */
};

Expand Down
6 changes: 3 additions & 3 deletions src/spasm_certificate.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct spasm_rank_certificate * spasm_certificate_rank_create(const struct spasm

/* write i / j indices (positions of pivots) */
for (int k = 0; k < r; k++)
ii[k] = fact->Lqinv[k];
ii[k] = fact->p[k];
int k = 0;
for (int j = 0; j < m; j++)
if (fact->Uqinv[j] >= 0) {
Expand Down Expand Up @@ -149,7 +149,7 @@ bool spasm_factorization_verify(const struct spasm_csr *A, const struct spasm_lu
const struct spasm_csr *U = fact->U;
const struct spasm_csr *L = fact->L;
const int *Uqinv = fact->Uqinv;
const int *Lqinv = fact->Lqinv;
const int *Lp = fact->p;

int n = A->n;
int m = A->m;
Expand All @@ -165,7 +165,7 @@ bool spasm_factorization_verify(const struct spasm_csr *A, const struct spasm_lu
for (int i = 0; i < n; i++)
pivotal_row[i] = 0;
for (int j = 0; j < r; j++) {
int i = Lqinv[j];
int i = Lp[j];
assert(i >= 0);
pivotal_row[i] = 1;
}
Expand Down
18 changes: 9 additions & 9 deletions src/spasm_echelonize.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ static void echelonize_GPLU(const struct spasm_csr *A, const int *p, int n, cons
i64 *Up = U->p;
i64 unz = spasm_nnz(U);
i64 lnz = (L != NULL) ? L->nz : 0;
int *Lqinv = fact->Lqinv;
int *Lp = fact->p;

/* initialize early abort */
int rows_since_last_pivot = 0;
Expand Down Expand Up @@ -137,7 +137,7 @@ static void echelonize_GPLU(const struct spasm_csr *A, const int *p, int n, cons
/* add entry entry in L for the pivot */
if (L != NULL) {
assert(x[jpiv] != 0);
Lqinv[U->n] = i_orig;
Lp[U->n] = i_orig;
Li[lnz] = i_orig;
Lj[lnz] = U->n;
Lx[lnz] = x[jpiv];
Expand Down Expand Up @@ -233,7 +233,7 @@ static void update_fact_after_LU(int n, int Sm, int r, const void *S, spasm_data
struct spasm_csr *U = fact->U;
struct spasm_triplet *L = fact->Ltmp;
int *Uqinv = fact->Uqinv;
int *Lqinv = fact->Lqinv;
int *Lp = fact->p;
i64 extra_unz = ((i64) (1 + 2*Sm - r)) * r; /* maximum size increase */
i64 extra_lnz = ((i64) (2*n - r + 1)) * r / 2;
i64 unz = spasm_nnz(U);
Expand Down Expand Up @@ -284,7 +284,7 @@ static void update_fact_after_LU(int n, int Sm, int r, const void *S, spasm_data
lnz += 1;
}
if (i < r) /* register pivot */
Lqinv[U->n + i] = iorig;
Lp[U->n + i] = iorig;
}
L->nz = lnz;

Expand Down Expand Up @@ -495,18 +495,18 @@ struct spasm_lu * spasm_echelonize(const struct spasm_csr *A, struct echelonize_
Uqinv[j] = -1;

struct spasm_triplet *L = NULL;
int *Lqinv = NULL;
int *Lp = NULL;
if (opts->L) {
L = spasm_triplet_alloc(n, n, spasm_nnz(A), prime, true);
Lqinv = spasm_malloc(n * sizeof(*Lqinv));
Lp = spasm_malloc(n * sizeof(*Lp));
for (int j = 0; j < n; j++)
Lqinv[j] = -1;
Lp[j] = -1;
assert(L->x != NULL);
}

struct spasm_lu *fact = spasm_malloc(sizeof(*fact));
fact->L = NULL;
fact->Lqinv = Lqinv;
fact->p = Lp;
fact->U = U;
fact->Uqinv = Uqinv;
fact->Ltmp = L;
Expand Down Expand Up @@ -600,7 +600,7 @@ struct spasm_lu * spasm_echelonize(const struct spasm_csr *A, struct echelonize_
spasm_csr_realloc(U, -1);
if (opts->L) {
L->m = U->n;
fact->Lqinv = spasm_realloc(Lqinv, U->n * sizeof(*Lqinv));
fact->p = spasm_realloc(Lp, U->n * sizeof(*Lp));
fact->L = spasm_compress(L);
spasm_triplet_free(L);
fact->Ltmp = NULL;
Expand Down
4 changes: 2 additions & 2 deletions src/spasm_pivots.c
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ int spasm_pivots_extract_structural(const struct spasm_csr *A, const int *p_in,
struct spasm_csr *U = fact->U;
struct spasm_triplet *L = fact->Ltmp;
int *Uqinv = fact->Uqinv;
int *Lqinv = fact->Lqinv;
int *Lp = fact->p;
i64 pivot_nnz = 0;
for (int k = 0; k < npiv; k++) {
int i = p[k];
Expand Down Expand Up @@ -418,7 +418,7 @@ int spasm_pivots_extract_structural(const struct spasm_csr *A, const int *p_in,
int i_out = (p_in != NULL) ? p_in[i] : i;
spasm_add_entry(L, i_out, U->n, pivot);
// fprintf(stderr, "Adding L[%d, %d] = %d\n", i_out, U->n, pivot);
Lqinv[U->n] = i_out;
Lp[U->n] = i_out;
}

/* make pivot unitary and add it first */
Expand Down
2 changes: 1 addition & 1 deletion src/spasm_solve.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool spasm_solve(const struct spasm_lu *fact, const spasm_ZZp *b, spasm_ZZp *x)
bool ok = spasm_dense_forward_solve(U, y, z, Uq);

/* y.LU = b */
spasm_dense_back_solve(L, z, x, fact->Lqinv);
spasm_dense_back_solve(L, z, x, fact->p);

free(y);
free(z);
Expand Down
2 changes: 1 addition & 1 deletion src/spasm_util.c
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ void spasm_dm_free(struct spasm_dm *P)
void spasm_lu_free(struct spasm_lu *N)
{
free(N->Uqinv);
free(N->Lqinv);
free(N->p);
spasm_csr_free(N->U);
spasm_csr_free(N->L);
free(N);
Expand Down
4 changes: 2 additions & 2 deletions tests/dense_lu_ffpack.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ int main(int argc, char **argv)
fact.U = U;
fact.L = spasm_compress(L);
fact.Uqinv = spasm_malloc(m * sizeof(int));
fact.Lqinv = spasm_malloc(n * sizeof(int));
fact.p = spasm_malloc(n * sizeof(int));

for (int j = 0; j < n; j++)
fact.Lqinv[j] = P[j];
fact.p[j] = P[j];
for (int j = 0; j < m; j++)
fact.Uqinv[j] = Qinv[j];

Expand Down
4 changes: 2 additions & 2 deletions tests/lu.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ int main(int argc, char **argv)
struct spasm_lu *fact = spasm_echelonize(A, &opts);
int r = fact->r;
assert(fact->L != NULL);
assert(r == 0 || fact->Lqinv != NULL);
assert(r == 0 || fact->p != NULL);
assert(fact->Ltmp == NULL);
assert(r == fact->L->m);
assert(r == fact->U->n);
Expand All @@ -58,7 +58,7 @@ int main(int argc, char **argv)
for (int j = 0; j < m; j++)
pivotal_col[j] = 0;
for (int k = 0; k < r; k++) {
int i = fact->Lqinv[k];
int i = fact->p[k];
pivotal_row[i] = 1;
}
for (int j = 0; j < m; j++) {
Expand Down
2 changes: 1 addition & 1 deletion tests/schur.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ int main(int argc, char **argv)
fact.Uqinv = qinv;
fact.L = NULL;
fact.Ltmp = NULL;
fact.Lqinv = NULL;
fact.p = NULL;

int npiv = spasm_pivots_extract_structural(A, NULL, &fact, p, &opts);
struct spasm_csr *S = spasm_schur(A, p + npiv, n - npiv, &fact, -1, NULL, NULL, NULL);
Expand Down
65 changes: 7 additions & 58 deletions tests/schur_dense.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,45 +39,28 @@ int main(int argc, char **argv)
struct echelonize_opts opts;
spasm_echelonize_init_opts(&opts);

struct spasm_lu fact;
int *p = spasm_malloc(n * sizeof(*p));
int *Uqinv = spasm_malloc(m * sizeof(*Uqinv));
int *Lqinv = spasm_malloc(n * sizeof(*Lqinv));
struct spasm_csr *U = spasm_csr_alloc(n, m, spasm_nnz(A), prime, true);
U->n = 0;
for (int j = 0; j < m; j++)
Uqinv[j] = -1;
fact.p = spasm_malloc(n * sizeof(*fact.p));
for (int i = 0; i < n; i++)
Lqinv[i] = -1;
fact.p[i] = -1;

struct spasm_lu fact;
fact.U = U;
fact.Uqinv = Uqinv;
fact.L = NULL;
fact.Lqinv = Lqinv;
fact.Ltmp = spasm_triplet_alloc(n, n, spasm_nnz(A), prime, true);

/* find pivots, copy to U, update L */
int npiv = spasm_pivots_extract_structural(A, NULL, &fact, p, &opts);

/* dump pivots */
spasm_ZZp *y = spasm_malloc(m * sizeof(*y));
// for (int j = 0; j < m; j++) {
// int i = Uqinv[j];
// if (i >= 0) {
// fprintf(stderr, "U[%d] eliminates column %d\n", i, j);
// for (int j = 0; j < m; j++)
// y[j] = 0;
// spasm_scatter(U, i, 1, y);
//
// fprintf(stderr, "U[%d] == (", i);
// for (int j = 0; j < m; j++)
// fprintf(stderr, "%s%d", (j == 0) ? "" : ", ", y[j]);
// fprintf(stderr, ")\n");
// }
// }
// for (int k = 0; k < npiv; k++)
// fprintf(stderr, "pivot on row A[%d]\n", p[k]);


/* compute dense schur complement w.r.t said pivots */
int Sm = m - npiv;
int Sn = n - npiv;
Expand All @@ -89,32 +72,17 @@ int main(int argc, char **argv)
size_t *Sqinv = spasm_malloc(Sm * sizeof(*Sqinv)); /* for FFPACK */
size_t *Sp = spasm_malloc(Sn * sizeof(*Sp)); /* for FFPACK */
spasm_schur_dense(A, p + npiv, Sn, NULL, &fact, S, datatype, q, p_out);
// for (int i = 0; i < fact.Ltmp->nz; i++)
// printf("Ltmp : (%d, %d, %d)\n", fact.Ltmp->i[i], fact.Ltmp->j[i], fact.Ltmp->x[i]);
struct spasm_csr *L = spasm_compress(fact.Ltmp);
i64 *Lp = L->p;
int *Lj = L->j;
spasm_ZZp *Lx = L->x;
assert(L->n == n);
// printf("========================== L\n");
// spasm_save_csr(stdout, L);
// printf("========================== L\n");


/* verify result */
for (int k = 0; k < Sn; k++) {
int i = p[npiv + k];
assert(p_out[k] == i);

/* display A[i] */
// printf("**************** processing A[%d]\n", i);
// for (int j = 0; j < m; j++)
// y[j] = 0;
// spasm_scatter(A, i, 1, y);
// fprintf(stderr, "A[%d] == (", i);
// for (int j = 0; j < m; j++)
// fprintf(stderr, "%s%d", (j == 0) ? "" : ", ", y[j]);
// fprintf(stderr, ")\n");


/* start from dense row */
for (int j = 0; j < m; j++)
y[j] = 0;
Expand All @@ -123,29 +91,10 @@ int main(int argc, char **argv)
y[j] = spasm_datatype_read(S, k*Sm + l, datatype);
}

// fprintf(stderr, "A[%d] == (", i);
// for (int j = 0; j < m; j++)
// fprintf(stderr, "%s%d", (j == 0) ? "" : ", ", y[j]);
// fprintf(stderr, ")\n");

/* add contribution from L */
for (int px = Lp[i]; px < Lp[i + 1]; px++) {
// fprintf(stderr, "Adding %d x U[%d]\n", Lx[px], Lj[px]);
for (int px = Lp[i]; px < Lp[i + 1]; px++)
spasm_scatter(U, Lj[px], Lx[px], y);

// fprintf(stderr, "A[%d] == (", i);
// for (int j = 0; j < m; j++)
// fprintf(stderr, "%s%d", (j == 0) ? "" : ", ", y[j]);
// fprintf(stderr, ")\n");
}

/* this should be A[i] */
// fprintf(stderr, "final check\n");
// fprintf(stderr, "A[%d] == (", i);
// for (int j = 0; j < m; j++)
// fprintf(stderr, "%s%d", (j == 0) ? "" : ", ", y[j]);
// fprintf(stderr, ")\n");

spasm_scatter(A, i, -1, y);
for (int j = 0; j < m; j++)
assert(y[j] == 0);
Expand Down

0 comments on commit e8ee3f6

Please sign in to comment.