示例#1
0
 protected override Float GetInstanceWeight(FloatLabelCursor cursor)
 {
     return(cursor.Weight);
 }
示例#2
0
        /// <summary>
        /// Initialize weights by running SGD up to specified tolerance.
        /// </summary>
        protected virtual VBuffer <float> InitializeWeightsSgd(IChannel ch, FloatLabelCursor.Factory cursorFactory)
        {
            if (!Quiet)
            {
                ch.Info("Running SGD initialization with tolerance {0}", SgdInitializationTolerance);
            }

            int        numExamples  = 0;
            var        oldWeights   = VBufferUtils.CreateEmpty <float>(BiasCount + WeightCount);
            DTerminate terminateSgd =
                (ref VBuffer <float> x) =>
            {
                if (++numExamples % 1000 != 0)
                {
                    return(false);
                }
                VectorUtils.AddMult(ref x, -1, ref oldWeights);
                float normDiff = VectorUtils.Norm(oldWeights);
                x.CopyTo(ref oldWeights);
                // #if OLD_TRACING // REVIEW: How should this be ported?
                if (!Quiet)
                {
                    Console.Write(".");
                    if (numExamples % 50000 == 0)
                    {
                        Console.WriteLine("\t{0}\t{1}", numExamples, normDiff);
                    }
                }
                // #endif
                return(normDiff < SgdInitializationTolerance);
            };

            VBuffer <float>  result = default(VBuffer <float>);
            FloatLabelCursor cursor = null;

            try
            {
                float[] scratch = null;

                SgdOptimizer.DStochasticGradient lossSgd =
                    (ref VBuffer <float> x, ref VBuffer <float> grad) =>
                {
                    // Zero out the gradient by sparsifying.
                    grad = new VBuffer <float>(grad.Length, 0, grad.Values, grad.Indices);
                    EnsureBiases(ref grad);

                    if (cursor == null || !cursor.MoveNext())
                    {
                        if (cursor != null)
                        {
                            cursor.Dispose();
                        }
                        cursor = cursorFactory.Create();
                        if (!cursor.MoveNext())
                        {
                            return;
                        }
                    }
                    AccumulateOneGradient(ref cursor.Features, cursor.Label, cursor.Weight, ref x, ref grad, ref scratch);
                };

                VBuffer <float> sgdWeights;
                if (DenseOptimizer)
                {
                    sgdWeights = VBufferUtils.CreateDense <float>(BiasCount + WeightCount);
                }
                else
                {
                    sgdWeights = VBufferUtils.CreateEmpty <float>(BiasCount + WeightCount);
                }
                SgdOptimizer sgdo = new SgdOptimizer(terminateSgd);
                sgdo.Minimize(lossSgd, ref sgdWeights, ref result);
                // #if OLD_TRACING // REVIEW: How should this be ported?
                if (!Quiet)
                {
                    Console.WriteLine();
                }
                // #endif
                ch.Info("SGD initialization done in {0} rounds", numExamples);
            }
            finally
            {
                if (cursor != null)
                {
                    cursor.Dispose();
                }
            }

            return(result);
        }
        public void Train(List <FeatureSubsetModel <IPredictorProducing <TOutput> > > models, RoleMappedData data, IHostEnvironment env)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(Stacking.LoadName);

            host.CheckValue(models, nameof(models));
            host.CheckValue(data, nameof(data));

            using (var ch = host.Start("Training stacked model"))
            {
                ch.Check(Meta == null, "Train called multiple times");
                ch.Check(BasePredictorType != null);

                var maps = new ValueMapper <VBuffer <Single>, TOutput> [models.Count];
                for (int i = 0; i < maps.Length; i++)
                {
                    Contracts.Assert(models[i].Predictor is IValueMapper);
                    var m = (IValueMapper)models[i].Predictor;
                    maps[i] = m.GetMapper <VBuffer <Single>, TOutput>();
                }

                // REVIEW: Should implement this better....
                var labels   = new Single[100];
                var features = new VBuffer <Single> [100];
                int count    = 0;
                // REVIEW: Should this include bad values or filter them?
                using (var cursor = new FloatLabelCursor(data, CursOpt.AllFeatures | CursOpt.AllLabels))
                {
                    TOutput[] predictions = new TOutput[maps.Length];
                    var       vBuffers    = new VBuffer <Single> [maps.Length];
                    while (cursor.MoveNext())
                    {
                        Parallel.For(0, maps.Length, i =>
                        {
                            var model = models[i];
                            if (model.SelectedFeatures != null)
                            {
                                EnsembleUtils.SelectFeatures(ref cursor.Features, model.SelectedFeatures, model.Cardinality, ref vBuffers[i]);
                                maps[i](ref vBuffers[i], ref predictions[i]);
                            }
                            else
                            {
                                maps[i](ref cursor.Features, ref predictions[i]);
                            }
                        });

                        Utils.EnsureSize(ref labels, count + 1);
                        Utils.EnsureSize(ref features, count + 1);
                        labels[count] = cursor.Label;
                        FillFeatureBuffer(predictions, ref features[count]);
                        count++;
                    }
                }

                ch.Info("The number of instances used for stacking trainer is {0}", count);

                var bldr = new ArrayDataViewBuilder(host);
                Array.Resize(ref labels, count);
                Array.Resize(ref features, count);
                bldr.AddColumn(DefaultColumnNames.Label, NumberType.Float, labels);
                bldr.AddColumn(DefaultColumnNames.Features, NumberType.Float, features);

                var view = bldr.GetDataView();
                var rmd  = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);

                var trainer = BasePredictorType.CreateInstance(host);
                if (trainer is ITrainerEx ex && ex.NeedNormalization)
                {
                    ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
                }
                trainer.Train(rmd);
                Meta = trainer.CreatePredictor();
                CheckMeta();

                ch.Done();
            }
        }