Exemplo n.º 1
0
        public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, CombineOvaPredictorModelsInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineOvaModels");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            host.CheckNonEmpty(input.ModelArray, nameof(input.ModelArray));
            // Something tells me we should put normalization as part of macro expansion, but since i get
            // subgraph instead of learner it's a bit tricky to get learner and decide should we add
            // normalization node or not, plus everywhere in code we leave that reposnsibility to TransformModel.
            var normalizedView = input.ModelArray[0].TransformModel.Apply(host, input.TrainingData);

            using (var ch = host.Start("CombineOvaModels"))
            {
                var schema = normalizedView.Schema;
                var label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.LabelColumn),
                                                                 input.LabelColumn,
                                                                 DefaultColumnNames.Label);
                var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.FeatureColumn),
                                                                  input.FeatureColumn, DefaultColumnNames.Features);
                var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.WeightColumn),
                                                                 input.WeightColumn, DefaultColumnNames.Weight);
                var data = new RoleMappedData(normalizedView, label, feature, null, weight);

                return(new PredictorModelOutput
                {
                    PredictorModel = new PredictorModelImpl(env, data, input.TrainingData,
                                                            OvaModelParameters.Create(host, input.UseProbabilities,
                                                                                      input.ModelArray.Select(p => p.Predictor as IPredictorProducing <float>).ToArray()))
                });
            }
        }
Exemplo n.º 2
0
        private protected override OvaModelParameters CreatePredictor()
        {
            Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");

            Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor.");
            Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes.");

            var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);

            IPredictorProducing <float>[] predictors = new IPredictorProducing <float> [_tlcNumClass];
            for (int i = 0; i < _tlcNumClass; ++i)
            {
                var pred = CreateBinaryPredictor(i, innerArgs);
                var cali = new PlattCalibrator(Host, -0.5, 0);
                predictors[i] = new FeatureWeightsCalibratedPredictor(Host, pred, cali);
            }
            string obj = (string)GetGbmParameters()["objective"];

            if (obj == "multiclass")
            {
                return(OvaModelParameters.Create(Host, OvaModelParameters.OutputFormula.Softmax, predictors));
            }
            else
            {
                return(OvaModelParameters.Create(Host, predictors));
            }
        }
Exemplo n.º 3
0
        [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only
        public void MultiClassLightGBM()
        {
            var env        = new MLContext(seed: 0);
            var dataPath   = GetDataPath(TestDatasets.iris.trainFilename);
            var dataSource = new MultiFileSource(dataPath);

            var ctx    = new MulticlassClassificationContext(env);
            var reader = TextLoaderStatic.CreateReader(env,
                                                       c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));

            OvaModelParameters pred = null;

            // With a custom loss function we no longer get calibrated predictions.
            var est = reader.MakeNewEstimator()
                      .Append(r => (label: r.label.ToKey(), r.features))
                      .Append(r => (r.label, preds: ctx.Trainers.LightGbm(
                                        r.label,
                                        r.features, onFit: p => pred = p)));

            var pipe = reader.Append(est);

            Assert.Null(pred);
            var model = pipe.Fit(dataSource);

            Assert.NotNull(pred);

            var data = model.Read(dataSource);

            // Just output some data on the schema for fun.
            var schema = data.AsDynamic.Schema;

            for (int c = 0; c < schema.Count; ++c)
            {
                Console.WriteLine($"{schema[c].Name}, {schema[c].Type}");
            }

            var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2);

            Assert.True(metrics.LogLoss > 0);
            Assert.True(metrics.TopKAccuracy > 0);
        }
Exemplo n.º 4
0
 protected override MulticlassPredictionTransformer <OvaModelParameters> MakeTransformer(OvaModelParameters model, Schema trainSchema)
 => new MulticlassPredictionTransformer <OvaModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);