protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) : base(env, name, ctx) { // *** Binary format *** // int: model count // int: weight count (0 or model count) // Float[]: weights // for each model: // int: number of SelectedFeatures (in bits) // byte[]: selected features (as many as needed for number of bits == (numSelectedFeatures + 7) / 8) // int: number of Metric values // for each Metric: // Float: metric value // int: metric name (id of the metric name in the string table) // in version 0x0001x0002: // bool: is the metric averaged int count = ctx.Reader.ReadInt32(); Host.CheckDecode(count > 0); int weightCount = ctx.Reader.ReadInt32(); Host.CheckDecode(weightCount == 0 || weightCount == count); Weights = ctx.Reader.ReadFloatArray(weightCount); Models = new FeatureSubsetModel <TPredictor> [count]; var ver = ctx.Header.ModelVerWritten; for (int i = 0; i < count; i++) { ctx.LoadModel <IPredictor, SignatureLoadModel>(Host, out IPredictor p, string.Format(SubPredictorFmt, i)); var predictor = p as TPredictor; Host.Check(p != null, "Inner predictor type not compatible with the ensemble type."); var features = ctx.Reader.ReadBitArray(); int numMetrics = ctx.Reader.ReadInt32(); Host.CheckDecode(numMetrics >= 0); var metrics = new KeyValuePair <string, double> [numMetrics]; for (int j = 0; j < numMetrics; j++) { var metricValue = ctx.Reader.ReadFloat(); var metricName = ctx.LoadStringOrNull(); if (ver == VerOld) { ctx.Reader.ReadBoolByte(); } metrics[j] = new KeyValuePair <string, double>(metricName, metricValue); } Models[i] = new FeatureSubsetModel <TPredictor>(predictor, features, metrics); } ctx.LoadModel <IOutputCombiner <TOutput>, SignatureLoadModel>(Host, out Combiner, @"Combiner"); }
private protected static FeatureSubsetModel <T>[] CreateModels <T>(List <FeatureSubsetModel <IPredictorProducing <TOutput> > > models) where T : IPredictor { var subsetModels = new FeatureSubsetModel <T> [models.Count]; for (int i = 0; i < models.Count; i++) { subsetModels[i] = new FeatureSubsetModel <T>( (T)models[i].Predictor, models[i].SelectedFeatures, models[i].Metrics); } return(subsetModels); }
private TPredictor TrainCore(IChannel ch, RoleMappedData data) { Host.AssertValue(ch); ch.AssertValue(data); // 1. Subset Selection var stackingTrainer = Combiner as IStackingTrainer <TOutput>; //REVIEW: Implement stacking for Batch mode. ch.CheckUserArg(stackingTrainer == null || Args.BatchSize <= 0, nameof(Args.BatchSize), "Stacking works only with Non-batch mode"); var validationDataSetProportion = SubModelSelector.ValidationDatasetProportion; if (stackingTrainer != null) { validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion); } var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager; var models = new List <FeatureSubsetModel <IPredictorProducing <TOutput> > >(); _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion); int batchNumber = 1; foreach (var batch in _subsetSelector.GetBatches(Host.Rand)) { // 2. Core train ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++); var batchModels = new FeatureSubsetModel <IPredictorProducing <TOutput> > [Trainers.Length]; Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand), new ParallelOptions() { MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1 }, (subset, state, index) => { ch.Info("Beginning training model {0} of {1}", index + 1, Trainers.Length); Stopwatch sw = Stopwatch.StartNew(); try { if (EnsureMinimumFeaturesSelected(subset)) { var model = new FeatureSubsetModel <IPredictorProducing <TOutput> >( Trainers[(int)index].Train(subset.Data), subset.SelectedFeatures, null); SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics); batchModels[(int)index] = model; } } catch (Exception ex) { ch.Assert(batchModels[(int)index] == null); ch.Warning(ex.Sensitivity(), "Trainer {0} of {1} was not learned properly due to the exception '{2}' and will not be added to models.", index + 1, Trainers.Length, ex.Message); } ch.Info("Trainer {0} of {1} finished in {2}", index + 1, Trainers.Length, sw.Elapsed); }); var modelsList = batchModels.Where(m => m != null).ToList(); if (Args.ShowMetrics) { PrintMetrics(ch, modelsList); } modelsList = SubModelSelector.Prune(modelsList).ToList(); if (stackingTrainer != null) { stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host); } models.AddRange(modelsList); int modelSize = Utils.Size(models); if (modelSize < Utils.Size(Trainers)) { ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers)); } ch.Check(modelSize > 0, "Ensemble training resulted in no valid models."); } return(CreatePredictor(models)); }