public void TestRandomGridSweeper() { var env = new MLContext(42); var args = new RandomGridSweeper.Options() { SweptParameters = new[] { ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamOptions() { Name = "foo", Min = 10, Max = 20, NumSteps = 3 })), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 100, Max = 10000, LogBase = true, StepSize = 10 })) } }; var sweeper = new RandomGridSweeper(env, args); var initialList = sweeper.ProposeSweeps(5, new List <RunResult>()); Assert.Equal(5, initialList.Length); var gridPoint = new bool[3][] { new bool[3], new bool[3], new bool[3] }; int i = 0; int j = 0; foreach (var parameterSet in initialList) { foreach (var parameterValue in parameterSet) { if (parameterValue.Name == "foo") { var val = long.Parse(parameterValue.ValueText); Assert.True(val == 10 || val == 15 || val == 20); i = (val == 10) ? 0 : (val == 15) ? 1 : 2; } else if (parameterValue.Name == "bar") { var val = long.Parse(parameterValue.ValueText); Assert.True(val == 100 || val == 1000 || val == 10000); j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; } else { Assert.True(false, "Wrong parameter"); } } Assert.False(gridPoint[i][j]); gridPoint[i][j] = true; } var nextList = sweeper.ProposeSweeps(5, initialList.Select(p => new RunResult(p))); Assert.Equal(4, nextList.Length); foreach (var parameterSet in nextList) { foreach (var parameterValue in parameterSet) { if (parameterValue.Name == "foo") { var val = long.Parse(parameterValue.ValueText); Assert.True(val == 10 || val == 15 || val == 20); i = (val == 10) ? 0 : (val == 15) ? 1 : 2; } else if (parameterValue.Name == "bar") { var val = long.Parse(parameterValue.ValueText); Assert.True(val == 100 || val == 1000 || val == 10000); j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; } else { Assert.True(false, "Wrong parameter"); } } Assert.False(gridPoint[i][j]); gridPoint[i][j] = true; } gridPoint = new bool[3][] { new bool[3], new bool[3], new bool[3] }; var lastList = sweeper.ProposeSweeps(10, null); Assert.Equal(9, lastList.Length); foreach (var parameterSet in lastList) { foreach (var parameterValue in parameterSet) { if (parameterValue.Name == "foo") { var val = long.Parse(parameterValue.ValueText); Assert.True(val == 10 || val == 15 || val == 20); i = (val == 10) ? 0 : (val == 15) ? 1 : 2; } else if (parameterValue.Name == "bar") { var val = long.Parse(parameterValue.ValueText); Assert.True(val == 100 || val == 1000 || val == 10000); j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; } else { Assert.True(false, "Wrong parameter"); } } Assert.False(gridPoint[i][j]); gridPoint[i][j] = true; } Assert.True(gridPoint.All(bArray => bArray.All(b => b))); }
public void TestSimpleSweeperAsync() { var random = new Random(42); var env = new MLContext(42); const int sweeps = 100; var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.OptionsBase { SweptParameters = new IComponentFactory <IValueGenerator>[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5 })), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) } }); var paramSets = new List <ParameterSet>(); for (int i = 0; i < sweeps; i++) { var task = sweeper.Propose(); Assert.True(task.IsCompleted); paramSets.Add(task.Result.ParameterSet); var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true); sweeper.Update(task.Result.Id, result); } Assert.Equal(sweeps, paramSets.Count); CheckAsyncSweeperResult(paramSets); // Test consumption without ever calling Update. var gridArgs = new RandomGridSweeper.Options(); gridArgs.SweptParameters = new IComponentFactory <INumericValueGenerator>[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5 })), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 100, LogBase = true })) }; var gridSweeper = new SimpleAsyncSweeper(env, gridArgs); paramSets.Clear(); for (int i = 0; i < sweeps; i++) { var task = gridSweeper.Propose(); Assert.True(task.IsCompleted); paramSets.Add(task.Result.ParameterSet); } Assert.Equal(sweeps, paramSets.Count); CheckAsyncSweeperResult(paramSets); }