Ejemplo n.º 1
0
        public PipelineNode ToPipelineNode()
        {
            var hyperParams       = SweepParams.Where(p => p != null && p.RawValue != null);
            var elementProperties = new Dictionary <string, object>();

            foreach (var hyperParam in hyperParams)
            {
                elementProperties[hyperParam.Name] = hyperParam.ProcessedValue();
            }
            return(new PipelineNode(TrainerName.ToString(), PipelineNodeType.Trainer,
                                    new[] { "Features" }, new[] { "Score" }, elementProperties));
        }
        public static ParameterSet BuildParameterSet(TrainerName trainerName, IDictionary <string, object> props)
        {
            props = props.Where(p => p.Key != LabelColumn && p.Key != WeightColumn)
                    .ToDictionary(kvp => kvp.Key, kvp => kvp.Value);

            if (trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
                trainerName == TrainerName.LightGbmRegression || trainerName == TrainerName.LightGbmRanking)
            {
                return(BuildLightGbmParameterSet(props));
            }

            var paramVals = props.Select(p => new StringParameterValue(p.Key, p.Value.ToString()));

            return(new ParameterSet(paramVals));
        }
        public static PipelineNode BuildPipelineNode(TrainerName trainerName, IEnumerable <SweepableParam> sweepParams,
                                                     string labelColumn, string weightColumn = null, IDictionary <string, object> additionalProperties = null)
        {
            var properties = BuildBasePipelineNodeProps(sweepParams, labelColumn, weightColumn);

            if (additionalProperties != null)
            {
                foreach (var property in additionalProperties)
                {
                    properties[property.Key] = property.Value;
                }
            }

            return(new PipelineNode(trainerName.ToString(), PipelineNodeType.Trainer, DefaultColumnNames.Features,
                                    DefaultColumnNames.Score, properties));
        }
 public static PipelineNode BuildLightGbmPipelineNode(TrainerName trainerName, IEnumerable <SweepableParam> sweepParams,
                                                      string labelColumn, string weightColumn, string groupColumn)
 {
     return(new PipelineNode(trainerName.ToString(), PipelineNodeType.Trainer, DefaultColumnNames.Features,
                             DefaultColumnNames.Score, BuildLightGbmPipelineNodeProps(sweepParams, labelColumn, weightColumn, groupColumn)));
 }
        public static ITrainerExtension GetTrainerExtension(TrainerName trainerName)
        {
            var trainerExtensionType = _trainerNamesToExtensionTypes[trainerName];

            return((ITrainerExtension)Activator.CreateInstance(trainerExtensionType));
        }