From 4f18fc0fd68fa43d7d745c9e9bddc3de02fa21e2 Mon Sep 17 00:00:00 2001
From: gbanjac
Date: Sat, 5 Oct 2019 16:53:39 +0200
Subject: [PATCH] Added OSQPSettings as argument to init_linsys_solver()
---
docs/contributing/index.rst | 4 ++--
include/auxil.h | 3 ++-
include/lin_sys.h | 16 +++++++---------
lin_sys/direct/pardiso/pardiso_interface.c | 16 ++++++++--------
lin_sys/direct/pardiso/pardiso_interface.h | 14 +++++++-------
lin_sys/direct/qdldl/qdldl_interface.c | 17 +++++++++--------
lin_sys/direct/qdldl/qdldl_interface.h | 14 +++++++-------
src/lin_sys.c | 22 +++++++++++-----------
src/osqp_api.c | 6 ++----
src/polish.c | 3 +--
tests/solve_linsys/test_solve_linsys.h | 14 ++++++++------
11 files changed, 64 insertions(+), 65 deletions(-)
diff --git a/docs/contributing/index.rst b/docs/contributing/index.rst
index 39b5ad277..6bc6ebc1b 100644
--- a/docs/contributing/index.rst
+++ b/docs/contributing/index.rst
@@ -75,10 +75,10 @@ The linear system solver object is defined in :code:`mysolver.h` as follows
};
// Initialize mysolver solver
- c_int init_linsys_solver_mysolver(mysolver_solver ** s, const csc * P, const csc * A, c_float sigma, c_float * rho_vec, c_int polish);
+ c_int init_linsys_solver_mysolver(mysolver_solver ** s, const csc * P, const csc * A, c_float * rho_vec, OSQPSettings *settings, c_int polish);
// Solve linear system and store result in b
- c_int solve_linsys_mysolver(mysolver_solver * s, c_float * b);
+ c_int solve_linsys_mysolver(mysolver_solver * s, c_float * b, c_int admm_iter);
// Update linear system solver matrices
c_int update_linsys_solver_matrices_mysolver(mysolver_solver * s, const csc *P, const csc *A);
diff --git a/include/auxil.h b/include/auxil.h
index 8f5a555ca..443f014b1 100644
--- a/include/auxil.h
+++ b/include/auxil.h
@@ -57,7 +57,8 @@ void swap_vectors(OSQPVectorf **a,
/**
* Update x_tilde and z_tilde variable (first ADMM step)
- * @param solver Solver
+ * @param solver Solver
+ * @param admm_iter Current ADMM iteration
*/
void update_xz_tilde(OSQPSolver *solver, c_int admm_iter);
diff --git a/include/lin_sys.h b/include/lin_sys.h
index 734ce4f57..d3daec97a 100644
--- a/include/lin_sys.h
+++ b/include/lin_sys.h
@@ -32,20 +32,18 @@ c_int unload_linsys_solver(enum linsys_solver_type linsys_solver);
* @param s Pointer to linear system solver structure
* @param P Cost function matrix
* @param A Constraint matrix
- * @param sigma Algorithm parameter
* @param rho_vec Algorithm parameter
- * @param linsys_solver Linear system solver
+ * @param settings Solver settings
* @param polish 0/1 depending whether we are allocating for
*polishing or not
* @return Exitflag for error (0 if no errors)
*/
-c_int init_linsys_solver(LinSysSolver **s,
- const OSQPMatrix *P,
- const OSQPMatrix *A,
- c_float sigma,
- const OSQPVectorf *rho_vec,
- enum linsys_solver_type linsys_solver,
- c_int polish);
+c_int init_linsys_solver(LinSysSolver **s,
+ const OSQPMatrix *P,
+ const OSQPMatrix *A,
+ const OSQPVectorf *rho_vec,
+ OSQPSettings *settings,
+ c_int polish);
# ifdef __cplusplus
}
diff --git a/lin_sys/direct/pardiso/pardiso_interface.c b/lin_sys/direct/pardiso/pardiso_interface.c
index 21f0c88cb..f04c987a9 100644
--- a/lin_sys/direct/pardiso/pardiso_interface.c
+++ b/lin_sys/direct/pardiso/pardiso_interface.c
@@ -73,12 +73,12 @@ void free_linsys_solver_pardiso(pardiso_solver *s) {
// Initialize factorization structure
-c_int init_linsys_solver_pardiso(pardiso_solver ** sp,
- const OSQPMatrix * P,
- const OSQPMatrix * A,
- c_float sigma,
- const OSQPVectorf * rho_vec,
- c_int polish){
+c_int init_linsys_solver_pardiso(pardiso_solver **sp,
+ const OSQPMatrix *P,
+ const OSQPMatrix *A,
+ const OSQPVectorf *rho_vec,
+ OSQPSettings *settings,
+ c_int polish) {
c_int i; // loop counter
c_int nnzKKT; // Number of nonzeros in KKT
@@ -86,10 +86,10 @@ c_int init_linsys_solver_pardiso(pardiso_solver ** sp,
c_int n_plus_m; // n_plus_m dimension
c_float* rhov; //used for direct access to rho_vec data when polish=false
+ c_float sigma = settings->sigma;
// Allocate private structure to store KKT factorization
- pardiso_solver *s;
- s = c_calloc(1, sizeof(pardiso_solver));
+ pardiso_solver *s = c_calloc(1, sizeof(pardiso_solver));
*sp = s;
// Size of KKT
diff --git a/lin_sys/direct/pardiso/pardiso_interface.h b/lin_sys/direct/pardiso/pardiso_interface.h
index ab7d15566..70d42e037 100644
--- a/lin_sys/direct/pardiso/pardiso_interface.h
+++ b/lin_sys/direct/pardiso/pardiso_interface.h
@@ -81,17 +81,17 @@ struct pardiso {
* @param s Pointer to a private structure
* @param P Cost function matrix (upper triangular form)
* @param A Constraints matrix
- * @param sigma Algorithm parameter. If polish, then sigma = delta.
* @param rho_vec Algorithm parameter. If polish, then rho_vec = OSQP_NULL.
+ * @param settings Solver settings
* @param polish Flag whether we are initializing for polish or not
* @return Exitflag for error (0 if no errors)
*/
-c_int init_linsys_solver_pardiso(pardiso_solver ** sp,
- const OSQPMatrix * P,
- const OSQPMatrix * A,
- c_float sigma,
- const OSQPVectorf * rho_vec,
- c_int polish);
+c_int init_linsys_solver_pardiso(pardiso_solver **sp,
+ const OSQPMatrix *P,
+ const OSQPMatrix *A,
+ const OSQPVectorf *rho_vec,
+ OSQPSettings *settings,
+ c_int polish);
/**
diff --git a/lin_sys/direct/qdldl/qdldl_interface.c b/lin_sys/direct/qdldl/qdldl_interface.c
index 4fe33f58c..e27cbfcd6 100644
--- a/lin_sys/direct/qdldl/qdldl_interface.c
+++ b/lin_sys/direct/qdldl/qdldl_interface.c
@@ -167,12 +167,12 @@ static c_int permute_KKT(csc ** KKT, qdldl_solver * p, c_int Pnz, c_int Anz, c_i
// Initialize LDL Factorization structure
-c_int init_linsys_solver_qdldl(qdldl_solver ** sp,
- const OSQPMatrix* P,
- const OSQPMatrix* A,
- c_float sigma,
- const OSQPVectorf* rho_vec,
- c_int polish){
+c_int init_linsys_solver_qdldl(qdldl_solver **sp,
+ const OSQPMatrix *P,
+ const OSQPMatrix *A,
+ const OSQPVectorf *rho_vec,
+ OSQPSettings *settings,
+ c_int polish) {
// Define Variables
csc * KKT_temp; // Temporary KKT pointer
@@ -180,9 +180,10 @@ c_int init_linsys_solver_qdldl(qdldl_solver ** sp,
c_int n_plus_m; // Define n_plus_m dimension
c_float* rhov; //used for direct access to rho_vec data when polish=false
+ c_float sigma = settings->sigma;
+
// Allocate private structure to store KKT factorization
- qdldl_solver *s;
- s = c_calloc(1, sizeof(qdldl_solver));
+ qdldl_solver *s = c_calloc(1, sizeof(qdldl_solver));
*sp = s;
// Size of KKT
diff --git a/lin_sys/direct/qdldl/qdldl_interface.h b/lin_sys/direct/qdldl/qdldl_interface.h
index 6e99c0654..ef28be683 100644
--- a/lin_sys/direct/qdldl/qdldl_interface.h
+++ b/lin_sys/direct/qdldl/qdldl_interface.h
@@ -85,17 +85,17 @@ struct qdldl {
* @param s Pointer to a private structure
* @param P Cost function matrix (upper triangular form)
* @param A Constraints matrix
- * @param sigma Algorithm parameter. If polish, then sigma = delta.
* @param rho_vec Algorithm parameter. If polish, then rho_vec = OSQP_NULL.
+ * @param settings Solver settings
* @param polish Flag whether we are initializing for polish or not
* @return Exitflag for error (0 if no errors)
*/
-c_int init_linsys_solver_qdldl(qdldl_solver ** sp,
- const OSQPMatrix * P,
- const OSQPMatrix * A,
- c_float sigma,
- const OSQPVectorf* rho_vec,
- c_int polish);
+c_int init_linsys_solver_qdldl(qdldl_solver **sp,
+ const OSQPMatrix *P,
+ const OSQPMatrix *A,
+ const OSQPVectorf *rho_vec,
+ OSQPSettings *settings,
+ c_int polish);
/**
* Solve linear system and store result in b
diff --git a/src/lin_sys.c b/src/lin_sys.c
index 68bd07e18..fff88cfe0 100644
--- a/src/lin_sys.c
+++ b/src/lin_sys.c
@@ -53,23 +53,23 @@ c_int unload_linsys_solver(enum linsys_solver_type linsys_solver) {
// Initialize linear system solver structure
// NB: Only the upper triangular part of P is filled
-c_int init_linsys_solver(LinSysSolver **s,
- const OSQPMatrix *P,
- const OSQPMatrix *A,
- c_float sigma,
- const OSQPVectorf *rho_vec,
- enum linsys_solver_type linsys_solver,
- c_int polish) {
- switch (linsys_solver) {
+c_int init_linsys_solver(LinSysSolver **s,
+ const OSQPMatrix *P,
+ const OSQPMatrix *A,
+ const OSQPVectorf *rho_vec,
+ OSQPSettings *settings,
+ c_int polish) {
+
+ switch (settings->linsys_solver) {
case QDLDL_SOLVER:
- return init_linsys_solver_qdldl((qdldl_solver **)s, P, A, sigma, rho_vec, polish);
+ return init_linsys_solver_qdldl((qdldl_solver **)s, P, A, rho_vec, settings, polish);
# ifdef ENABLE_MKL_PARDISO
case MKL_PARDISO_SOLVER:
- return init_linsys_solver_pardiso((pardiso_solver **)s, P, A, sigma, rho_vec, polish);
+ return init_linsys_solver_pardiso((pardiso_solver **)s, P, A, rho_vec, settings, polish);
# endif /* ifdef ENABLE_MKL_PARDISO */
default: // QDLDL
- return init_linsys_solver_qdldl((qdldl_solver **)s, P, A, sigma, rho_vec, polish);
+ return init_linsys_solver_qdldl((qdldl_solver **)s, P, A, rho_vec, settings, polish);
}
}
diff --git a/src/osqp_api.c b/src/osqp_api.c
index 51f4c4933..231fc4af6 100644
--- a/src/osqp_api.c
+++ b/src/osqp_api.c
@@ -241,10 +241,8 @@ c_int osqp_setup(OSQPSolver** solverp,
if (load_linsys_solver(settings->linsys_solver)) return osqp_error(OSQP_LINSYS_SOLVER_LOAD_ERROR);
// Initialize linear system solver structure
- exitflag = init_linsys_solver(&(work->linsys_solver), work->data->P, work->data->A,
- settings->sigma,
- work->rho_vec,
- settings->linsys_solver, 0);
+ exitflag = init_linsys_solver(&(work->linsys_solver), work->data->P, work->data->A,
+ work->rho_vec, solver->settings, 0);
if (exitflag) {
return osqp_error(exitflag);
diff --git a/src/polish.c b/src/polish.c
index 469836197..6327ee2a0 100644
--- a/src/polish.c
+++ b/src/polish.c
@@ -229,8 +229,7 @@ c_int polish(OSQPSolver *solver) {
// Form and factorize reduced KKT
exitflag = init_linsys_solver(&plsh, work->data->P, work->pol->Ared,
- settings->delta, OSQP_NULL,
- settings->linsys_solver, 1);
+ OSQP_NULL, settings, 1);
if (exitflag) {
// Polishing failed
diff --git a/tests/solve_linsys/test_solve_linsys.h b/tests/solve_linsys/test_solve_linsys.h
index 6d19ad09d..54cbc40a3 100644
--- a/tests/solve_linsys/test_solve_linsys.h
+++ b/tests/solve_linsys/test_solve_linsys.h
@@ -18,8 +18,9 @@ static const char* test_solveKKT() {
solve_linsys_sols_data *data = generate_problem_solve_linsys_sols_data();
// Settings
- settings->rho = data->test_solve_KKT_rho;
- settings->sigma = data->test_solve_KKT_sigma;
+ settings->rho = data->test_solve_KKT_rho;
+ settings->sigma = data->test_solve_KKT_sigma;
+ settings->linsys_solver = LINSYS_SOLVER;
// Set rho_vec
m = data->test_solve_KKT_A->m;
@@ -32,7 +33,7 @@ static const char* test_solveKKT() {
A = OSQPMatrix_new_from_csc(data->test_solve_KKT_A, 0);
// Form and factorize KKT matrix
- exitflag = init_linsys_solver(&s, Pu, A, settings->sigma, rho_vec, LINSYS_SOLVER, 0);
+ exitflag = init_linsys_solver(&s, Pu, A, rho_vec, settings, 0);
// Solve KKT x = b via LDL given factorization
rhs = OSQPVectorf_new(data->test_solve_KKT_rhs, m+n);
@@ -68,8 +69,9 @@ static char* test_solveKKT_pardiso() {
solve_linsys_sols_data *data = generate_problem_solve_linsys_sols_data();
// Settings
- settings->rho = data->test_solve_KKT_rho;
- settings->sigma = data->test_solve_KKT_sigma;
+ settings->rho = data->test_solve_KKT_rho;
+ settings->sigma = data->test_solve_KKT_sigma;
+ settings->linsys_solver = MKL_PARDISO_SOLVER;
// Set rho_vec
m = data->test_solve_KKT_A->m;
@@ -87,7 +89,7 @@ static char* test_solveKKT_pardiso() {
exitflag == 0);
// Form and factorize KKT matrix
- exitflag = init_linsys_solver(&s, Pu, A, settings->sigma, rho_vec, MKL_PARDISO_SOLVER, 0);
+ exitflag = init_linsys_solver(&s, Pu, A, rho_vec, settings, 0);
// Solve KKT x = b via LDL given factorization
rhs = OSQPVectorf_new(data->test_solve_KKT_rhs, m+n);