public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo) { var property = new Dictionary <string, object>(); property.Add(nameof(MatrixFactorizationTrainer.Options.MatrixColumnIndexColumnName), columnInfo.UserIdColumnName); property.Add(nameof(MatrixFactorizationTrainer.Options.MatrixRowIndexColumnName), columnInfo.ItemIdColumnName); return(TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, columnInfo.LabelColumnName, additionalProperties: property)); }
public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo) { var property = new Dictionary <string, object>(); property.Add(nameof(FastTreeRankingTrainer.Options.RowGroupColumnName), columnInfo.GroupIdColumnName); return(TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, columnInfo.LabelColumnName, additionalProperties: property)); }
internal SuggestedTrainer(MLContext mlContext, ITrainerExtension trainerExtension, ColumnInformation columnInfo, ParameterSet hyperParamSet = null) { _mlContext = mlContext; _trainerExtension = trainerExtension; _columnInfo = columnInfo; SweepParams = _trainerExtension.GetHyperparamSweepRanges(); TrainerName = TrainerExtensionCatalog.GetTrainerName(_trainerExtension); SetHyperparamValues(hyperParamSet); }
/// <summary> /// Given a predictor type, return a set of all permissible trainers (with their sweeper params, if defined). /// </summary> /// <returns>Array of viable learners.</returns> public static IEnumerable <SuggestedTrainer> AllowedTrainers(MLContext mlContext, TaskKind task, ColumnInformation columnInfo, IEnumerable <TrainerName> trainerWhitelist) { var trainerExtensions = TrainerExtensionCatalog.GetTrainers(task, trainerWhitelist, columnInfo); var trainers = new List <SuggestedTrainer>(); foreach (var trainerExtension in trainerExtensions) { var learner = new SuggestedTrainer(mlContext, trainerExtension, columnInfo); trainers.Add(learner); } return(trainers.ToArray()); }
public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo) { Dictionary <string, object> additionalProperties = null; if (sweepParams == null || !sweepParams.Any(p => p.Name != "NumberOfIterations")) { additionalProperties = new Dictionary <string, object>() { { "NumberOfIterations", DefaultNumIterations } }; } return(TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, columnInfo.LabelColumnName, additionalProperties: additionalProperties)); }
public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipeline) { var transforms = new List <SuggestedTransform>(); var transformsPostTrainer = new List <SuggestedTransform>(); SuggestedTrainer trainer = null; var trainerEncountered = false; foreach (var pipelineNode in pipeline.Nodes) { if (pipelineNode.NodeType == PipelineNodeType.Trainer) { var trainerName = (TrainerName)Enum.Parse(typeof(TrainerName), pipelineNode.Name); var trainerExtension = TrainerExtensionCatalog.GetTrainerExtension(trainerName); var hyperParamSet = TrainerExtensionUtil.BuildParameterSet(trainerName, pipelineNode.Properties); var columnInfo = TrainerExtensionUtil.BuildColumnInfo(pipelineNode.Properties); trainer = new SuggestedTrainer(context, trainerExtension, columnInfo, hyperParamSet); trainerEncountered = true; } else if (pipelineNode.NodeType == PipelineNodeType.Transform) { var estimatorName = (EstimatorName)Enum.Parse(typeof(EstimatorName), pipelineNode.Name); var estimatorExtension = EstimatorExtensionCatalog.GetExtension(estimatorName); var estimator = estimatorExtension.CreateInstance(context, pipelineNode); var transform = new SuggestedTransform(pipelineNode, estimator); if (!trainerEncountered) { transforms.Add(transform); } else { transformsPostTrainer.Add(transform); } } } return(new SuggestedPipeline(transforms, transformsPostTrainer, trainer, context, pipeline.CacheBeforeTrainer)); }
public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo) { return(TrainerExtensionUtil.BuildLightGbmPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName, columnInfo.GroupIdColumnName)); }
public PipelineNode CreatePipelineNode(IEnumerable <SweepableParam> sweepParams, ColumnInformation columnInfo) { return(TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, columnInfo.LabelColumnName, null)); }