1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.math.optimization.general; 19 20 import org.apache.commons.math.ConvergenceException; 21 import org.apache.commons.math.FunctionEvaluationException; 22 import org.apache.commons.math.analysis.UnivariateRealFunction; 23 import org.apache.commons.math.analysis.solvers.BrentSolver; 24 import org.apache.commons.math.analysis.solvers.UnivariateRealSolver; 25 import org.apache.commons.math.exception.util.LocalizedFormats; 26 import org.apache.commons.math.optimization.GoalType; 27 import org.apache.commons.math.optimization.OptimizationException; 28 import org.apache.commons.math.optimization.RealPointValuePair; 29 import org.apache.commons.math.util.FastMath; 30 31 /** 32 * Non-linear conjugate gradient optimizer. 33 * <p> 34 * This class supports both the Fletcher-Reeves and the Polak-Ribière 35 * update formulas for the conjugate search directions. It also supports 36 * optional preconditioning. 37 * </p> 38 * 39 * @version $Revision: 1070725 $ $Date: 2011-02-15 02:31:12 +0100 (mar. 15 févr. 2011) $ 40 * @since 2.0 41 * 42 */ 43 44 public class NonLinearConjugateGradientOptimizer 45 extends AbstractScalarDifferentiableOptimizer { 46 47 /** Update formula for the beta parameter. */ 48 private final ConjugateGradientFormula updateFormula; 49 50 /** Preconditioner (may be null). */ 51 private Preconditioner preconditioner; 52 53 /** solver to use in the line search (may be null). */ 54 private UnivariateRealSolver solver; 55 56 /** Initial step used to bracket the optimum in line search. */ 57 private double initialStep; 58 59 /** Simple constructor with default settings. 60 * <p>The convergence check is set to a {@link 61 * org.apache.commons.math.optimization.SimpleVectorialValueChecker} 62 * and the maximal number of iterations is set to 63 * {@link AbstractScalarDifferentiableOptimizer#DEFAULT_MAX_ITERATIONS}. 64 * @param updateFormula formula to use for updating the β parameter, 65 * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link 66 * ConjugateGradientFormula#POLAK_RIBIERE} 67 */ NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula)68 public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) { 69 this.updateFormula = updateFormula; 70 preconditioner = null; 71 solver = null; 72 initialStep = 1.0; 73 } 74 75 /** 76 * Set the preconditioner. 77 * @param preconditioner preconditioner to use for next optimization, 78 * may be null to remove an already registered preconditioner 79 */ setPreconditioner(final Preconditioner preconditioner)80 public void setPreconditioner(final Preconditioner preconditioner) { 81 this.preconditioner = preconditioner; 82 } 83 84 /** 85 * Set the solver to use during line search. 86 * @param lineSearchSolver solver to use during line search, may be null 87 * to remove an already registered solver and fall back to the 88 * default {@link BrentSolver Brent solver}. 89 */ setLineSearchSolver(final UnivariateRealSolver lineSearchSolver)90 public void setLineSearchSolver(final UnivariateRealSolver lineSearchSolver) { 91 this.solver = lineSearchSolver; 92 } 93 94 /** 95 * Set the initial step used to bracket the optimum in line search. 96 * <p> 97 * The initial step is a factor with respect to the search direction, 98 * which itself is roughly related to the gradient of the function 99 * </p> 100 * @param initialStep initial step used to bracket the optimum in line search, 101 * if a non-positive value is used, the initial step is reset to its 102 * default value of 1.0 103 */ setInitialStep(final double initialStep)104 public void setInitialStep(final double initialStep) { 105 if (initialStep <= 0) { 106 this.initialStep = 1.0; 107 } else { 108 this.initialStep = initialStep; 109 } 110 } 111 112 /** {@inheritDoc} */ 113 @Override doOptimize()114 protected RealPointValuePair doOptimize() 115 throws FunctionEvaluationException, OptimizationException, IllegalArgumentException { 116 try { 117 118 // initialization 119 if (preconditioner == null) { 120 preconditioner = new IdentityPreconditioner(); 121 } 122 if (solver == null) { 123 solver = new BrentSolver(); 124 } 125 final int n = point.length; 126 double[] r = computeObjectiveGradient(point); 127 if (goal == GoalType.MINIMIZE) { 128 for (int i = 0; i < n; ++i) { 129 r[i] = -r[i]; 130 } 131 } 132 133 // initial search direction 134 double[] steepestDescent = preconditioner.precondition(point, r); 135 double[] searchDirection = steepestDescent.clone(); 136 137 double delta = 0; 138 for (int i = 0; i < n; ++i) { 139 delta += r[i] * searchDirection[i]; 140 } 141 142 RealPointValuePair current = null; 143 while (true) { 144 145 final double objective = computeObjectiveValue(point); 146 RealPointValuePair previous = current; 147 current = new RealPointValuePair(point, objective); 148 if (previous != null) { 149 if (checker.converged(getIterations(), previous, current)) { 150 // we have found an optimum 151 return current; 152 } 153 } 154 155 incrementIterationsCounter(); 156 157 double dTd = 0; 158 for (final double di : searchDirection) { 159 dTd += di * di; 160 } 161 162 // find the optimal step in the search direction 163 final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection); 164 final double step = solver.solve(lsf, 0, findUpperBound(lsf, 0, initialStep)); 165 166 // validate new point 167 for (int i = 0; i < point.length; ++i) { 168 point[i] += step * searchDirection[i]; 169 } 170 r = computeObjectiveGradient(point); 171 if (goal == GoalType.MINIMIZE) { 172 for (int i = 0; i < n; ++i) { 173 r[i] = -r[i]; 174 } 175 } 176 177 // compute beta 178 final double deltaOld = delta; 179 final double[] newSteepestDescent = preconditioner.precondition(point, r); 180 delta = 0; 181 for (int i = 0; i < n; ++i) { 182 delta += r[i] * newSteepestDescent[i]; 183 } 184 185 final double beta; 186 if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) { 187 beta = delta / deltaOld; 188 } else { 189 double deltaMid = 0; 190 for (int i = 0; i < r.length; ++i) { 191 deltaMid += r[i] * steepestDescent[i]; 192 } 193 beta = (delta - deltaMid) / deltaOld; 194 } 195 steepestDescent = newSteepestDescent; 196 197 // compute conjugate search direction 198 if ((getIterations() % n == 0) || (beta < 0)) { 199 // break conjugation: reset search direction 200 searchDirection = steepestDescent.clone(); 201 } else { 202 // compute new conjugate search direction 203 for (int i = 0; i < n; ++i) { 204 searchDirection[i] = steepestDescent[i] + beta * searchDirection[i]; 205 } 206 } 207 208 } 209 210 } catch (ConvergenceException ce) { 211 throw new OptimizationException(ce); 212 } 213 } 214 215 /** 216 * Find the upper bound b ensuring bracketing of a root between a and b 217 * @param f function whose root must be bracketed 218 * @param a lower bound of the interval 219 * @param h initial step to try 220 * @return b such that f(a) and f(b) have opposite signs 221 * @exception FunctionEvaluationException if the function cannot be computed 222 * @exception OptimizationException if no bracket can be found 223 */ findUpperBound(final UnivariateRealFunction f, final double a, final double h)224 private double findUpperBound(final UnivariateRealFunction f, 225 final double a, final double h) 226 throws FunctionEvaluationException, OptimizationException { 227 final double yA = f.value(a); 228 double yB = yA; 229 for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) { 230 final double b = a + step; 231 yB = f.value(b); 232 if (yA * yB <= 0) { 233 return b; 234 } 235 } 236 throw new OptimizationException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH); 237 } 238 239 /** Default identity preconditioner. */ 240 private static class IdentityPreconditioner implements Preconditioner { 241 242 /** {@inheritDoc} */ precondition(double[] variables, double[] r)243 public double[] precondition(double[] variables, double[] r) { 244 return r.clone(); 245 } 246 247 } 248 249 /** Internal class for line search. 250 * <p> 251 * The function represented by this class is the dot product of 252 * the objective function gradient and the search direction. Its 253 * value is zero when the gradient is orthogonal to the search 254 * direction, i.e. when the objective function value is a local 255 * extremum along the search direction. 256 * </p> 257 */ 258 private class LineSearchFunction implements UnivariateRealFunction { 259 /** Search direction. */ 260 private final double[] searchDirection; 261 262 /** Simple constructor. 263 * @param searchDirection search direction 264 */ LineSearchFunction(final double[] searchDirection)265 public LineSearchFunction(final double[] searchDirection) { 266 this.searchDirection = searchDirection; 267 } 268 269 /** {@inheritDoc} */ value(double x)270 public double value(double x) throws FunctionEvaluationException { 271 272 // current point in the search direction 273 final double[] shiftedPoint = point.clone(); 274 for (int i = 0; i < shiftedPoint.length; ++i) { 275 shiftedPoint[i] += x * searchDirection[i]; 276 } 277 278 // gradient of the objective function 279 final double[] gradient; 280 gradient = computeObjectiveGradient(shiftedPoint); 281 282 // dot product with the search direction 283 double dotProduct = 0; 284 for (int i = 0; i < gradient.length; ++i) { 285 dotProduct += gradient[i] * searchDirection[i]; 286 } 287 288 return dotProduct; 289 290 } 291 292 } 293 294 } 295