public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            var property = new Dictionary <string, object>();

            property.Add(nameof(MatrixFactorizationTrainer.Options.MatrixColumnIndexColumnName), columnInfo.UserIdColumnName);
            property.Add(nameof(MatrixFactorizationTrainer.Options.MatrixRowIndexColumnName), columnInfo.ItemIdColumnName);
            return(TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, columnInfo.LabelColumnName, additionalProperties: property));
        }
Пример #2
0
        public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            var property = new Dictionary <string, object>();

            property.Add(nameof(FastTreeRankingTrainer.Options.RowGroupColumnName), columnInfo.GroupIdColumnName);
            return(TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                                                          columnInfo.LabelColumnName, additionalProperties: property));
        }
Пример #3
0
 internal SuggestedTrainer(MLContext mlContext, ITrainerExtension trainerExtension,
                           ColumnInformation columnInfo,
                           ParameterSet hyperParamSet = null)
 {
     _mlContext        = mlContext;
     _trainerExtension = trainerExtension;
     _columnInfo       = columnInfo;
     SweepParams       = _trainerExtension.GetHyperparamSweepRanges();
     TrainerName       = TrainerExtensionCatalog.GetTrainerName(_trainerExtension);
     SetHyperparamValues(hyperParamSet);
 }
        /// <summary>
        /// Given a predictor type, return a set of all permissible trainers (with their sweeper params, if defined).
        /// </summary>
        /// <returns>Array of viable learners.</returns>
        public static IEnumerable <SuggestedTrainer> AllowedTrainers(MLContext mlContext, TaskKind task,
                                                                     ColumnInformation columnInfo, IEnumerable <TrainerName> trainerWhitelist)
        {
            var trainerExtensions = TrainerExtensionCatalog.GetTrainers(task, trainerWhitelist, columnInfo);

            var trainers = new List <SuggestedTrainer>();

            foreach (var trainerExtension in trainerExtensions)
            {
                var learner = new SuggestedTrainer(mlContext, trainerExtension, columnInfo);
                trainers.Add(learner);
            }
            return(trainers.ToArray());
        }
        public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            Dictionary <string, object> additionalProperties = null;

            if (sweepParams == null || !sweepParams.Any(p => p.Name != "NumberOfIterations"))
            {
                additionalProperties = new Dictionary <string, object>()
                {
                    { "NumberOfIterations", DefaultNumIterations }
                };
            }

            return(TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                                                          columnInfo.LabelColumnName, additionalProperties: additionalProperties));
        }
        public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipeline)
        {
            var transforms            = new List <SuggestedTransform>();
            var transformsPostTrainer = new List <SuggestedTransform>();
            SuggestedTrainer trainer  = null;

            var trainerEncountered = false;

            foreach (var pipelineNode in pipeline.Nodes)
            {
                if (pipelineNode.NodeType == PipelineNodeType.Trainer)
                {
                    var trainerName      = (TrainerName)Enum.Parse(typeof(TrainerName), pipelineNode.Name);
                    var trainerExtension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
                    var hyperParamSet    = TrainerExtensionUtil.BuildParameterSet(trainerName, pipelineNode.Properties);
                    var columnInfo       = TrainerExtensionUtil.BuildColumnInfo(pipelineNode.Properties);
                    trainer            = new SuggestedTrainer(context, trainerExtension, columnInfo, hyperParamSet);
                    trainerEncountered = true;
                }
                else if (pipelineNode.NodeType == PipelineNodeType.Transform)
                {
                    var estimatorName      = (EstimatorName)Enum.Parse(typeof(EstimatorName), pipelineNode.Name);
                    var estimatorExtension = EstimatorExtensionCatalog.GetExtension(estimatorName);
                    var estimator          = estimatorExtension.CreateInstance(context, pipelineNode);
                    var transform          = new SuggestedTransform(pipelineNode, estimator);
                    if (!trainerEncountered)
                    {
                        transforms.Add(transform);
                    }
                    else
                    {
                        transformsPostTrainer.Add(transform);
                    }
                }
            }

            return(new SuggestedPipeline(transforms, transformsPostTrainer, trainer, context, pipeline.CacheBeforeTrainer));
        }
Пример #7
0
 public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo)
 {
     return(TrainerExtensionUtil.BuildLightGbmPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                                                           columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName, columnInfo.GroupIdColumnName));
 }
Пример #8
0
 public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo)
 {
     return(TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                                                   columnInfo.LabelColumnName, null));
 }