1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
3 // http://code.google.com/p/ceres-solver/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
7 //
8 // * Redistributions of source code must retain the above copyright notice,
9 // this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 // this list of conditions and the following disclaimer in the documentation
12 // and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 // used to endorse or promote products derived from this software without
15 // specific prior written permission.
16 //
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
28 //
29 // Author: sameeragarwal@google.com (Sameer Agarwal)
30 //
31 // A preconditioned conjugate gradients solver
32 // (ConjugateGradientsSolver) for positive semidefinite linear
33 // systems.
34 //
35 // We have also augmented the termination criterion used by this
36 // solver to support not just residual based termination but also
37 // termination based on decrease in the value of the quadratic model
38 // that CG optimizes.
39
40 #include "ceres/conjugate_gradients_solver.h"
41
42 #include <cmath>
43 #include <cstddef>
44 #include "ceres/fpclassify.h"
45 #include "ceres/internal/eigen.h"
46 #include "ceres/linear_operator.h"
47 #include "ceres/stringprintf.h"
48 #include "ceres/types.h"
49 #include "glog/logging.h"
50
51 namespace ceres {
52 namespace internal {
53 namespace {
54
IsZeroOrInfinity(double x)55 bool IsZeroOrInfinity(double x) {
56 return ((x == 0.0) || (IsInfinite(x)));
57 }
58
59 } // namespace
60
ConjugateGradientsSolver(const LinearSolver::Options & options)61 ConjugateGradientsSolver::ConjugateGradientsSolver(
62 const LinearSolver::Options& options)
63 : options_(options) {
64 }
65
Solve(LinearOperator * A,const double * b,const LinearSolver::PerSolveOptions & per_solve_options,double * x)66 LinearSolver::Summary ConjugateGradientsSolver::Solve(
67 LinearOperator* A,
68 const double* b,
69 const LinearSolver::PerSolveOptions& per_solve_options,
70 double* x) {
71 CHECK_NOTNULL(A);
72 CHECK_NOTNULL(x);
73 CHECK_NOTNULL(b);
74 CHECK_EQ(A->num_rows(), A->num_cols());
75
76 LinearSolver::Summary summary;
77 summary.termination_type = LINEAR_SOLVER_NO_CONVERGENCE;
78 summary.message = "Maximum number of iterations reached.";
79 summary.num_iterations = 0;
80
81 const int num_cols = A->num_cols();
82 VectorRef xref(x, num_cols);
83 ConstVectorRef bref(b, num_cols);
84
85 const double norm_b = bref.norm();
86 if (norm_b == 0.0) {
87 xref.setZero();
88 summary.termination_type = LINEAR_SOLVER_SUCCESS;
89 summary.message = "Convergence. |b| = 0.";
90 return summary;
91 }
92
93 Vector r(num_cols);
94 Vector p(num_cols);
95 Vector z(num_cols);
96 Vector tmp(num_cols);
97
98 const double tol_r = per_solve_options.r_tolerance * norm_b;
99
100 tmp.setZero();
101 A->RightMultiply(x, tmp.data());
102 r = bref - tmp;
103 double norm_r = r.norm();
104 if (norm_r <= tol_r) {
105 summary.termination_type = LINEAR_SOLVER_SUCCESS;
106 summary.message =
107 StringPrintf("Convergence. |r| = %e <= %e.", norm_r, tol_r);
108 return summary;
109 }
110
111 double rho = 1.0;
112
113 // Initial value of the quadratic model Q = x'Ax - 2 * b'x.
114 double Q0 = -1.0 * xref.dot(bref + r);
115
116 for (summary.num_iterations = 1;
117 summary.num_iterations < options_.max_num_iterations;
118 ++summary.num_iterations) {
119 // Apply preconditioner
120 if (per_solve_options.preconditioner != NULL) {
121 z.setZero();
122 per_solve_options.preconditioner->RightMultiply(r.data(), z.data());
123 } else {
124 z = r;
125 }
126
127 double last_rho = rho;
128 rho = r.dot(z);
129 if (IsZeroOrInfinity(rho)) {
130 summary.termination_type = LINEAR_SOLVER_FAILURE;
131 summary.message = StringPrintf("Numerical failure. rho = r'z = %e.", rho);
132 break;
133 };
134
135 if (summary.num_iterations == 1) {
136 p = z;
137 } else {
138 double beta = rho / last_rho;
139 if (IsZeroOrInfinity(beta)) {
140 summary.termination_type = LINEAR_SOLVER_FAILURE;
141 summary.message = StringPrintf(
142 "Numerical failure. beta = rho_n / rho_{n-1} = %e.", beta);
143 break;
144 }
145 p = z + beta * p;
146 }
147
148 Vector& q = z;
149 q.setZero();
150 A->RightMultiply(p.data(), q.data());
151 const double pq = p.dot(q);
152 if ((pq <= 0) || IsInfinite(pq)) {
153 summary.termination_type = LINEAR_SOLVER_FAILURE;
154 summary.message = StringPrintf("Numerical failure. p'q = %e.", pq);
155 break;
156 }
157
158 const double alpha = rho / pq;
159 if (IsInfinite(alpha)) {
160 summary.termination_type = LINEAR_SOLVER_FAILURE;
161 summary.message =
162 StringPrintf("Numerical failure. alpha = rho / pq = %e", alpha);
163 break;
164 }
165
166 xref = xref + alpha * p;
167
168 // Ideally we would just use the update r = r - alpha*q to keep
169 // track of the residual vector. However this estimate tends to
170 // drift over time due to round off errors. Thus every
171 // residual_reset_period iterations, we calculate the residual as
172 // r = b - Ax. We do not do this every iteration because this
173 // requires an additional matrix vector multiply which would
174 // double the complexity of the CG algorithm.
175 if (summary.num_iterations % options_.residual_reset_period == 0) {
176 tmp.setZero();
177 A->RightMultiply(x, tmp.data());
178 r = bref - tmp;
179 } else {
180 r = r - alpha * q;
181 }
182
183 // Quadratic model based termination.
184 // Q1 = x'Ax - 2 * b' x.
185 const double Q1 = -1.0 * xref.dot(bref + r);
186
187 // For PSD matrices A, let
188 //
189 // Q(x) = x'Ax - 2b'x
190 //
191 // be the cost of the quadratic function defined by A and b. Then,
192 // the solver terminates at iteration i if
193 //
194 // i * (Q(x_i) - Q(x_i-1)) / Q(x_i) < q_tolerance.
195 //
196 // This termination criterion is more useful when using CG to
197 // solve the Newton step. This particular convergence test comes
198 // from Stephen Nash's work on truncated Newton
199 // methods. References:
200 //
201 // 1. Stephen G. Nash & Ariela Sofer, Assessing A Search
202 // Direction Within A Truncated Newton Method, Operation
203 // Research Letters 9(1990) 219-221.
204 //
205 // 2. Stephen G. Nash, A Survey of Truncated Newton Methods,
206 // Journal of Computational and Applied Mathematics,
207 // 124(1-2), 45-59, 2000.
208 //
209 const double zeta = summary.num_iterations * (Q1 - Q0) / Q1;
210 if (zeta < per_solve_options.q_tolerance) {
211 summary.termination_type = LINEAR_SOLVER_SUCCESS;
212 summary.message =
213 StringPrintf("Convergence: zeta = %e < %e",
214 zeta,
215 per_solve_options.q_tolerance);
216 break;
217 }
218 Q0 = Q1;
219
220 // Residual based termination.
221 norm_r = r. norm();
222 if (norm_r <= tol_r) {
223 summary.termination_type = LINEAR_SOLVER_SUCCESS;
224 summary.message =
225 StringPrintf("Convergence. |r| = %e <= %e.", norm_r, tol_r);
226 break;
227 }
228 }
229
230 return summary;
231 };
232
233 } // namespace internal
234 } // namespace ceres
235