1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.commons.math.optimization.general;
19 
20 import org.apache.commons.math.ConvergenceException;
21 import org.apache.commons.math.FunctionEvaluationException;
22 import org.apache.commons.math.analysis.UnivariateRealFunction;
23 import org.apache.commons.math.analysis.solvers.BrentSolver;
24 import org.apache.commons.math.analysis.solvers.UnivariateRealSolver;
25 import org.apache.commons.math.exception.util.LocalizedFormats;
26 import org.apache.commons.math.optimization.GoalType;
27 import org.apache.commons.math.optimization.OptimizationException;
28 import org.apache.commons.math.optimization.RealPointValuePair;
29 import org.apache.commons.math.util.FastMath;
30 
31 /**
32  * Non-linear conjugate gradient optimizer.
33  * <p>
34  * This class supports both the Fletcher-Reeves and the Polak-Ribi&egrave;re
35  * update formulas for the conjugate search directions. It also supports
36  * optional preconditioning.
37  * </p>
38  *
39  * @version $Revision: 1070725 $ $Date: 2011-02-15 02:31:12 +0100 (mar. 15 févr. 2011) $
40  * @since 2.0
41  *
42  */
43 
44 public class NonLinearConjugateGradientOptimizer
45     extends AbstractScalarDifferentiableOptimizer {
46 
47     /** Update formula for the beta parameter. */
48     private final ConjugateGradientFormula updateFormula;
49 
50     /** Preconditioner (may be null). */
51     private Preconditioner preconditioner;
52 
53     /** solver to use in the line search (may be null). */
54     private UnivariateRealSolver solver;
55 
56     /** Initial step used to bracket the optimum in line search. */
57     private double initialStep;
58 
59     /** Simple constructor with default settings.
60      * <p>The convergence check is set to a {@link
61      * org.apache.commons.math.optimization.SimpleVectorialValueChecker}
62      * and the maximal number of iterations is set to
63      * {@link AbstractScalarDifferentiableOptimizer#DEFAULT_MAX_ITERATIONS}.
64      * @param updateFormula formula to use for updating the &beta; parameter,
65      * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
66      * ConjugateGradientFormula#POLAK_RIBIERE}
67      */
NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula)68     public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) {
69         this.updateFormula = updateFormula;
70         preconditioner     = null;
71         solver             = null;
72         initialStep        = 1.0;
73     }
74 
75     /**
76      * Set the preconditioner.
77      * @param preconditioner preconditioner to use for next optimization,
78      * may be null to remove an already registered preconditioner
79      */
setPreconditioner(final Preconditioner preconditioner)80     public void setPreconditioner(final Preconditioner preconditioner) {
81         this.preconditioner = preconditioner;
82     }
83 
84     /**
85      * Set the solver to use during line search.
86      * @param lineSearchSolver solver to use during line search, may be null
87      * to remove an already registered solver and fall back to the
88      * default {@link BrentSolver Brent solver}.
89      */
setLineSearchSolver(final UnivariateRealSolver lineSearchSolver)90     public void setLineSearchSolver(final UnivariateRealSolver lineSearchSolver) {
91         this.solver = lineSearchSolver;
92     }
93 
94     /**
95      * Set the initial step used to bracket the optimum in line search.
96      * <p>
97      * The initial step is a factor with respect to the search direction,
98      * which itself is roughly related to the gradient of the function
99      * </p>
100      * @param initialStep initial step used to bracket the optimum in line search,
101      * if a non-positive value is used, the initial step is reset to its
102      * default value of 1.0
103      */
setInitialStep(final double initialStep)104     public void setInitialStep(final double initialStep) {
105         if (initialStep <= 0) {
106             this.initialStep = 1.0;
107         } else {
108             this.initialStep = initialStep;
109         }
110     }
111 
112     /** {@inheritDoc} */
113     @Override
doOptimize()114     protected RealPointValuePair doOptimize()
115         throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
116         try {
117 
118             // initialization
119             if (preconditioner == null) {
120                 preconditioner = new IdentityPreconditioner();
121             }
122             if (solver == null) {
123                 solver = new BrentSolver();
124             }
125             final int n = point.length;
126             double[] r = computeObjectiveGradient(point);
127             if (goal == GoalType.MINIMIZE) {
128                 for (int i = 0; i < n; ++i) {
129                     r[i] = -r[i];
130                 }
131             }
132 
133             // initial search direction
134             double[] steepestDescent = preconditioner.precondition(point, r);
135             double[] searchDirection = steepestDescent.clone();
136 
137             double delta = 0;
138             for (int i = 0; i < n; ++i) {
139                 delta += r[i] * searchDirection[i];
140             }
141 
142             RealPointValuePair current = null;
143             while (true) {
144 
145                 final double objective = computeObjectiveValue(point);
146                 RealPointValuePair previous = current;
147                 current = new RealPointValuePair(point, objective);
148                 if (previous != null) {
149                     if (checker.converged(getIterations(), previous, current)) {
150                         // we have found an optimum
151                         return current;
152                     }
153                 }
154 
155                 incrementIterationsCounter();
156 
157                 double dTd = 0;
158                 for (final double di : searchDirection) {
159                     dTd += di * di;
160                 }
161 
162                 // find the optimal step in the search direction
163                 final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection);
164                 final double step = solver.solve(lsf, 0, findUpperBound(lsf, 0, initialStep));
165 
166                 // validate new point
167                 for (int i = 0; i < point.length; ++i) {
168                     point[i] += step * searchDirection[i];
169                 }
170                 r = computeObjectiveGradient(point);
171                 if (goal == GoalType.MINIMIZE) {
172                     for (int i = 0; i < n; ++i) {
173                         r[i] = -r[i];
174                     }
175                 }
176 
177                 // compute beta
178                 final double deltaOld = delta;
179                 final double[] newSteepestDescent = preconditioner.precondition(point, r);
180                 delta = 0;
181                 for (int i = 0; i < n; ++i) {
182                     delta += r[i] * newSteepestDescent[i];
183                 }
184 
185                 final double beta;
186                 if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) {
187                     beta = delta / deltaOld;
188                 } else {
189                     double deltaMid = 0;
190                     for (int i = 0; i < r.length; ++i) {
191                         deltaMid += r[i] * steepestDescent[i];
192                     }
193                     beta = (delta - deltaMid) / deltaOld;
194                 }
195                 steepestDescent = newSteepestDescent;
196 
197                 // compute conjugate search direction
198                 if ((getIterations() % n == 0) || (beta < 0)) {
199                     // break conjugation: reset search direction
200                     searchDirection = steepestDescent.clone();
201                 } else {
202                     // compute new conjugate search direction
203                     for (int i = 0; i < n; ++i) {
204                         searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
205                     }
206                 }
207 
208             }
209 
210         } catch (ConvergenceException ce) {
211             throw new OptimizationException(ce);
212         }
213     }
214 
215     /**
216      * Find the upper bound b ensuring bracketing of a root between a and b
217      * @param f function whose root must be bracketed
218      * @param a lower bound of the interval
219      * @param h initial step to try
220      * @return b such that f(a) and f(b) have opposite signs
221      * @exception FunctionEvaluationException if the function cannot be computed
222      * @exception OptimizationException if no bracket can be found
223      */
findUpperBound(final UnivariateRealFunction f, final double a, final double h)224     private double findUpperBound(final UnivariateRealFunction f,
225                                   final double a, final double h)
226         throws FunctionEvaluationException, OptimizationException {
227         final double yA = f.value(a);
228         double yB = yA;
229         for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
230             final double b = a + step;
231             yB = f.value(b);
232             if (yA * yB <= 0) {
233                 return b;
234             }
235         }
236         throw new OptimizationException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
237     }
238 
239     /** Default identity preconditioner. */
240     private static class IdentityPreconditioner implements Preconditioner {
241 
242         /** {@inheritDoc} */
precondition(double[] variables, double[] r)243         public double[] precondition(double[] variables, double[] r) {
244             return r.clone();
245         }
246 
247     }
248 
249     /** Internal class for line search.
250      * <p>
251      * The function represented by this class is the dot product of
252      * the objective function gradient and the search direction. Its
253      * value is zero when the gradient is orthogonal to the search
254      * direction, i.e. when the objective function value is a local
255      * extremum along the search direction.
256      * </p>
257      */
258     private class LineSearchFunction implements UnivariateRealFunction {
259         /** Search direction. */
260         private final double[] searchDirection;
261 
262         /** Simple constructor.
263          * @param searchDirection search direction
264          */
LineSearchFunction(final double[] searchDirection)265         public LineSearchFunction(final double[] searchDirection) {
266             this.searchDirection = searchDirection;
267         }
268 
269         /** {@inheritDoc} */
value(double x)270         public double value(double x) throws FunctionEvaluationException {
271 
272             // current point in the search direction
273             final double[] shiftedPoint = point.clone();
274             for (int i = 0; i < shiftedPoint.length; ++i) {
275                 shiftedPoint[i] += x * searchDirection[i];
276             }
277 
278             // gradient of the objective function
279             final double[] gradient;
280             gradient = computeObjectiveGradient(shiftedPoint);
281 
282             // dot product with the search direction
283             double dotProduct = 0;
284             for (int i = 0; i < gradient.length; ++i) {
285                 dotProduct += gradient[i] * searchDirection[i];
286             }
287 
288             return dotProduct;
289 
290         }
291 
292     }
293 
294 }
295