internal static IValueGenerator ToIValueGenerator(TlcModule.SweepableParamAttribute attr) { if (attr is TlcModule.SweepableLongParamAttribute sweepableLongParamAttr) { var args = new LongParamArguments { Min = sweepableLongParamAttr.Min, Max = sweepableLongParamAttr.Max, LogBase = sweepableLongParamAttr.IsLogScale, Name = sweepableLongParamAttr.Name, StepSize = sweepableLongParamAttr.StepSize }; if (sweepableLongParamAttr.NumSteps != null) { args.NumSteps = (int)sweepableLongParamAttr.NumSteps; } return(new LongValueGenerator(args)); } if (attr is TlcModule.SweepableFloatParamAttribute sweepableFloatParamAttr) { var args = new FloatParamArguments { Min = sweepableFloatParamAttr.Min, Max = sweepableFloatParamAttr.Max, LogBase = sweepableFloatParamAttr.IsLogScale, Name = sweepableFloatParamAttr.Name, StepSize = sweepableFloatParamAttr.StepSize }; if (sweepableFloatParamAttr.NumSteps != null) { args.NumSteps = (int)sweepableFloatParamAttr.NumSteps; } return(new FloatValueGenerator(args)); } if (attr is TlcModule.SweepableDiscreteParamAttribute sweepableDiscreteParamAttr) { var args = new DiscreteParamArguments { Name = sweepableDiscreteParamAttr.Name, Values = sweepableDiscreteParamAttr.Options.Select(o => o.ToString()).ToArray() }; return(new DiscreteValueGenerator(args)); } throw new Exception($"Sweeping only supported for Discrete, Long, and Float parameter types. Unrecognized type {attr.GetType()}"); }
/// <summary> /// Method to convert set of sweepable hyperparameters into <see cref="IComponentFactory"/> instances used /// by the current smart hyperparameter sweepers. /// </summary> internal static IComponentFactory <IValueGenerator>[] ConvertToComponentFactories(TlcModule.SweepableParamAttribute[] hps) { var results = new IComponentFactory <IValueGenerator> [hps.Length]; for (int i = 0; i < hps.Length; i++) { switch (hps[i]) { case TlcModule.SweepableDiscreteParamAttribute dp: results[i] = ComponentFactoryUtils.CreateFromFunction(env => { var dpArgs = new DiscreteParamArguments() { Name = dp.Name, Values = dp.Options.Select(o => o.ToString()).ToArray() }; return(new DiscreteValueGenerator(dpArgs)); }); break; case TlcModule.SweepableFloatParamAttribute fp: results[i] = ComponentFactoryUtils.CreateFromFunction(env => { var fpArgs = new FloatParamArguments() { Name = fp.Name, Min = fp.Min, Max = fp.Max, LogBase = fp.IsLogScale, }; if (fp.NumSteps.HasValue) { fpArgs.NumSteps = fp.NumSteps.Value; } if (fp.StepSize.HasValue) { fpArgs.StepSize = fp.StepSize.Value; } return(new FloatValueGenerator(fpArgs)); }); break; case TlcModule.SweepableLongParamAttribute lp: results[i] = ComponentFactoryUtils.CreateFromFunction(env => { var lpArgs = new LongParamArguments() { Name = lp.Name, Min = lp.Min, Max = lp.Max, LogBase = lp.IsLogScale }; if (lp.NumSteps.HasValue) { lpArgs.NumSteps = lp.NumSteps.Value; } if (lp.StepSize.HasValue) { lpArgs.StepSize = lp.StepSize.Value; } return(new LongValueGenerator(lpArgs)); }); break; } } return(results); }