From 4f18fc0fd68fa43d7d745c9e9bddc3de02fa21e2 Mon Sep 17 00:00:00 2001 From: gbanjac Date: Sat, 5 Oct 2019 16:53:39 +0200 Subject: [PATCH] Added OSQPSettings as argument to init_linsys_solver() --- docs/contributing/index.rst | 4 ++-- include/auxil.h | 3 ++- include/lin_sys.h | 16 +++++++--------- lin_sys/direct/pardiso/pardiso_interface.c | 16 ++++++++-------- lin_sys/direct/pardiso/pardiso_interface.h | 14 +++++++------- lin_sys/direct/qdldl/qdldl_interface.c | 17 +++++++++-------- lin_sys/direct/qdldl/qdldl_interface.h | 14 +++++++------- src/lin_sys.c | 22 +++++++++++----------- src/osqp_api.c | 6 ++---- src/polish.c | 3 +-- tests/solve_linsys/test_solve_linsys.h | 14 ++++++++------ 11 files changed, 64 insertions(+), 65 deletions(-) diff --git a/docs/contributing/index.rst b/docs/contributing/index.rst index 39b5ad277..6bc6ebc1b 100644 --- a/docs/contributing/index.rst +++ b/docs/contributing/index.rst @@ -75,10 +75,10 @@ The linear system solver object is defined in :code:`mysolver.h` as follows }; // Initialize mysolver solver - c_int init_linsys_solver_mysolver(mysolver_solver ** s, const csc * P, const csc * A, c_float sigma, c_float * rho_vec, c_int polish); + c_int init_linsys_solver_mysolver(mysolver_solver ** s, const csc * P, const csc * A, c_float * rho_vec, OSQPSettings *settings, c_int polish); // Solve linear system and store result in b - c_int solve_linsys_mysolver(mysolver_solver * s, c_float * b); + c_int solve_linsys_mysolver(mysolver_solver * s, c_float * b, c_int admm_iter); // Update linear system solver matrices c_int update_linsys_solver_matrices_mysolver(mysolver_solver * s, const csc *P, const csc *A); diff --git a/include/auxil.h b/include/auxil.h index 8f5a555ca..443f014b1 100644 --- a/include/auxil.h +++ b/include/auxil.h @@ -57,7 +57,8 @@ void swap_vectors(OSQPVectorf **a, /** * Update x_tilde and z_tilde variable (first ADMM step) - * @param solver Solver + * @param solver Solver + * @param admm_iter Current ADMM iteration */ void update_xz_tilde(OSQPSolver *solver, c_int admm_iter); diff --git a/include/lin_sys.h b/include/lin_sys.h index 734ce4f57..d3daec97a 100644 --- a/include/lin_sys.h +++ b/include/lin_sys.h @@ -32,20 +32,18 @@ c_int unload_linsys_solver(enum linsys_solver_type linsys_solver); * @param s Pointer to linear system solver structure * @param P Cost function matrix * @param A Constraint matrix - * @param sigma Algorithm parameter * @param rho_vec Algorithm parameter - * @param linsys_solver Linear system solver + * @param settings Solver settings * @param polish 0/1 depending whether we are allocating for *polishing or not * @return Exitflag for error (0 if no errors) */ -c_int init_linsys_solver(LinSysSolver **s, - const OSQPMatrix *P, - const OSQPMatrix *A, - c_float sigma, - const OSQPVectorf *rho_vec, - enum linsys_solver_type linsys_solver, - c_int polish); +c_int init_linsys_solver(LinSysSolver **s, + const OSQPMatrix *P, + const OSQPMatrix *A, + const OSQPVectorf *rho_vec, + OSQPSettings *settings, + c_int polish); # ifdef __cplusplus } diff --git a/lin_sys/direct/pardiso/pardiso_interface.c b/lin_sys/direct/pardiso/pardiso_interface.c index 21f0c88cb..f04c987a9 100644 --- a/lin_sys/direct/pardiso/pardiso_interface.c +++ b/lin_sys/direct/pardiso/pardiso_interface.c @@ -73,12 +73,12 @@ void free_linsys_solver_pardiso(pardiso_solver *s) { // Initialize factorization structure -c_int init_linsys_solver_pardiso(pardiso_solver ** sp, - const OSQPMatrix * P, - const OSQPMatrix * A, - c_float sigma, - const OSQPVectorf * rho_vec, - c_int polish){ +c_int init_linsys_solver_pardiso(pardiso_solver **sp, + const OSQPMatrix *P, + const OSQPMatrix *A, + const OSQPVectorf *rho_vec, + OSQPSettings *settings, + c_int polish) { c_int i; // loop counter c_int nnzKKT; // Number of nonzeros in KKT @@ -86,10 +86,10 @@ c_int init_linsys_solver_pardiso(pardiso_solver ** sp, c_int n_plus_m; // n_plus_m dimension c_float* rhov; //used for direct access to rho_vec data when polish=false + c_float sigma = settings->sigma; // Allocate private structure to store KKT factorization - pardiso_solver *s; - s = c_calloc(1, sizeof(pardiso_solver)); + pardiso_solver *s = c_calloc(1, sizeof(pardiso_solver)); *sp = s; // Size of KKT diff --git a/lin_sys/direct/pardiso/pardiso_interface.h b/lin_sys/direct/pardiso/pardiso_interface.h index ab7d15566..70d42e037 100644 --- a/lin_sys/direct/pardiso/pardiso_interface.h +++ b/lin_sys/direct/pardiso/pardiso_interface.h @@ -81,17 +81,17 @@ struct pardiso { * @param s Pointer to a private structure * @param P Cost function matrix (upper triangular form) * @param A Constraints matrix - * @param sigma Algorithm parameter. If polish, then sigma = delta. * @param rho_vec Algorithm parameter. If polish, then rho_vec = OSQP_NULL. + * @param settings Solver settings * @param polish Flag whether we are initializing for polish or not * @return Exitflag for error (0 if no errors) */ -c_int init_linsys_solver_pardiso(pardiso_solver ** sp, - const OSQPMatrix * P, - const OSQPMatrix * A, - c_float sigma, - const OSQPVectorf * rho_vec, - c_int polish); +c_int init_linsys_solver_pardiso(pardiso_solver **sp, + const OSQPMatrix *P, + const OSQPMatrix *A, + const OSQPVectorf *rho_vec, + OSQPSettings *settings, + c_int polish); /** diff --git a/lin_sys/direct/qdldl/qdldl_interface.c b/lin_sys/direct/qdldl/qdldl_interface.c index 4fe33f58c..e27cbfcd6 100644 --- a/lin_sys/direct/qdldl/qdldl_interface.c +++ b/lin_sys/direct/qdldl/qdldl_interface.c @@ -167,12 +167,12 @@ static c_int permute_KKT(csc ** KKT, qdldl_solver * p, c_int Pnz, c_int Anz, c_i // Initialize LDL Factorization structure -c_int init_linsys_solver_qdldl(qdldl_solver ** sp, - const OSQPMatrix* P, - const OSQPMatrix* A, - c_float sigma, - const OSQPVectorf* rho_vec, - c_int polish){ +c_int init_linsys_solver_qdldl(qdldl_solver **sp, + const OSQPMatrix *P, + const OSQPMatrix *A, + const OSQPVectorf *rho_vec, + OSQPSettings *settings, + c_int polish) { // Define Variables csc * KKT_temp; // Temporary KKT pointer @@ -180,9 +180,10 @@ c_int init_linsys_solver_qdldl(qdldl_solver ** sp, c_int n_plus_m; // Define n_plus_m dimension c_float* rhov; //used for direct access to rho_vec data when polish=false + c_float sigma = settings->sigma; + // Allocate private structure to store KKT factorization - qdldl_solver *s; - s = c_calloc(1, sizeof(qdldl_solver)); + qdldl_solver *s = c_calloc(1, sizeof(qdldl_solver)); *sp = s; // Size of KKT diff --git a/lin_sys/direct/qdldl/qdldl_interface.h b/lin_sys/direct/qdldl/qdldl_interface.h index 6e99c0654..ef28be683 100644 --- a/lin_sys/direct/qdldl/qdldl_interface.h +++ b/lin_sys/direct/qdldl/qdldl_interface.h @@ -85,17 +85,17 @@ struct qdldl { * @param s Pointer to a private structure * @param P Cost function matrix (upper triangular form) * @param A Constraints matrix - * @param sigma Algorithm parameter. If polish, then sigma = delta. * @param rho_vec Algorithm parameter. If polish, then rho_vec = OSQP_NULL. + * @param settings Solver settings * @param polish Flag whether we are initializing for polish or not * @return Exitflag for error (0 if no errors) */ -c_int init_linsys_solver_qdldl(qdldl_solver ** sp, - const OSQPMatrix * P, - const OSQPMatrix * A, - c_float sigma, - const OSQPVectorf* rho_vec, - c_int polish); +c_int init_linsys_solver_qdldl(qdldl_solver **sp, + const OSQPMatrix *P, + const OSQPMatrix *A, + const OSQPVectorf *rho_vec, + OSQPSettings *settings, + c_int polish); /** * Solve linear system and store result in b diff --git a/src/lin_sys.c b/src/lin_sys.c index 68bd07e18..fff88cfe0 100644 --- a/src/lin_sys.c +++ b/src/lin_sys.c @@ -53,23 +53,23 @@ c_int unload_linsys_solver(enum linsys_solver_type linsys_solver) { // Initialize linear system solver structure // NB: Only the upper triangular part of P is filled -c_int init_linsys_solver(LinSysSolver **s, - const OSQPMatrix *P, - const OSQPMatrix *A, - c_float sigma, - const OSQPVectorf *rho_vec, - enum linsys_solver_type linsys_solver, - c_int polish) { - switch (linsys_solver) { +c_int init_linsys_solver(LinSysSolver **s, + const OSQPMatrix *P, + const OSQPMatrix *A, + const OSQPVectorf *rho_vec, + OSQPSettings *settings, + c_int polish) { + + switch (settings->linsys_solver) { case QDLDL_SOLVER: - return init_linsys_solver_qdldl((qdldl_solver **)s, P, A, sigma, rho_vec, polish); + return init_linsys_solver_qdldl((qdldl_solver **)s, P, A, rho_vec, settings, polish); # ifdef ENABLE_MKL_PARDISO case MKL_PARDISO_SOLVER: - return init_linsys_solver_pardiso((pardiso_solver **)s, P, A, sigma, rho_vec, polish); + return init_linsys_solver_pardiso((pardiso_solver **)s, P, A, rho_vec, settings, polish); # endif /* ifdef ENABLE_MKL_PARDISO */ default: // QDLDL - return init_linsys_solver_qdldl((qdldl_solver **)s, P, A, sigma, rho_vec, polish); + return init_linsys_solver_qdldl((qdldl_solver **)s, P, A, rho_vec, settings, polish); } } diff --git a/src/osqp_api.c b/src/osqp_api.c index 51f4c4933..231fc4af6 100644 --- a/src/osqp_api.c +++ b/src/osqp_api.c @@ -241,10 +241,8 @@ c_int osqp_setup(OSQPSolver** solverp, if (load_linsys_solver(settings->linsys_solver)) return osqp_error(OSQP_LINSYS_SOLVER_LOAD_ERROR); // Initialize linear system solver structure - exitflag = init_linsys_solver(&(work->linsys_solver), work->data->P, work->data->A, - settings->sigma, - work->rho_vec, - settings->linsys_solver, 0); + exitflag = init_linsys_solver(&(work->linsys_solver), work->data->P, work->data->A, + work->rho_vec, solver->settings, 0); if (exitflag) { return osqp_error(exitflag); diff --git a/src/polish.c b/src/polish.c index 469836197..6327ee2a0 100644 --- a/src/polish.c +++ b/src/polish.c @@ -229,8 +229,7 @@ c_int polish(OSQPSolver *solver) { // Form and factorize reduced KKT exitflag = init_linsys_solver(&plsh, work->data->P, work->pol->Ared, - settings->delta, OSQP_NULL, - settings->linsys_solver, 1); + OSQP_NULL, settings, 1); if (exitflag) { // Polishing failed diff --git a/tests/solve_linsys/test_solve_linsys.h b/tests/solve_linsys/test_solve_linsys.h index 6d19ad09d..54cbc40a3 100644 --- a/tests/solve_linsys/test_solve_linsys.h +++ b/tests/solve_linsys/test_solve_linsys.h @@ -18,8 +18,9 @@ static const char* test_solveKKT() { solve_linsys_sols_data *data = generate_problem_solve_linsys_sols_data(); // Settings - settings->rho = data->test_solve_KKT_rho; - settings->sigma = data->test_solve_KKT_sigma; + settings->rho = data->test_solve_KKT_rho; + settings->sigma = data->test_solve_KKT_sigma; + settings->linsys_solver = LINSYS_SOLVER; // Set rho_vec m = data->test_solve_KKT_A->m; @@ -32,7 +33,7 @@ static const char* test_solveKKT() { A = OSQPMatrix_new_from_csc(data->test_solve_KKT_A, 0); // Form and factorize KKT matrix - exitflag = init_linsys_solver(&s, Pu, A, settings->sigma, rho_vec, LINSYS_SOLVER, 0); + exitflag = init_linsys_solver(&s, Pu, A, rho_vec, settings, 0); // Solve KKT x = b via LDL given factorization rhs = OSQPVectorf_new(data->test_solve_KKT_rhs, m+n); @@ -68,8 +69,9 @@ static char* test_solveKKT_pardiso() { solve_linsys_sols_data *data = generate_problem_solve_linsys_sols_data(); // Settings - settings->rho = data->test_solve_KKT_rho; - settings->sigma = data->test_solve_KKT_sigma; + settings->rho = data->test_solve_KKT_rho; + settings->sigma = data->test_solve_KKT_sigma; + settings->linsys_solver = MKL_PARDISO_SOLVER; // Set rho_vec m = data->test_solve_KKT_A->m; @@ -87,7 +89,7 @@ static char* test_solveKKT_pardiso() { exitflag == 0); // Form and factorize KKT matrix - exitflag = init_linsys_solver(&s, Pu, A, settings->sigma, rho_vec, MKL_PARDISO_SOLVER, 0); + exitflag = init_linsys_solver(&s, Pu, A, rho_vec, settings, 0); // Solve KKT x = b via LDL given factorization rhs = OSQPVectorf_new(data->test_solve_KKT_rhs, m+n);