Ejemplo n.º 1
0
        protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
            : base(env, name, ctx)
        {
            // *** Binary format ***
            // int: model count
            // int: weight count (0 or model count)
            // Float[]: weights
            // for each model:
            //   int: number of SelectedFeatures (in bits)
            //   byte[]: selected features (as many as needed for number of bits == (numSelectedFeatures + 7) / 8)
            //   int: number of Metric values
            //   for each Metric:
            //     Float: metric value
            //     int: metric name (id of the metric name in the string table)
            //     in version 0x0001x0002:
            //       bool: is the metric averaged

            int count = ctx.Reader.ReadInt32();

            Host.CheckDecode(count > 0);

            int weightCount = ctx.Reader.ReadInt32();

            Host.CheckDecode(weightCount == 0 || weightCount == count);
            Weights = ctx.Reader.ReadFloatArray(weightCount);

            Models = new FeatureSubsetModel <TPredictor> [count];
            var ver = ctx.Header.ModelVerWritten;

            for (int i = 0; i < count; i++)
            {
                ctx.LoadModel <IPredictor, SignatureLoadModel>(Host, out IPredictor p, string.Format(SubPredictorFmt, i));
                var predictor = p as TPredictor;
                Host.Check(p != null, "Inner predictor type not compatible with the ensemble type.");
                var features   = ctx.Reader.ReadBitArray();
                int numMetrics = ctx.Reader.ReadInt32();
                Host.CheckDecode(numMetrics >= 0);
                var metrics = new KeyValuePair <string, double> [numMetrics];
                for (int j = 0; j < numMetrics; j++)
                {
                    var metricValue = ctx.Reader.ReadFloat();
                    var metricName  = ctx.LoadStringOrNull();
                    if (ver == VerOld)
                    {
                        ctx.Reader.ReadBoolByte();
                    }
                    metrics[j] = new KeyValuePair <string, double>(metricName, metricValue);
                }
                Models[i] = new FeatureSubsetModel <TPredictor>(predictor, features, metrics);
            }
            ctx.LoadModel <IOutputCombiner <TOutput>, SignatureLoadModel>(Host, out Combiner, @"Combiner");
        }
Ejemplo n.º 2
0
        private protected static FeatureSubsetModel <T>[] CreateModels <T>(List <FeatureSubsetModel <IPredictorProducing <TOutput> > > models) where T : IPredictor
        {
            var subsetModels = new FeatureSubsetModel <T> [models.Count];

            for (int i = 0; i < models.Count; i++)
            {
                subsetModels[i] = new FeatureSubsetModel <T>(
                    (T)models[i].Predictor,
                    models[i].SelectedFeatures,
                    models[i].Metrics);
            }
            return(subsetModels);
        }
Ejemplo n.º 3
0
        private TPredictor TrainCore(IChannel ch, RoleMappedData data)
        {
            Host.AssertValue(ch);
            ch.AssertValue(data);

            // 1. Subset Selection
            var stackingTrainer = Combiner as IStackingTrainer <TOutput>;

            //REVIEW: Implement stacking for Batch mode.
            ch.CheckUserArg(stackingTrainer == null || Args.BatchSize <= 0, nameof(Args.BatchSize), "Stacking works only with Non-batch mode");

            var validationDataSetProportion = SubModelSelector.ValidationDatasetProportion;

            if (stackingTrainer != null)
            {
                validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion);
            }

            var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager;
            var models      = new List <FeatureSubsetModel <IPredictorProducing <TOutput> > >();

            _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion);
            int batchNumber = 1;

            foreach (var batch in _subsetSelector.GetBatches(Host.Rand))
            {
                // 2. Core train
                ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++);
                var batchModels = new FeatureSubsetModel <IPredictorProducing <TOutput> > [Trainers.Length];

                Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand),
                                 new ParallelOptions()
                {
                    MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1
                },
                                 (subset, state, index) =>
                {
                    ch.Info("Beginning training model {0} of {1}", index + 1, Trainers.Length);
                    Stopwatch sw = Stopwatch.StartNew();
                    try
                    {
                        if (EnsureMinimumFeaturesSelected(subset))
                        {
                            var model = new FeatureSubsetModel <IPredictorProducing <TOutput> >(
                                Trainers[(int)index].Train(subset.Data),
                                subset.SelectedFeatures,
                                null);
                            SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics);
                            batchModels[(int)index] = model;
                        }
                    }
                    catch (Exception ex)
                    {
                        ch.Assert(batchModels[(int)index] == null);
                        ch.Warning(ex.Sensitivity(), "Trainer {0} of {1} was not learned properly due to the exception '{2}' and will not be added to models.",
                                   index + 1, Trainers.Length, ex.Message);
                    }
                    ch.Info("Trainer {0} of {1} finished in {2}", index + 1, Trainers.Length, sw.Elapsed);
                });

                var modelsList = batchModels.Where(m => m != null).ToList();
                if (Args.ShowMetrics)
                {
                    PrintMetrics(ch, modelsList);
                }

                modelsList = SubModelSelector.Prune(modelsList).ToList();

                if (stackingTrainer != null)
                {
                    stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host);
                }

                models.AddRange(modelsList);
                int modelSize = Utils.Size(models);
                if (modelSize < Utils.Size(Trainers))
                {
                    ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers));
                }
                ch.Check(modelSize > 0, "Ensemble training resulted in no valid models.");
            }
            return(CreatePredictor(models));
        }