コード例 #1
0
        public virtual RoleMappedData GetTestData(Subset subset, Batch batch)
        {
            Host.CheckValueOrNull(subset);
            Host.CheckValue(batch.TestInstances, nameof(batch), "Batch does not have test data");

            if (subset == null || subset.SelectedFeatures == null)
            {
                return(batch.TestInstances);
            }
            return(EnsembleUtils.SelectFeatures(Host, batch.TestInstances, subset.SelectedFeatures));
        }
コード例 #2
0
        public Subset SelectFeatures(RoleMappedData data, Random rand)
        {
            _host.CheckValue(data, nameof(data));
            data.CheckFeatureFloatVector();

            var type     = data.Schema.Feature.Value.Type;
            int len      = type.GetVectorSize();
            var features = new BitArray(len);

            for (int j = 0; j < len; j++)
            {
                features[j] = rand.NextDouble() < _args.FeaturesSelectionProportion;
            }
            var dataNew = EnsembleUtils.SelectFeatures(_host, data, features);

            return(new Subset(dataNew, features));
        }
コード例 #3
0
        public void Train(List <FeatureSubsetModel <IPredictorProducing <TOutput> > > models, RoleMappedData data, IHostEnvironment env)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(Stacking.LoadName);

            host.CheckValue(models, nameof(models));
            host.CheckValue(data, nameof(data));

            using (var ch = host.Start("Training stacked model"))
            {
                ch.Check(Meta == null, "Train called multiple times");
                ch.Check(BasePredictorType != null);

                var maps = new ValueMapper <VBuffer <Single>, TOutput> [models.Count];
                for (int i = 0; i < maps.Length; i++)
                {
                    Contracts.Assert(models[i].Predictor is IValueMapper);
                    var m = (IValueMapper)models[i].Predictor;
                    maps[i] = m.GetMapper <VBuffer <Single>, TOutput>();
                }

                // REVIEW: Should implement this better....
                var labels   = new Single[100];
                var features = new VBuffer <Single> [100];
                int count    = 0;
                // REVIEW: Should this include bad values or filter them?
                using (var cursor = new FloatLabelCursor(data, CursOpt.AllFeatures | CursOpt.AllLabels))
                {
                    TOutput[] predictions = new TOutput[maps.Length];
                    var       vBuffers    = new VBuffer <Single> [maps.Length];
                    while (cursor.MoveNext())
                    {
                        Parallel.For(0, maps.Length, i =>
                        {
                            var model = models[i];
                            if (model.SelectedFeatures != null)
                            {
                                EnsembleUtils.SelectFeatures(ref cursor.Features, model.SelectedFeatures, model.Cardinality, ref vBuffers[i]);
                                maps[i](ref vBuffers[i], ref predictions[i]);
                            }
                            else
                            {
                                maps[i](ref cursor.Features, ref predictions[i]);
                            }
                        });

                        Utils.EnsureSize(ref labels, count + 1);
                        Utils.EnsureSize(ref features, count + 1);
                        labels[count] = cursor.Label;
                        FillFeatureBuffer(predictions, ref features[count]);
                        count++;
                    }
                }

                ch.Info("The number of instances used for stacking trainer is {0}", count);

                var bldr = new ArrayDataViewBuilder(host);
                Array.Resize(ref labels, count);
                Array.Resize(ref features, count);
                bldr.AddColumn(DefaultColumnNames.Label, NumberType.Float, labels);
                bldr.AddColumn(DefaultColumnNames.Features, NumberType.Float, features);

                var view = bldr.GetDataView();
                var rmd  = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);

                var trainer = BasePredictorType.CreateComponent(host);
                if (trainer.Info.NeedNormalization)
                {
                    ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
                }
                Meta = trainer.Train(rmd);
                CheckMeta();
            }
        }