public NelderMeadSweeper(IHostEnvironment env, Arguments args) { Contracts.CheckValue(env, nameof(env)); env.CheckUserArg(-1 < args.DeltaInsideContraction, nameof(args.DeltaInsideContraction), "Must be greater than -1"); env.CheckUserArg(args.DeltaInsideContraction < 0, nameof(args.DeltaInsideContraction), "Must be less than 0"); env.CheckUserArg(0 < args.DeltaOutsideContraction, nameof(args.DeltaOutsideContraction), "Must be greater than 0"); env.CheckUserArg(args.DeltaReflection > args.DeltaOutsideContraction, nameof(args.DeltaReflection), "Must be greater than " + nameof(args.DeltaOutsideContraction)); env.CheckUserArg(args.DeltaExpansion > args.DeltaReflection, nameof(args.DeltaExpansion), "Must be greater than " + nameof(args.DeltaReflection)); env.CheckUserArg(0 < args.GammaShrink && args.GammaShrink < 1, nameof(args.GammaShrink), "Must be between 0 and 1"); env.CheckValue(args.FirstBatchSweeper, nameof(args.FirstBatchSweeper), "First Batch Sweeper Contains Null Value"); _args = args; _sweepParameters = new List <IValueGenerator>(); foreach (var sweptParameter in args.SweptParameters) { var parameter = sweptParameter.CreateComponent(env); // REVIEW: ideas about how to support discrete values: // 1. assign each discrete value a random number (1-n) to make mirroring possible // 2. each time we need to mirror a discrete value, sample from the remaining value // 2.1. make the sampling non-uniform by learning "weights" for the different discrete values based on // the metric values that we get when using them. (For example, if, for a given discrete value, we get a bad result, // we lower its weight, but if we get a good result we increase its weight). var parameterNumeric = parameter as INumericValueGenerator; env.CheckUserArg(parameterNumeric != null, nameof(args.SweptParameters), "Nelder-Mead sweeper can only sweep over numeric parameters"); _sweepParameters.Add(parameterNumeric); } _initSweeper = args.FirstBatchSweeper.CreateComponent(env, _sweepParameters.ToArray()); _dim = _sweepParameters.Count; env.CheckUserArg(_dim > 1, nameof(args.SweptParameters), "Nelder-Mead sweeper needs at least two parameters to sweep over."); _simplexVertices = new SortedList <IRunResult, Float[]>(new SimplexVertexComparer()); _stage = OptimizationStage.NeedReflectionPoint; _pendingSweeps = new List <KeyValuePair <ParameterSet, Float[]> >(); _pendingSweepsNotSubmitted = new Queue <KeyValuePair <ParameterSet, Float[]> >(); }
public ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable <IRunResult> previousRuns = null) { int numSweeps = Math.Min(maxSweeps, _dim + 1 - _simplexVertices.Count); if (previousRuns == null) { return(_initSweeper.ProposeSweeps(numSweeps, previousRuns)); } foreach (var run in previousRuns) { Contracts.Check(run != null); } foreach (var run in previousRuns) { if (_simplexVertices.Count == _dim + 1) { break; } if (!_simplexVertices.ContainsKey(run)) { _simplexVertices.Add(run, ParameterSetAsFloatArray(run.ParameterSet)); } if (_simplexVertices.Count == _dim + 1) { ComputeExtremes(); } } if (_simplexVertices.Count < _dim + 1) { numSweeps = Math.Min(maxSweeps, _dim + 1 - _simplexVertices.Count); return(_initSweeper.ProposeSweeps(numSweeps, previousRuns)); } switch (_stage) { case OptimizationStage.NeedReflectionPoint: _pendingSweeps.Clear(); var nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaReflection); if (OutOfBounds(nextPoint) && _args.ProjectInbounds) { // if the reflection point is out of bounds, get the inner contraction point. nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction); _stage = OptimizationStage.WaitingForInnerContractionResult; } else { _stage = OptimizationStage.WaitingForReflectionResult; } _pendingSweeps.Add(new KeyValuePair <ParameterSet, Float[]>(FloatArrayAsParameterSet(nextPoint), nextPoint)); if (previousRuns.Any(runResult => runResult.ParameterSet.Equals(_pendingSweeps[0].Key))) { _stage = OptimizationStage.WaitingForReductionResult; _pendingSweeps.Clear(); if (!TryGetReductionPoints(maxSweeps, previousRuns)) { _stage = OptimizationStage.Done; return(null); } return(_pendingSweeps.Select(kvp => kvp.Key).ToArray()); } return(new ParameterSet[] { _pendingSweeps[0].Key }); case OptimizationStage.WaitingForReflectionResult: Contracts.Assert(_pendingSweeps.Count == 1); _lastReflectionResult = FindRunResult(previousRuns)[0]; if (_secondWorst.Key.CompareTo(_lastReflectionResult.Key) < 0 && _lastReflectionResult.Key.CompareTo(_best.Key) <= 0) { // the reflection result is better than the second worse, but not better than the best UpdateSimplex(_lastReflectionResult.Key, _lastReflectionResult.Value); goto case OptimizationStage.NeedReflectionPoint; } if (_lastReflectionResult.Key.CompareTo(_best.Key) > 0) { // the reflection result is the best so far nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaExpansion); if (OutOfBounds(nextPoint) && _args.ProjectInbounds) { // if the expansion point is out of bounds, get the inner contraction point. nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction); _stage = OptimizationStage.WaitingForInnerContractionResult; } else { _stage = OptimizationStage.WaitingForExpansionResult; } } else if (_lastReflectionResult.Key.CompareTo(_worst.Key) > 0) { // other wise, get results for the outer contraction point. nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaOutsideContraction); _stage = OptimizationStage.WaitingForOuterContractionResult; } else { // other wise, reflection result is not better than worst, get results for the inner contraction point nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction); _stage = OptimizationStage.WaitingForInnerContractionResult; } _pendingSweeps.Clear(); _pendingSweeps.Add(new KeyValuePair <ParameterSet, Float[]>(FloatArrayAsParameterSet(nextPoint), nextPoint)); if (previousRuns.Any(runResult => runResult.ParameterSet.Equals(_pendingSweeps[0].Key))) { _stage = OptimizationStage.WaitingForReductionResult; _pendingSweeps.Clear(); if (!TryGetReductionPoints(maxSweeps, previousRuns)) { _stage = OptimizationStage.Done; return(null); } return(_pendingSweeps.Select(kvp => kvp.Key).ToArray()); } return(new ParameterSet[] { _pendingSweeps[0].Key }); case OptimizationStage.WaitingForExpansionResult: Contracts.Assert(_pendingSweeps.Count == 1); var expansionResult = FindRunResult(previousRuns)[0].Key; if (expansionResult.CompareTo(_lastReflectionResult.Key) > 0) { // expansion point is better than reflection point UpdateSimplex(expansionResult, _pendingSweeps[0].Value); goto case OptimizationStage.NeedReflectionPoint; } // reflection point is better than expansion point UpdateSimplex(_lastReflectionResult.Key, _lastReflectionResult.Value); goto case OptimizationStage.NeedReflectionPoint; case OptimizationStage.WaitingForOuterContractionResult: Contracts.Assert(_pendingSweeps.Count == 1); var outerContractionResult = FindRunResult(previousRuns)[0].Key; if (outerContractionResult.CompareTo(_lastReflectionResult.Key) > 0) { // outer contraction point is better than reflection point UpdateSimplex(outerContractionResult, _pendingSweeps[0].Value); goto case OptimizationStage.NeedReflectionPoint; } // get the reduction points _stage = OptimizationStage.WaitingForReductionResult; _pendingSweeps.Clear(); if (!TryGetReductionPoints(maxSweeps, previousRuns)) { _stage = OptimizationStage.Done; return(null); } return(_pendingSweeps.Select(kvp => kvp.Key).ToArray()); case OptimizationStage.WaitingForInnerContractionResult: Contracts.Assert(_pendingSweeps.Count == 1); var innerContractionResult = FindRunResult(previousRuns)[0].Key; if (innerContractionResult.CompareTo(_worst.Key) > 0) { // inner contraction point is better than worst point UpdateSimplex(innerContractionResult, _pendingSweeps[0].Value); goto case OptimizationStage.NeedReflectionPoint; } // get the reduction points _stage = OptimizationStage.WaitingForReductionResult; _pendingSweeps.Clear(); if (!TryGetReductionPoints(maxSweeps, previousRuns)) { _stage = OptimizationStage.Done; return(null); } return(_pendingSweeps.Select(kvp => kvp.Key).ToArray()); case OptimizationStage.WaitingForReductionResult: Contracts.Assert(_pendingSweeps.Count + _pendingSweepsNotSubmitted.Count == _dim); if (_pendingSweeps.Count < _dim) { return(SubmitMoreReductionPoints(maxSweeps)); } ReplaceSimplexVertices(previousRuns); // if the diameter of the new simplex has become too small, stop sweeping. if (SimplexDiameter() < _args.StoppingSimplexDiameter) { return(null); } goto case OptimizationStage.NeedReflectionPoint; case OptimizationStage.Done: default: return(null); } }