public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper() { var random = new Random(42); var env = new MLContext(42); var param = new IComponentFactory <INumericValueGenerator>[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5 })), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; var args = new NelderMeadSweeper.Options(); args.SweptParameters = param; var sweeper = new NelderMeadSweeper(env, args); var sweeps = sweeper.ProposeSweeps(5, new List <RunResult>()); Assert.Equal(3, sweeps.Length); var results = new List <IRunResult>(); for (int i = 1; i < 10; i++) { foreach (var parameterSet in sweeps) { foreach (var parameterValue in parameterSet) { if (parameterValue.Name == "foo") { var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture); Assert.InRange(val, 1, 5); } else if (parameterValue.Name == "bar") { var val = long.Parse(parameterValue.ValueText); Assert.InRange(val, 1, 1000); } else { Assert.True(false, "Wrong parameter"); } } results.Add(new RunResult(parameterSet, random.NextDouble(), true)); } sweeps = sweeper.ProposeSweeps(5, results); } Assert.True(sweeps == null || sweeps.Length <= 5); }
public async Task TestNelderMeadSweeperAsync() { var random = new Random(42); var env = new MLContext(42); const int batchSize = 5; const int sweeps = 40; var paramSets = new List <ParameterSet>(); var args = new DeterministicSweeperAsync.Options(); args.BatchSize = batchSize; args.Relaxation = 0; args.Sweeper = ComponentFactoryUtils.CreateFromFunction( environ => { var param = new IComponentFactory <INumericValueGenerator>[] { ComponentFactoryUtils.CreateFromFunction( innerEnviron => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5 })), ComponentFactoryUtils.CreateFromFunction( innerEnviron => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; var nelderMeadSweeperArgs = new NelderMeadSweeper.Options() { SweptParameters = param, FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction <IValueGenerator[], ISweeper>( (firstBatchSweeperEnviron, firstBatchSweeperArgs) => new RandomGridSweeper(environ, new RandomGridSweeper.Options() { SweptParameters = param })) }; return(new NelderMeadSweeper(environ, nelderMeadSweeperArgs)); } ); var sweeper = new DeterministicSweeperAsync(env, args); var mlock = new object(); double[] metrics = new double[sweeps]; for (int i = 0; i < metrics.Length; i++) { metrics[i] = random.NextDouble(); } for (int i = 0; i < sweeps; i++) { var paramWithId = await sweeper.Propose(); if (paramWithId == null) { return; } var result = new RunResult(paramWithId.ParameterSet, metrics[i], true); sweeper.Update(paramWithId.Id, result); lock (mlock) paramSets.Add(paramWithId.ParameterSet); } Assert.True(paramSets.Count <= sweeps); CheckAsyncSweeperResult(paramSets); }