//Override default termination criterion MeanRelativeImprovementCriterion with
        private protected override Optimizer InitializeOptimizer(IChannel ch, FloatLabelCursor.Factory cursorFactory,
                                                                 out VBuffer <float> init, out ITerminationCriterion terminationCriterion)
            var opt = base.InitializeOptimizer(ch, cursorFactory, out init, out terminationCriterion);

            // MeanImprovementCriterion:
            //   Terminates when the geometrically-weighted average improvement falls below the tolerance
            terminationCriterion = new MeanImprovementCriterion(OptTol, 0.25f, MaxIterations);

        protected virtual int ComputeNumThreads(FloatLabelCursor.Factory cursorFactory)
            int maxThreads = Math.Min(8, Math.Max(1, Environment.ProcessorCount / 2));

            if (0 < Host.ConcurrencyFactor && Host.ConcurrencyFactor < maxThreads)
                maxThreads = Host.ConcurrencyFactor;

        private FloatLabelCursor.Factory CreateCursorFactory(RoleMappedData data)
            var loadFlags = CursOpt.AllLabels | CursOpt.Features;
            if (PredictionKind == PredictionKind.Ranking)
                loadFlags |= CursOpt.Group;

            if (data.Schema.Weight.HasValue)
                loadFlags |= CursOpt.AllWeights;

            var factory = new FloatLabelCursor.Factory(data, loadFlags);
            return factory;
        //Override default termination criterion MeanRelativeImprovementCriterion with
        protected override Optimizer InitializeOptimizer(IChannel ch, FloatLabelCursor.Factory cursorFactory,
                                                         out VBuffer <Float> init, out ITerminationCriterion terminationCriterion)
            var opt = base.InitializeOptimizer(ch, cursorFactory, out init, out terminationCriterion);

            // MeanImprovementCriterion:
            //   Terminates when the geometrically-weighted average improvement falls below the tolerance
            //terminationCriterion = new GradientCheckingMonitor(new MeanImprovementCriterion(CmdArgs.optTol, 0.25, MaxIterations),2);
            terminationCriterion = new MeanImprovementCriterion(OptTol, (Float)0.25, MaxIterations);

        private FloatLabelCursor.Factory CreateCursorFactory(RoleMappedData data)
            var loadFlags = CursOpt.AllLabels | CursOpt.AllWeights | CursOpt.Features;

            if (_predictionKind == PredictionKind.Ranking)
                loadFlags |= CursOpt.Group;

            var factory = new FloatLabelCursor.Factory(data, loadFlags);

 /// <summary>
 /// Calculate the density of data. Only use top 1000 rows to calculate.
 /// </summary>
 private static double DetectDensity(FloatLabelCursor.Factory factory, int numRows = 1000)
     int nonZeroCount = 0;
     int totalCount = 0;
     using (var cursor = factory.Create())
         while (cursor.MoveNext() && numRows > 0)
             nonZeroCount += cursor.Features.GetValues().Length;
             totalCount += cursor.Features.Length;
     return (double)nonZeroCount / totalCount;
        // REVIEW: No extra benefits from using more threads in training.
        protected override int ComputeNumThreads(FloatLabelCursor.Factory cursorFactory)
            int maxThreads;

            if (Host.ConcurrencyFactor < 1)
                maxThreads = Math.Min(2, Math.Max(1, Environment.ProcessorCount / 2));
                maxThreads = Host.ConcurrencyFactor;

            public InputDataManager(SymSgdClassificationTrainer trainer, FloatLabelCursor.Factory cursorFactory, IChannel ch)
                _instIndices        = new ArrayManager <int>(trainer, ch);
                _instValues         = new ArrayManager <float>(trainer, ch);
                _instanceProperties = new List <InstanceProperties>();

                _cursorFactory  = cursorFactory;
                _ch             = ch;
                _cursor         = cursorFactory.Create();
                _cursorMoveNext = _cursor.MoveNext();
                _isFullyLoaded  = true;

                _instanceIndex = 0;

                _trainer = trainer;
        protected virtual Optimizer InitializeOptimizer(IChannel ch, FloatLabelCursor.Factory cursorFactory,
                                                        out VBuffer <float> init, out ITerminationCriterion terminationCriterion)
            // MeanRelativeImprovementCriterion:
            //   Stops optimization when the average objective improvement over the last
            //   n iterations, normalized by the function value, is small enough.
            terminationCriterion = new MeanRelativeImprovementCriterion(OptTol, 5, MaxIterations);

            Optimizer opt = (L1Weight > 0)
                ? new L1Optimizer(Host, BiasCount, L1Weight / NumGoodRows, MemorySize, DenseOptimizer, null, EnforceNonNegativity)
                : new Optimizer(Host, MemorySize, DenseOptimizer, null, EnforceNonNegativity);

            opt.Quiet = Quiet;

            if (_srcPredictor != null)
                init = InitializeWeightsFromPredictor(_srcPredictor);
            else if (InitWtsDiameter > 0)
                float[] initWeights = new float[BiasCount + WeightCount];
                for (int j = 0; j < initWeights.Length; j++)
                    initWeights[j] = InitWtsDiameter * (Host.Rand.NextSingle() - 0.5f);
                init = new VBuffer <float>(initWeights.Length, initWeights);
            else if (SgdInitializationTolerance > 0)
                init = InitializeWeightsSgd(ch, cursorFactory);
                init = VBufferUtils.CreateEmpty <float>(BiasCount + WeightCount);

        private protected override OlsModelParameters TrainModelCore(TrainContext context)
            using (var ch = Host.Start("Training"))
                ch.CheckValue(context, nameof(context));
                var examples = context.TrainingSet;
                ch.CheckParam(examples.Schema.Feature.HasValue, nameof(examples), "Need a feature column");
                ch.CheckParam(examples.Schema.Label.HasValue, nameof(examples), "Need a labelColumn column");

                // The labelColumn type must be either Float or a key type based on int (if allowKeyLabels is true).
                var typeLab = examples.Schema.Label.Value.Type;
                if (typeLab != NumberDataViewType.Single)
                    throw ch.Except("Incompatible labelColumn column type {0}, must be {1}", typeLab, NumberDataViewType.Single);

                // The feature type must be a vector of Float.
                var typeFeat = examples.Schema.Feature.Value.Type as VectorDataViewType;
                if (typeFeat == null || !typeFeat.IsKnownSize)
                    throw ch.Except("Incompatible feature column type {0}, must be known sized vector of {1}", typeFeat, NumberDataViewType.Single);
                if (typeFeat.ItemType != NumberDataViewType.Single)
                    throw ch.Except("Incompatible feature column type {0}, must be vector of {1}", typeFeat, NumberDataViewType.Single);

                CursOpt cursorOpt = CursOpt.Label | CursOpt.Features;
                if (examples.Schema.Weight.HasValue)
                    cursorOpt |= CursOpt.Weight;

                var cursorFactory = new FloatLabelCursor.Factory(examples, cursorOpt);

                return(TrainCore(ch, cursorFactory, typeFeat.Size));
        private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state)
            bool shuffle = Args.Shuffle;

            if (shuffle && !data.Data.CanShuffle)
                ch.Warning("Training data does not support shuffling, so ignoring request to shuffle");
                shuffle = false;

            var  rand          = shuffle ? Host.Rand : null;
            var  cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight);
            long numBad        = 0;

            while (state.Iteration < Args.NumIterations)

                using (var cursor = cursorFactory.Create(rand))
                    while (cursor.MoveNext())
                        state.ProcessDataInstance(ch, ref cursor.Features, cursor.Label, cursor.Weight);
                    numBad += cursor.BadFeaturesRowCount;


            if (numBad > 0)
                    "Skipped {0} instances with missing features during training (over {1} iterations; {2} inst/iter)",
                    numBad, Args.NumIterations, numBad / Args.NumIterations);
        /// <summary>
        /// Load dataset. Use row batch way to reduce peak memory cost.
        /// </summary>
        private void LoadDataset(IChannel ch, FloatLabelCursor.Factory factory, Dataset dataset, int numRow, int batchSize, CategoricalMetaData catMetaData)
            ch.Assert(dataset.GetNumRows() == numRow);
            ch.Assert(dataset.GetNumCols() == catMetaData.NumCol);
            var rand = Host.Rand;

            // To avoid array resize, batch size should bigger than size of one row.
            batchSize = Math.Max(batchSize, catMetaData.NumCol);
            double density       = DetectDensity(factory);
            int    numElem       = 0;
            int    totalRowCount = 0;
            int    curRowCount   = 0;

            if (density >= 0.5)
                int batchRow = batchSize / catMetaData.NumCol;
                batchRow = Math.Max(1, batchRow);
                if (batchRow > numRow)
                    batchRow = numRow;

                // This can only happen if the size of ONE example(row) exceeds the max array size. This looks like a very unlikely case.
                if ((long)catMetaData.NumCol * batchRow > Utils.ArrayMaxSize)
                    throw ch.Except("Size of array exceeded the " + nameof(Utils.ArrayMaxSize));

                float[] features = new float[catMetaData.NumCol * batchRow];

                using (var cursor = factory.Create())
                    while (cursor.MoveNext())
                        ch.Assert(totalRowCount < numRow);
                        CopyToArray(ch, cursor, features, catMetaData, rand, ref numElem);
                        if (batchRow == curRowCount)
                            ch.Assert(numElem == curRowCount * catMetaData.NumCol);
                            // PushRows is run by multi-threading inside, so lock here.
                            lock (LightGbmShared.LockForMultiThreadingInside)
                                dataset.PushRows(features, curRowCount, catMetaData.NumCol, totalRowCount - curRowCount);
                            curRowCount = 0;
                            numElem     = 0;
                    ch.Assert(totalRowCount == numRow);
                    if (curRowCount > 0)
                        ch.Assert(numElem == curRowCount * catMetaData.NumCol);
                        // PushRows is run by multi-threading inside, so lock here.
                        lock (LightGbmShared.LockForMultiThreadingInside)
                            dataset.PushRows(features, curRowCount, catMetaData.NumCol, totalRowCount - curRowCount);
                int esimateBatchRow = (int)(batchSize / (catMetaData.NumCol * density));
                esimateBatchRow = Math.Max(1, esimateBatchRow);
                float[] features = new float[batchSize];
                int[]   indices  = new int[batchSize];
                int[]   indptr   = new int[esimateBatchRow + 1];

                using (var cursor = factory.Create())
                    while (cursor.MoveNext())
                        ch.Assert(totalRowCount < numRow);
                        // Need push rows to LightGBM.
                        if (numElem + cursor.Features.Count > features.Length)
                            // Mini batch size is greater than size of one row.
                            // So, at least we have the data of one row.
                            ch.Assert(curRowCount > 0);
                            Utils.EnsureSize(ref indptr, curRowCount + 1);
                            indptr[curRowCount] = numElem;
                            // PushRows is run by multi-threading inside, so lock here.
                            lock (LightGbmShared.LockForMultiThreadingInside)
                                dataset.PushRows(indptr, indices, features,
                                                 curRowCount + 1, numElem, catMetaData.NumCol, totalRowCount - curRowCount);
                            curRowCount = 0;
                            numElem     = 0;
                        Utils.EnsureSize(ref indptr, curRowCount + 1);
                        indptr[curRowCount] = numElem;
                        CopyToCsr(ch, cursor, indices, features, catMetaData, rand, ref numElem);
                    ch.Assert(totalRowCount == numRow);
                    if (curRowCount > 0)
                        Utils.EnsureSize(ref indptr, curRowCount + 1);
                        indptr[curRowCount] = numElem;
                        // PushRows is run by multi-threading inside, so lock here.
                        lock (LightGbmShared.LockForMultiThreadingInside)
                            dataset.PushRows(indptr, indices, features, curRowCount + 1,
                                             numElem, catMetaData.NumCol, totalRowCount - curRowCount);
        /// <summary>
        /// Create a dataset from the sampling data.
        /// </summary>
        private void CreateDatasetFromSamplingData(IChannel ch, FloatLabelCursor.Factory factory,
                                                   int numRow, string param, float[] labels, float[] weights, int[] groups, CategoricalMetaData catMetaData,
                                                   out Dataset dataset)

            int numSampleRow = GetNumSampleRow(numRow, FeatureCount);

            var    rand        = Host.Rand;
            double averageStep = (double)numRow / numSampleRow;
            int    totalIdx    = 0;
            int    sampleIdx   = 0;
            double density     = DetectDensity(factory);

            double[][] sampleValuePerColumn   = new double[catMetaData.NumCol][];
            int[][]    sampleIndicesPerColumn = new int[catMetaData.NumCol][];
            int[]      nonZeroCntPerColumn    = new int[catMetaData.NumCol];
            int        estimateNonZeroCnt     = (int)(numSampleRow * density);

            estimateNonZeroCnt = Math.Max(1, estimateNonZeroCnt);
            for (int i = 0; i < catMetaData.NumCol; i++)
                nonZeroCntPerColumn[i]    = 0;
                sampleValuePerColumn[i]   = new double[estimateNonZeroCnt];
                sampleIndicesPerColumn[i] = new int[estimateNonZeroCnt];
            using (var cursor = factory.Create())
                int step = 1;
                if (averageStep > 1)
                    step = rand.Next((int)(2 * averageStep - 1)) + 1;
                while (MoveMany(cursor, step))
                    if (cursor.Features.IsDense)
                        GetFeatureValueDense(ch, cursor, catMetaData, rand, out float[] featureValues);
                        for (int i = 0; i < catMetaData.NumCol; ++i)
                            float fv = featureValues[i];
                            if (fv == 0)
                            int curNonZeroCnt = nonZeroCntPerColumn[i];
                            Utils.EnsureSize(ref sampleValuePerColumn[i], curNonZeroCnt + 1);
                            Utils.EnsureSize(ref sampleIndicesPerColumn[i], curNonZeroCnt + 1);
                            sampleValuePerColumn[i][curNonZeroCnt]   = fv;
                            sampleIndicesPerColumn[i][curNonZeroCnt] = sampleIdx;
                            nonZeroCntPerColumn[i] = curNonZeroCnt + 1;
                        GetFeatureValueSparse(ch, cursor, catMetaData, rand, out int[] featureIndices, out float[] featureValues, out int cnt);
                        for (int i = 0; i < cnt; ++i)
                            int   colIdx = featureIndices[i];
                            float fv     = featureValues[i];
                            if (fv == 0)
                            int curNonZeroCnt = nonZeroCntPerColumn[colIdx];
                            Utils.EnsureSize(ref sampleValuePerColumn[colIdx], curNonZeroCnt + 1);
                            Utils.EnsureSize(ref sampleIndicesPerColumn[colIdx], curNonZeroCnt + 1);
                            sampleValuePerColumn[colIdx][curNonZeroCnt]   = fv;
                            sampleIndicesPerColumn[colIdx][curNonZeroCnt] = sampleIdx;
                            nonZeroCntPerColumn[colIdx] = curNonZeroCnt + 1;
                    totalIdx += step;
                    if (numSampleRow == sampleIdx || numRow == totalIdx)
                    averageStep = (double)(numRow - totalIdx) / (numSampleRow - sampleIdx);
                    step        = 1;
                    if (averageStep > 1)
                        step = rand.Next((int)(2 * averageStep - 1)) + 1;
            dataset = new Dataset(sampleValuePerColumn, sampleIndicesPerColumn, catMetaData.NumCol, nonZeroCntPerColumn, sampleIdx, numRow, param, labels, weights, groups);
        /// <summary>
        /// Compute row count, list of labels, weights and group counts of the dataset.
        /// </summary>
        private void GetMetainfo(IChannel ch, FloatLabelCursor.Factory factory,
                                 out int numRow, out float[] labels, out float[] weights, out int[] groups)
            ch.Check(factory.Data.Schema.Label != null, "The data should have label.");
            List <float> labelList  = new List <float>();
            bool         hasWeights = factory.Data.Schema.Weight != null;
            bool         hasGroup   = false;

            if (_predictionKind == PredictionKind.Ranking)
                ch.Check(factory.Data.Schema != null, "The data for ranking task should have group field.");
                hasGroup = true;
            List <float> weightList   = hasWeights ? new List <float>() : null;
            List <ulong> cursorGroups = hasGroup ? new List <ulong>() : null;

            using (var cursor = factory.Create())
                while (cursor.MoveNext())
                    if (labelList.Count == Utils.ArrayMaxSize)
                        throw ch.Except($"Dataset row count exceeded the maximum count of {Utils.ArrayMaxSize}");
                    if (hasWeights)
                        // Default weight = 1.
                        if (float.IsNaN(cursor.Weight))
                    if (hasGroup)
            labels = labelList.ToArray();
            ConvertNaNLabels(ch, factory.Data, labels);
            numRow = labels.Length;
            ch.Check(numRow > 0, "Cannot use empty dataset.");
            weights = hasWeights ? weightList.ToArray() : null;
            groups  = null;
            if (hasGroup)
                List <int> groupList = new List <int>();
                int        lastGroup = -1;
                for (int i = 0; i < numRow; ++i)
                    if (i == 0 || cursorGroups[i] != cursorGroups[i - 1])
                groups = groupList.ToArray();
        protected virtual void TrainCore(IChannel ch, RoleMappedData data)

            // Compute the number of threads to use. The ctor should have verified that this will
            // produce a positive value.
            int numThreads = !UseThreads ? 1 : (NumThreads ?? Environment.ProcessorCount);

            if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
                numThreads = Host.ConcurrencyFactor;
                ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
                           + "setting of the environment. Using {0} training threads instead.", numThreads);

            ch.Assert(numThreads > 0);

            NumGoodRows = 0;
            WeightSum   = 0;

            _features = null;
            _labels   = null;
            _weights  = null;
            if (numThreads > 1)
                ch.Info("LBFGS multi-threading will attempt to load dataset into memory. In case of out-of-memory " +
                        "issues, add 'numThreads=1' to the trainer arguments and 'cache=-' to the command line " +
                        "arguments to turn off multi-threading.");
                _features = new VBuffer <float> [1000];
                _labels   = new float[1000];
                if (data.Schema.Weight != null)
                    _weights = new float[1000];

            var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Features | CursOpt.Label | CursOpt.Weight);

            long numBad;

            // REVIEW: This pass seems overly expensive for the benefit when multi-threading is off....
            using (var cursor = cursorFactory.Create())
                using (var pch = Host.StartProgressChannel("LBFGS data prep"))
                    // REVIEW: maybe it makes sense for the factory to capture the good row count after
                    // the first successful cursoring?
                    Double totalCount = data.Data.GetRowCount(true) ?? Double.NaN;

                    long exCount = 0;
                    pch.SetHeader(new ProgressHeader(null, new[] { "examples" }),
                                  e => e.SetProgress(0, exCount, totalCount));
                    while (cursor.MoveNext())
                        WeightSum += cursor.Weight;
                        if (ShowTrainingStats)
                            ProcessPriorDistribution(cursor.Label, cursor.Weight);

                        PreTrainingProcessInstance(cursor.Label, ref cursor.Features, cursor.Weight);
                        if (_features != null)
                            ch.Assert(cursor.KeptRowCount <= int.MaxValue);
                            int index = (int)cursor.KeptRowCount - 1;
                            Utils.EnsureSize(ref _features, index + 1);
                            Utils.EnsureSize(ref _labels, index + 1);
                            if (_weights != null)
                                Utils.EnsureSize(ref _weights, index + 1);
                                _weights[index] = cursor.Weight;
                            Utils.Swap(ref _features[index], ref cursor.Features);
                            _labels[index] = cursor.Label;

                            if (cursor.KeptRowCount >= int.MaxValue)
                                ch.Warning("Limiting data size for multi-threading");
                    NumGoodRows = cursor.KeptRowCount;
                    numBad      = cursor.SkippedRowCount;
            ch.Check(NumGoodRows > 0, NoTrainingInstancesMessage);
            if (numBad > 0)
                ch.Warning("Skipped {0} instances with missing features/label/weight during training", numBad);

            if (_features != null)
                ch.Assert(numThreads > 1);

                // If there are so many threads that each only gets a small number (less than 10) of instances, trim
                // the number of threads so each gets a more reasonable number (100 or so). These numbers are pretty arbitrary,
                // but avoid the possibility of having no instances on some threads.
                if (numThreads > 1 && NumGoodRows / numThreads < 10)
                    int numNew = Math.Max(1, (int)NumGoodRows / 100);
                    ch.Warning("Too few instances to use {0} threads, decreasing to {1} thread(s)", numThreads, numNew);
                    numThreads = numNew;
                ch.Assert(numThreads > 0);

                // Divide up the instances among the threads.
                _numChunks = numThreads;
                _ranges    = new int[_numChunks + 1];
                int cinstTot = (int)NumGoodRows;
                for (int ichk = 0, iinstMin = 0; ichk < numThreads; ichk++)
                    int cchkLeft = numThreads - ichk;                                // Number of chunks left to fill.
                    ch.Assert(0 < cchkLeft && cchkLeft <= numThreads);
                    int cinstThis = (cinstTot - iinstMin + cchkLeft - 1) / cchkLeft; // Size of this chunk.
                    ch.Assert(0 < cinstThis && cinstThis <= cinstTot - iinstMin);
                    iinstMin         += cinstThis;
                    _ranges[ichk + 1] = iinstMin;

                _localLosses    = new float[numThreads];
                _localGradients = new VBuffer <float> [numThreads - 1];
                int size = BiasCount + WeightCount;
                for (int i = 0; i < _localGradients.Length; i++)
                    _localGradients[i] = VBufferUtils.CreateEmpty <float>(size);

                ch.Assert(_numChunks > 0 && _data == null);
                // Streaming, single-threaded case.
                _data          = data;
                _cursorFactory = cursorFactory;
                ch.Assert(_numChunks == 0 && _data != null);

            VBuffer <float>       initWeights;
            ITerminationCriterion terminationCriterion;
            Optimizer             opt = InitializeOptimizer(ch, cursorFactory, out initWeights, out terminationCriterion);

            opt.Quiet = Quiet;

            float loss;

                opt.Minimize(DifferentiableFunction, ref initWeights, terminationCriterion, ref CurrentWeights, out loss);
            catch (Optimizer.PrematureConvergenceException e)
                if (!Quiet)
                    ch.Warning("Premature convergence occurred. The OptimizationTolerance may be set too small. {0}", e.Message);
                CurrentWeights = e.State.X;
                loss           = e.State.Value;

            ch.Assert(CurrentWeights.Length == BiasCount + WeightCount);

            int numParams = BiasCount;

            if ((L1Weight > 0 && !Quiet) || ShowTrainingStats)
                VBufferUtils.ForEachDefined(ref CurrentWeights, (index, value) => { if (index >= BiasCount && value != 0)
                if (L1Weight > 0 && !Quiet)
                    ch.Info("L1 regularization selected {0} of {1} weights.", numParams, BiasCount + WeightCount);

            if (ShowTrainingStats)
                ComputeTrainingStatistics(ch, cursorFactory, loss, numParams);
        /// <inheritdoc/>
        private protected override bool CheckConvergence(
            IProgressChannel pch,
            int iter,
            FloatLabelCursor.Factory cursorFactory,
            DualsTableBase duals,
            IdToIdxLookup idToIdx,
            VBuffer <float>[] weights,
            VBuffer <float>[] bestWeights,
            float[] biasUnreg,
            float[] bestBiasUnreg,
            float[] biasReg,
            float[] bestBiasReg,
            long count,
            Double[] metrics,
            ref Double bestPrimalLoss,
            ref int bestIter)
            int numClasses = weights.Length;

            Contracts.Assert(duals.Length >= numClasses * count);
            Contracts.Assert(Utils.Size(weights) == numClasses);
            Contracts.Assert(Utils.Size(biasReg) == numClasses);
            Contracts.Assert(Utils.Size(biasUnreg) == numClasses);
            Contracts.Assert(Utils.Size(metrics) == 6);
            var reportedValues = new Double?[metrics.Length + 1];

            reportedValues[metrics.Length] = iter;
            var lossSum     = new CompensatedSum();
            var dualLossSum = new CompensatedSum();
            int numFeatures = weights[0].Length;

            using (var cursor = cursorFactory.Create())
                long row = 0;
                Func <DataViewRowId, long, long> getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length);
                // Iterates through data to compute loss function.
                while (cursor.MoveNext())
                    var    instanceWeight = GetInstanceWeight(cursor);
                    var    features       = cursor.Features;
                    var    label          = (int)cursor.Label;
                    var    labelOutput    = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]);
                    Double subLoss        = 0;
                    Double subDualLoss    = 0;
                    long   idx            = getIndexFromIdAndRow(cursor.Id, row);
                    long   dualIndex      = idx * numClasses;
                    for (int iClass = 0; iClass < numClasses; iClass++)
                        if (iClass == label)

                        var currentClassOutput = WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
                        subLoss += _loss.Loss(labelOutput - currentClassOutput, 1);
                        Contracts.Assert(dualIndex == iClass + idx * numClasses);
                        var dual = duals[dualIndex++];
                        subDualLoss += _loss.DualLoss(1, dual);

                    lossSum.Add(subLoss * instanceWeight);
                    dualLossSum.Add(subDualLoss * instanceWeight);

                Host.Assert(idToIdx == null || row * numClasses == duals.Length);

            Double l2Const     = SdcaTrainerOptions.L2Regularization.Value;
            Double l1Threshold = SdcaTrainerOptions.L1Threshold.Value;

            Double weightsL1Norm                = 0;
            Double weightsL2NormSquared         = 0;
            Double biasRegularizationAdjustment = 0;

            for (int iClass = 0; iClass < numClasses; iClass++)
                weightsL1Norm                += VectorUtils.L1Norm(in weights[iClass]) + Math.Abs(biasReg[iClass]);
                weightsL2NormSquared         += VectorUtils.NormSquared(weights[iClass]) + biasReg[iClass] * biasReg[iClass];
                biasRegularizationAdjustment += biasReg[iClass] * biasUnreg[iClass];

            Double l1Regularizer = SdcaTrainerOptions.L1Threshold.Value * l2Const * weightsL1Norm;
            var    l2Regularizer = l2Const * weightsL2NormSquared * 0.5;

            var newLoss     = lossSum.Sum / count + l2Regularizer + l1Regularizer;
            var newDualLoss = dualLossSum.Sum / count - l2Regularizer - l2Const * biasRegularizationAdjustment;
            var dualityGap  = newLoss - newDualLoss;

            metrics[(int)MetricKind.Loss]       = newLoss;
            metrics[(int)MetricKind.DualLoss]   = newDualLoss;
            metrics[(int)MetricKind.DualityGap] = dualityGap;
            metrics[(int)MetricKind.BiasUnreg]  = biasUnreg[0];
            metrics[(int)MetricKind.BiasReg]    = biasReg[0];
            metrics[(int)MetricKind.L1Sparsity] = SdcaTrainerOptions.L1Threshold == 0 ? 1 : weights.Sum(
                weight => weight.GetValues().Count(w => w != 0)) / (numClasses * numFeatures);

            bool converged = dualityGap / newLoss < SdcaTrainerOptions.ConvergenceTolerance;

            if (metrics[(int)MetricKind.Loss] < bestPrimalLoss)
                for (int iClass = 0; iClass < numClasses; iClass++)
                    // Maintain a copy of weights and bias with best primal loss thus far.
                    // This is some extra work and uses extra memory, but it seems worth doing it.
                    // REVIEW: Sparsify bestWeights?
                    weights[iClass].CopyTo(ref bestWeights[iClass]);
                    bestBiasReg[iClass]   = biasReg[iClass];
                    bestBiasUnreg[iClass] = biasUnreg[iClass];

                bestPrimalLoss = metrics[(int)MetricKind.Loss];
                bestIter       = iter;

            for (int i = 0; i < metrics.Length; i++)
                reportedValues[i] = metrics[i];
            if (pch != null)

        /// <inheritdoc/>
        private protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, Random rand,
                                                         IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, float[] biasReg, float[] invariants, float lambdaNInv,
                                                         VBuffer <float>[] weights, float[] biasUnreg, VBuffer <float>[] l1IntermediateWeights, float[] l1IntermediateBias, float[] featureNormSquared)
            int numClasses = Utils.Size(weights);

            Contracts.Assert(Utils.Size(biasReg) == numClasses);
            Contracts.Assert(Utils.Size(biasUnreg) == numClasses);

            int  maxUpdateTrials = 2 * numThreads;
            var  l1Threshold     = SdcaTrainerOptions.L1Threshold.Value;
            bool l1ThresholdZero = l1Threshold == 0;
            var  lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Regularization.Value;

            var pch = progress != null?progress.StartProgressChannel("Dual update") : null;

            using (pch)
                using (var cursor = SdcaTrainerOptions.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create())
                    long rowCount = 0;
                    if (pch != null)
                        pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount));

                    Func <DataViewRowId, long> getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length);
                    while (cursor.MoveNext())
                        long  idx = getIndexFromId(cursor.Id);
                        long  dualIndexInitPos = idx * numClasses;
                        var   features         = cursor.Features;
                        var   label            = (int)cursor.Label;
                        float invariant;
                        float normSquared;
                        if (invariants != null)
                            invariant = invariants[idx];
                            normSquared = featureNormSquared[idx];
                            normSquared = VectorUtils.NormSquared(in features);
                            if (SdcaTrainerOptions.BiasLearningRate == 0)
                                normSquared += 1;

                            invariant = _loss.ComputeDualUpdateInvariant(2 * normSquared * lambdaNInv * GetInstanceWeight(cursor));

                        // The output for the label class using current weights and bias.
                        var labelOutput    = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]);
                        var instanceWeight = GetInstanceWeight(cursor);

                        // This will be the new dual variable corresponding to the label class.
                        float labelDual = 0;

                        // This will be used to update the weights and regularized bias corresponding to the label class.
                        float labelPrimalUpdate = 0;

                        // This will be used to update the unregularized bias corresponding to the label class.
                        float labelAdjustment = 0;

                        // Iterates through all classes.
                        for (int iClass = 0; iClass < numClasses; iClass++)
                            // Skip the dual/weights/bias update for label class. Will be taken care of at the end.
                            if (iClass == label)

                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[iClass]);
                            var l1IntermediateWeightsEditor =
                                !l1ThresholdZero?VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[iClass]) :

                            // Loop trials for compare-and-swap updates of duals.
                            // In general, concurrent update conflict to the same dual variable is rare
                            // if data is shuffled.
                            for (int numTrials = 0; numTrials < maxUpdateTrials; numTrials++)
                                long dualIndex  = iClass + dualIndexInitPos;
                                var  dual       = duals[dualIndex];
                                var  output     = labelOutput + labelPrimalUpdate * normSquared - WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
                                var  dualUpdate = _loss.DualUpdate(output, 1, dual, invariant, numThreads);

                                // The successive over-relaxation approach to adjust the sum of dual variables (biasReg) to zero.
                                // Reference to details: http://stat.rutgers.edu/home/tzhang/papers/ml02_dual.pdf, pp. 16-17.
                                var adjustment = l1ThresholdZero ? lr * biasReg[iClass] : lr * l1IntermediateBias[iClass];
                                dualUpdate -= adjustment;
                                bool success = false;
                                duals.ApplyAt(dualIndex, (long index, ref float value) =>
                                              success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual);

                                if (success)
                                    // Note: dualConstraint[iClass] = lambdaNInv * (sum of duals[iClass])
                                    var primalUpdate = dualUpdate * lambdaNInv * instanceWeight;
                                    labelDual         -= dual + dualUpdate;
                                    labelPrimalUpdate += primalUpdate;
                                    biasUnreg[iClass] += adjustment * lambdaNInv * instanceWeight;
                                    labelAdjustment   -= adjustment;

                                    if (l1ThresholdZero)
                                        VectorUtils.AddMult(in features, weightsEditor.Values, -primalUpdate);
                                        biasReg[iClass] -= primalUpdate;
                                        //Iterative shrinkage-thresholding (aka. soft-thresholding)
                                        //Update v=denseWeights as if there's no L1
                                        //Thresholding: if |v[j]| < threshold, turn off weights[j]
                                        //If not, shrink: w[j] = v[i] - sign(v[j]) * threshold
                                        l1IntermediateBias[iClass] -= primalUpdate;
                                        if (SdcaTrainerOptions.BiasLearningRate == 0)
                                            biasReg[iClass] = Math.Abs(l1IntermediateBias[iClass]) - l1Threshold > 0.0
                                        ? l1IntermediateBias[iClass] - Math.Sign(l1IntermediateBias[iClass]) * l1Threshold
                                        : 0;

                                        var featureValues = features.GetValues();
                                        if (features.IsDense)
                                            CpuMathUtils.SdcaL1UpdateDense(-primalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                                        else if (featureValues.Length > 0)
                                            CpuMathUtils.SdcaL1UpdateSparse(-primalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);


                        // Updating with label class weights and dual variable.
                        duals[label + dualIndexInitPos] = labelDual;
                        biasUnreg[label] += labelAdjustment * lambdaNInv * instanceWeight;
                        if (l1ThresholdZero)
                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
                            VectorUtils.AddMult(in features, weightsEditor.Values, labelPrimalUpdate);
                            biasReg[label] += labelPrimalUpdate;
                            l1IntermediateBias[label] += labelPrimalUpdate;
                            var intermediateBias = l1IntermediateBias[label];
                            biasReg[label] = Math.Abs(intermediateBias) - l1Threshold > 0.0
                            ? intermediateBias - Math.Sign(intermediateBias) * l1Threshold
                            : 0;

                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
                            var l1IntermediateWeightsEditor = VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[label]);
                            var featureValues = features.GetValues();
                            if (features.IsDense)
                                CpuMathUtils.SdcaL1UpdateDense(labelPrimalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                            else if (featureValues.Length > 0)
                                CpuMathUtils.SdcaL1UpdateSparse(labelPrimalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);

 private protected virtual int ComputeNumThreads(FloatLabelCursor.Factory cursorFactory)
 => Math.Min(8, Math.Max(1, Environment.ProcessorCount / 2));
 protected abstract void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory cursorFactory, float loss, int numParams);
        /// <summary>
        /// Initialize weights by running SGD up to specified tolerance.
        /// </summary>
        protected virtual VBuffer <float> InitializeWeightsSgd(IChannel ch, FloatLabelCursor.Factory cursorFactory)
            if (!Quiet)
                ch.Info("Running SGD initialization with tolerance {0}", SgdInitializationTolerance);

            int        numExamples  = 0;
            var        oldWeights   = VBufferUtils.CreateEmpty <float>(BiasCount + WeightCount);
            DTerminate terminateSgd =
                (in VBuffer <float> x) =>
                if (++numExamples % 1000 != 0)
                VectorUtils.AddMult(in x, -1, ref oldWeights);
                float normDiff = VectorUtils.Norm(oldWeights);
                x.CopyTo(ref oldWeights);
                // #if OLD_TRACING // REVIEW: How should this be ported?
                if (!Quiet)
                    if (numExamples % 50000 == 0)
                        Console.WriteLine("\t{0}\t{1}", numExamples, normDiff);
                // #endif
                return(normDiff < SgdInitializationTolerance);

            VBuffer <float>  result = default(VBuffer <float>);
            FloatLabelCursor cursor = null;

                float[] scratch = null;

                SgdOptimizer.DStochasticGradient lossSgd =
                    (in VBuffer <float> x, ref VBuffer <float> grad) =>
                    // Zero out the gradient by sparsifying.
                    grad = new VBuffer <float>(grad.Length, 0, grad.Values, grad.Indices);
                    EnsureBiases(ref grad);

                    if (cursor == null || !cursor.MoveNext())
                        if (cursor != null)
                        cursor = cursorFactory.Create();
                        if (!cursor.MoveNext())
                    AccumulateOneGradient(in cursor.Features, cursor.Label, cursor.Weight, in x, ref grad, ref scratch);

                VBuffer <float> sgdWeights;
                if (DenseOptimizer)
                    sgdWeights = VBufferUtils.CreateDense <float>(BiasCount + WeightCount);
                    sgdWeights = VBufferUtils.CreateEmpty <float>(BiasCount + WeightCount);
                SgdOptimizer sgdo = new SgdOptimizer(terminateSgd);
                sgdo.Minimize(lossSgd, ref sgdWeights, ref result);
                // #if OLD_TRACING // REVIEW: How should this be ported?
                if (!Quiet)
                // #endif
                ch.Info("SGD initialization done in {0} rounds", numExamples);
                if (cursor != null)

        private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount)
            int numFeatures   = data.Schema.Feature.Value.Type.GetVectorSize();
            var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features);
            int numThreads    = 1;

            ch.CheckUserArg(numThreads > 0, nameof(_options.NumberOfThreads),
                            "The number of threads must be either null or a positive integer.");

            var             positiveInstanceWeight = _options.PositiveInstanceWeight;
            VBuffer <float> weights = default;
            float           bias    = 0.0f;

            if (predictor != null)
                predictor.GetFeatureWeights(ref weights);
                VBufferUtils.Densify(ref weights);
                bias = predictor.Bias;
                weights = VBufferUtils.CreateDense <float>(numFeatures);

            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights);

            // Reference: Parasail. SymSGD.
            bool tuneLR = _options.LearningRate == null;
            var  lr     = _options.LearningRate ?? 1.0f;

            bool tuneNumLocIter = (_options.UpdateFrequency == null);
            var  numLocIter     = _options.UpdateFrequency ?? 1;

            var l2Const = _options.L2Regularization;
            var piw     = _options.PositiveInstanceWeight;

            // This is state of the learner that is shared with the native code.
            State    state         = new State();
            GCHandle stateGCHandle = default;

                stateGCHandle = GCHandle.Alloc(state, GCHandleType.Pinned);

                state.TotalInstancesProcessed = 0;
                using (InputDataManager inputDataManager = new InputDataManager(this, cursorFactory, ch))
                    bool shouldInitialize = true;
                    using (var pch = Host.StartProgressChannel("Preprocessing"))

                    int iter = 0;
                    if (inputDataManager.IsFullyLoaded)
                        ch.Info("Data fully loaded into memory.");
                    using (var pch = Host.StartProgressChannel("Training"))
                        if (inputDataManager.IsFullyLoaded)
                            pch.SetHeader(new ProgressHeader(new[] { "iterations" }),
                                          entry => entry.SetProgress(0, state.PassIteration, _options.NumberOfIterations));
                            // If fully loaded, call the SymSGDNative and do not come back until learned for all iterations.
                            Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
                                            _options.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize,
                                            stateGCHandle, ch.Info);
                            shouldInitialize = false;
                            pch.SetHeader(new ProgressHeader(new[] { "iterations" }),
                                          entry => entry.SetProgress(0, iter, _options.NumberOfIterations));

                            // Since we loaded data in batch sizes, multiple passes over the loaded data is feasible.
                            int numPassesForABatch = inputDataManager.Count / 10000;
                            while (iter < _options.NumberOfIterations)
                                // We want to train on the final passes thoroughly (without learning on the same batch multiple times)
                                // This is for fine tuning the AUC. Experimentally, we found that 1 or 2 passes is enough
                                int numFinalPassesToTrainThoroughly = 2;
                                // We also do not want to learn for more passes than what the user asked
                                int numPassesForThisBatch = Math.Min(numPassesForABatch, _options.NumberOfIterations - iter - numFinalPassesToTrainThoroughly);
                                // If all of this leaves us with 0 passes, then set numPassesForThisBatch to 1
                                numPassesForThisBatch = Math.Max(1, numPassesForThisBatch);
                                state.PassIteration   = iter;
                                Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
                                                numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize,
                                                stateGCHandle, ch.Info);
                                shouldInitialize = false;

                                // Check if we are done with going through the data
                                if (inputDataManager.FinishedTheLoad)
                                    iter += numPassesForThisBatch;
                                    // Check if more passes are left
                                    if (iter < _options.NumberOfIterations)
                                        inputDataManager.RestartLoading(_options.Shuffle, Host);

                                // If more passes are left, load as much as possible
                                if (iter < _options.NumberOfIterations)

                        // Maps back the dense features that are mislocated
                        if (numThreads > 1)
                            Native.MapBackWeightVector(weightsEditor.Values, stateGCHandle);
                if (stateGCHandle.IsAllocated)
            return(CreatePredictor(weights, bias));
        /// <summary>
        /// Create a dataset from the sampling data.
        /// </summary>
        private void CreateDatasetFromSamplingData(IChannel ch, FloatLabelCursor.Factory factory,
                                                   int numRow, string param, float[] labels, float[] weights, int[] groups, CategoricalMetaData catMetaData,
                                                   out Dataset dataset)

            int numSampleRow = GetNumSampleRow(numRow, FeatureCount);

            var    rand        = Host.Rand;
            double averageStep = (double)numRow / numSampleRow;
            int    totalIdx    = 0;
            int    sampleIdx   = 0;
            double density     = DetectDensity(factory);

            double[][] sampleValuePerColumn   = new double[catMetaData.NumCol][];
            int[][]    sampleIndicesPerColumn = new int[catMetaData.NumCol][];
            int[]      nonZeroCntPerColumn    = new int[catMetaData.NumCol];
            int        estimateNonZeroCnt     = (int)(numSampleRow * density);

            estimateNonZeroCnt = Math.Max(1, estimateNonZeroCnt);
            for (int i = 0; i < catMetaData.NumCol; i++)
                nonZeroCntPerColumn[i]    = 0;
                sampleValuePerColumn[i]   = new double[estimateNonZeroCnt];
                sampleIndicesPerColumn[i] = new int[estimateNonZeroCnt];
            using (var cursor = factory.Create())
                int step = 1;
                if (averageStep > 1)
                    step = rand.Next((int)(2 * averageStep - 1)) + 1;
                while (MoveMany(cursor, step))
                    if (cursor.Features.IsDense)
                        GetFeatureValueDense(ch, cursor, catMetaData, rand, out ReadOnlySpan <float> featureValues);
                        for (int i = 0; i < catMetaData.NumCol; ++i)
                            float fv = featureValues[i];
                            if (fv == 0)
                            int curNonZeroCnt = nonZeroCntPerColumn[i];
                            Utils.EnsureSize(ref sampleValuePerColumn[i], curNonZeroCnt + 1);
                            Utils.EnsureSize(ref sampleIndicesPerColumn[i], curNonZeroCnt + 1);
                            // sampleValuePerColumn[i] is a vector whose j-th element is added when j-th non-zero value
                            // at the i-th feature is found as scanning the training data.
                            // In other words, sampleValuePerColumn[i][j] is the j-th non-zero i-th feature in the data set.
                            // when we scan the data matrix example-by-example.
                            sampleValuePerColumn[i][curNonZeroCnt] = fv;
                            // If the data set is dense, sampleValuePerColumn[i][j] would be the i-th feature at the j-th example.
                            // If the data set is not dense, sampleValuePerColumn[i][j] would be the i-th feature at the
                            // sampleIndicesPerColumn[i][j]-th example.
                            sampleIndicesPerColumn[i][curNonZeroCnt] = sampleIdx;
                            // The number of non-zero values at the i-th feature is nonZeroCntPerColumn[i].
                            nonZeroCntPerColumn[i] = curNonZeroCnt + 1;
                        GetFeatureValueSparse(ch, cursor, catMetaData, rand, out ReadOnlySpan <int> featureIndices, out ReadOnlySpan <float> featureValues, out int cnt);
                        for (int i = 0; i < cnt; ++i)
                            int   colIdx = featureIndices[i];
                            float fv     = featureValues[i];
                            if (fv == 0)
                            int curNonZeroCnt = nonZeroCntPerColumn[colIdx];
                            Utils.EnsureSize(ref sampleValuePerColumn[colIdx], curNonZeroCnt + 1);
                            Utils.EnsureSize(ref sampleIndicesPerColumn[colIdx], curNonZeroCnt + 1);
                            sampleValuePerColumn[colIdx][curNonZeroCnt]   = fv;
                            sampleIndicesPerColumn[colIdx][curNonZeroCnt] = sampleIdx;
                            nonZeroCntPerColumn[colIdx] = curNonZeroCnt + 1;
                    // Actual row indexed sampled from the original data set
                    totalIdx += step;
                    // Row index in the sub-sampled data created in this loop.
                    if (numSampleRow == sampleIdx || numRow == totalIdx)
                    averageStep = (double)(numRow - totalIdx) / (numSampleRow - sampleIdx);
                    step        = 1;
                    if (averageStep > 1)
                        step = rand.Next((int)(2 * averageStep - 1)) + 1;
            dataset = new Dataset(sampleValuePerColumn, sampleIndicesPerColumn, catMetaData.NumCol, nonZeroCntPerColumn, sampleIdx, numRow, param, labels, weights, groups);
        protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory cursorFactory, Float loss, int numParams)
            Contracts.Assert(NumGoodRows > 0);
            Contracts.Assert(WeightSum > 0);
            Contracts.Assert(BiasCount == 1);
            Contracts.Assert(loss >= 0);
            Contracts.Assert(numParams >= BiasCount);

            ch.Info("Model trained with {0} training examples.", NumGoodRows);

            // Compute deviance: start with loss function.
            Float deviance = (Float)(2 * loss * WeightSum);

            if (L2Weight > 0)
                // Need to subtract L2 regularization loss.
                // The bias term is not regularized.
                var regLoss = VectorUtils.NormSquared(CurrentWeights.Values, 1, CurrentWeights.Length - 1) * L2Weight;
                deviance -= regLoss;

            if (L1Weight > 0)
                // Need to subtract L1 regularization loss.
                // The bias term is not regularized.
                Double regLoss = 0;
                VBufferUtils.ForEachDefined(ref CurrentWeights, (ind, value) => { if (ind >= BiasCount)
                                                                                      regLoss += Math.Abs(value);
                deviance -= (Float)regLoss * L1Weight * 2;

            ch.Info("Residual Deviance: \t{0} (on {1} degrees of freedom)", deviance, Math.Max(NumGoodRows - numParams, 0));

            // Compute null deviance, i.e., the deviance of null hypothesis.
            // Cap the prior positive rate at 1e-15.
            Double priorPosRate = _posWeight / WeightSum;

            Contracts.Assert(0 <= priorPosRate && priorPosRate <= 1);
            Float nullDeviance = (priorPosRate <= 1e-15 || 1 - priorPosRate <= 1e-15) ?
                                 0f : (Float)(2 * WeightSum * MathUtils.Entropy(priorPosRate, true));

            ch.Info("Null Deviance:     \t{0} (on {1} degrees of freedom)", nullDeviance, NumGoodRows - 1);

            // Compute AIC.
            ch.Info("AIC:               \t{0}", 2 * numParams + deviance);

            // Show the coefficients statistics table.
            var featureColIdx = cursorFactory.Data.Schema.Feature.Index;
            var schema        = cursorFactory.Data.Data.Schema;
            var featureLength = CurrentWeights.Length - BiasCount;
            var namesSpans    = VBufferUtils.CreateEmpty <DvText>(featureLength);

            if (schema.HasSlotNames(featureColIdx, featureLength))
                schema.GetMetadata(MetadataUtils.Kinds.SlotNames, featureColIdx, ref namesSpans);
            Host.Assert(namesSpans.Length == featureLength);

            // Inverse mapping of non-zero weight slots.
            Dictionary <int, int> weightIndicesInvMap = null;

            // Indices of bias and non-zero weight slots.
            int[] weightIndices = null;

            // Whether all weights are non-zero.
            bool denseWeight = numParams == CurrentWeights.Length;

            // Extract non-zero indices of weight.
            if (!denseWeight)
                weightIndices          = new int[numParams];
                weightIndicesInvMap    = new Dictionary <int, int>(numParams);
                weightIndices[0]       = 0;
                weightIndicesInvMap[0] = 0;
                int j = 1;
                for (int i = 1; i < CurrentWeights.Length; i++)
                    if (CurrentWeights.Values[i] != 0)
                        weightIndices[j]       = i;
                        weightIndicesInvMap[i] = j++;

                Contracts.Assert(j == numParams);

            // Compute the standard error of coefficients.
            long hessianDimension = (long)numParams * (numParams + 1) / 2;

            if (hessianDimension > int.MaxValue)
                ch.Warning("The number of parameter is too large. Cannot hold the variance-covariance matrix in memory. " +
                           "Skipping computation of standard errors and z-statistics of coefficients. Consider choosing a larger L1 regularizer" +
                           "to reduce the number of parameters.");
                _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance);

            // Building the variance-covariance matrix for parameters.
            // The layout of this algorithm is a packed row-major lower triangular matrix.
            // E.g., layout of indices for 4-by-4:
            // 0
            // 1 2
            // 3 4 5
            // 6 7 8 9
            var hessian = new Double[hessianDimension];

            // Initialize diagonal elements with L2 regularizers except for the first entry (index 0)
            // Since bias is not regularized.
            if (L2Weight > 0)
                // i is the array index of the diagonal entry at iRow-th row and iRow-th column.
                // iRow is one-based.
                int i = 0;
                for (int iRow = 2; iRow <= numParams; iRow++)
                    i         += iRow;
                    hessian[i] = L2Weight;

                Contracts.Assert(i == hessian.Length - 1);

            // Initialize the remaining entries.
            var bias = CurrentWeights.Values[0];

            using (var cursor = cursorFactory.Create())
                while (cursor.MoveNext())
                    var label  = cursor.Label;
                    var weight = cursor.Weight;
                    var score  = bias + VectorUtils.DotProductWithOffset(ref CurrentWeights, 1, ref cursor.Features);
                    // Compute Bernoulli variance n_i * p_i * (1 - p_i) for the i-th training example.
                    var variance = weight / (2 + 2 * Math.Cosh(score));

                    // Increment the first entry of hessian.
                    hessian[0] += variance;

                    var values = cursor.Features.Values;
                    if (cursor.Features.IsDense)
                        int ioff = 1;

                        // Increment remaining entries of hessian.
                        for (int i = 1; i < numParams; i++)
                            ch.Assert(ioff == i * (i + 1) / 2);
                            int wi = weightIndices == null ? i - 1 : weightIndices[i] - 1;
                            Contracts.Assert(0 <= wi && wi < cursor.Features.Length);
                            var val = values[wi] * variance;
                            // Add the implicit first bias term to X'X
                            hessian[ioff++] += val;
                            // Add the remainder of X'X
                            for (int j = 0; j < i; j++)
                                int wj = weightIndices == null ? j : weightIndices[j + 1] - 1;
                                Contracts.Assert(0 <= wj && wj < cursor.Features.Length);
                                hessian[ioff++] += val * values[wj];
                        ch.Assert(ioff == hessian.Length);
                        var indices = cursor.Features.Indices;
                        for (int ii = 0; ii < cursor.Features.Count; ++ii)
                            int i  = indices[ii];
                            int wi = i + 1;
                            if (weightIndicesInvMap != null && !weightIndicesInvMap.TryGetValue(i + 1, out wi))

                            Contracts.Assert(0 < wi && wi <= cursor.Features.Length);
                            int ioff = wi * (wi + 1) / 2;
                            var val  = values[ii] * variance;
                            // Add the implicit first bias term to X'X
                            hessian[ioff] += val;
                            // Add the remainder of X'X
                            for (int jj = 0; jj <= ii; jj++)
                                int j  = indices[jj];
                                int wj = j + 1;
                                if (weightIndicesInvMap != null && !weightIndicesInvMap.TryGetValue(j + 1, out wj))

                                Contracts.Assert(0 < wj && wj <= cursor.Features.Length);
                                hessian[ioff + wj] += val * values[jj];

            // Apply Cholesky Decomposition to find the inverse of the Hessian.
            Double[] invHessian = null;
                // First, find the Cholesky decomposition LL' of the Hessian.
                Mkl.Pptrf(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numParams, hessian);
                // Note that hessian is already modified at this point. It is no longer the original Hessian,
                // but instead represents the Cholesky decomposition L.
                // Also note that the following routine is supposed to consume the Cholesky decomposition L instead
                // of the original information matrix.
                Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numParams, hessian);
                // At this point, hessian should contain the inverse of the original Hessian matrix.
                // Swap hessian with invHessian to avoid confusion in the following context.
                Utils.Swap(ref hessian, ref invHessian);
                Contracts.Assert(hessian == null);
            catch (DllNotFoundException)
                throw ch.ExceptNotSupp("The MKL library (Microsoft.ML.MklImports.dll) or one of its dependencies is missing.");

            Float[] stdErrorValues = new Float[numParams];
            stdErrorValues[0] = (Float)Math.Sqrt(invHessian[0]);

            for (int i = 1; i < numParams; i++)
                // Initialize with inverse Hessian.
                stdErrorValues[i] = (Single)invHessian[i * (i + 1) / 2 + i];

            if (L2Weight > 0)
                // Iterate through all entries of inverse Hessian to make adjustment to variance.
                // A discussion on ridge regularized LR coefficient covariance matrix can be found here:
                // http://www.ncbi.nlm.nih.gov/pmc/articles/PMC3228544/
                // http://www.inf.unibz.it/dis/teaching/DWDM/project2010/LogisticRegression.pdf
                int ioffset = 1;
                for (int iRow = 1; iRow < numParams; iRow++)
                    for (int iCol = 0; iCol <= iRow; iCol++)
                        var entry      = (Single)invHessian[ioffset];
                        var adjustment = -L2Weight * entry * entry;
                        stdErrorValues[iRow] -= adjustment;
                        if (0 < iCol && iCol < iRow)
                            stdErrorValues[iCol] -= adjustment;

                Contracts.Assert(ioffset == invHessian.Length);

            for (int i = 1; i < numParams; i++)
                stdErrorValues[i] = (Float)Math.Sqrt(stdErrorValues[i]);

            VBuffer <Float> stdErrors = new VBuffer <Float>(CurrentWeights.Length, numParams, stdErrorValues, weightIndices);

            _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance, ref stdErrors);
 protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory factory, float loss, int numParams)
     // No-op by design.
        private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)

            int m = featureCount + 1;

            // Check for memory conditions first.
            if ((long)m * (m + 1) / 2 > int.MaxValue)
                throw ch.Except("Cannot hold covariance matrix in memory with {0} features", m - 1);

            // Track the number of examples.
            long n = 0;
            // Since we are accumulating over many values, we use Double even for the single precision build.
            var xty = new Double[m];
            // The layout of this algorithm is a packed row-major lower triangular matrix.
            var xtx = new Double[m * (m + 1) / 2];

            // Build X'X (lower triangular) and X'y incrementally (X'X+=X'X_i; X'y+=X'y_i):
            using (var cursor = cursorFactory.Create())
                while (cursor.MoveNext())
                    var yi = cursor.Label;
                    // Increment first element of X'y
                    xty[0] += yi;
                    // Increment first element of lower triangular X'X
                    xtx[0] += 1;
                    var values = cursor.Features.GetValues();

                    if (cursor.Features.IsDense)
                        int ioff = 1;
                        ch.Assert(values.Length + 1 == m);
                        // Increment rest of first column of lower triangular X'X
                        for (int i = 1; i < m; i++)
                            ch.Assert(ioff == i * (i + 1) / 2);
                            var val = values[i - 1];
                            // Add the implicit first bias term to X'X
                            xtx[ioff++] += val;
                            // Add the remainder of X'X
                            for (int j = 0; j < i; j++)
                                xtx[ioff++] += val * values[j];
                            // X'y
                            xty[i] += val * yi;
                        ch.Assert(ioff == xtx.Length);
                        var fIndices = cursor.Features.GetIndices();
                        for (int ii = 0; ii < values.Length; ++ii)
                            int i    = fIndices[ii] + 1;
                            int ioff = i * (i + 1) / 2;
                            var val  = values[ii];
                            // Add the implicit first bias term to X'X
                            xtx[ioff++] += val;
                            // Add the remainder of X'X
                            for (int jj = 0; jj <= ii; jj++)
                                xtx[ioff + fIndices[jj]] += val * values[jj];
                            // X'y
                            xty[i] += val * yi;
                ch.Check(n > 0, "No training examples in dataset.");
                if (cursor.BadFeaturesRowCount > 0)
                    ch.Warning("Skipped {0} instances with missing features/label during training", cursor.SkippedRowCount);

                if (_l2Weight > 0)
                    // Skip the bias term for regularization, in the ridge regression case.
                    // So start at [1,1] instead of [0,0].

                    // REVIEW: There are two ways to view this, firstly, it is more
                    // user friendly ot make this scaling factor behave similarly regardless
                    // of data size, so that if you have the same parameters, you get the same
                    // model if you feed in your data than if you duplicate your data 10 times.
                    // This is what I have now. The alternate point of view is to view this
                    // L2 regularization parameter as providing some sort of prior, in which
                    // case duplication 10 times should in fact be treated differently! (That
                    // is, we should not multiply by n below.) Both interpretations seem
                    // correct, in their way.
                    Double squared = _l2Weight * _l2Weight * n;
                    int    ioff    = 0;
                    for (int i = 1; i < m; ++i)
                        xtx[ioff += i + 1] += squared;
                    ch.Assert(ioff == xtx.Length - 1);

            if (!(_l2Weight > 0) && n < m)
                throw ch.Except("Ordinary least squares requires more examples than parameters. There are {0} parameters, but {1} examples. To enable training, use a positive L2 weight so this behaves as ridge regression.", m, n);

            Double yMean = n == 0 ? 0 : xty[0] / n;

            ch.Info("Trainer solving for {0} parameters across {1} examples", m, n);
            // Cholesky Decomposition of X'X into LL'
                Mkl.Pptrf(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, xtx);
            catch (DllNotFoundException)
                // REVIEW: Is there no better way?
                throw ch.ExceptNotSupp("The MKL library (libMklImports) or one of its dependencies is missing.");
            // Solve for beta in (LL')beta = X'y:
            Mkl.Pptrs(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, 1, xtx, xty, 1);
            // Note that the solver overwrote xty so it contains the solution. To be more clear,
            // we effectively change its name (through reassignment) so we don't get confused that
            // this is somehow xty in the remaining calculation.
            var beta = xty;

            xty = null;
            // Check that the solution is valid.
            for (int i = 0; i < beta.Length; ++i)
                ch.Check(FloatUtils.IsFinite(beta[i]), "Non-finite values detected in OLS solution");

            var weights = VBufferUtils.CreateDense <float>(beta.Length - 1);

            for (int i = 1; i < beta.Length; ++i)
                weights.Values[i - 1] = (float)beta[i];
            var bias = (float)beta[0];

            if (!(_l2Weight > 0) && m == n)
                // We would expect the solution to the problem to be exact in this case.
                ch.Info("Number of examples equals number of parameters, solution is exact but no statistics can be derived");
                return(new OlsLinearRegressionPredictor(Host, in weights, bias, null, null, null, 1, float.NaN));

            Double rss = 0; // residual sum of squares
            Double tss = 0; // total sum of squares

            using (var cursor = cursorFactory.Create())
                var   lrPredictor = new LinearRegressionPredictor(Host, in weights, bias);
                var   lrMap       = lrPredictor.GetMapper <VBuffer <float>, float>();
                float yh          = default;
                while (cursor.MoveNext())
                    var features = cursor.Features;
                    lrMap(in features, ref yh);
                    var e = cursor.Label - yh;
                    rss += e * e;
                    var ydm = cursor.Label - yMean;
                    tss += ydm * ydm;
            var rSquared = ProbClamp(1 - (rss / tss));
            // R^2 adjusted differs from the normal formula on account of the bias term, by Said's reckoning.
            double rSquaredAdjusted;

            if (n > m)
                rSquaredAdjusted = ProbClamp(1 - (1 - rSquared) * (n - 1) / (n - m));
                ch.Info("Coefficient of determination R2 = {0:g}, or {1:g} (adjusted)",
                        rSquared, rSquaredAdjusted);
                rSquaredAdjusted = Double.NaN;

            // The per parameter significance is compute intensive and may not be required for all practitioners.
            // Also we can't estimate it, unless we can estimate the variance, which requires more examples than
            // parameters.
            if (!_perParameterSignificance || m >= n)
                return(new OlsLinearRegressionPredictor(Host, in weights, bias, null, null, null, rSquared, rSquaredAdjusted));

            var standardErrors = new Double[m];
            var tValues        = new Double[m];
            var pValues        = new Double[m];

            // Invert X'X:
            Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, xtx);
            var s2 = rss / (n - m); // estimate of variance of y

            for (int i = 0; i < m; i++)
                // Initialize with inverse Hessian.
                standardErrors[i] = (Single)xtx[i * (i + 1) / 2 + i];

            if (_l2Weight > 0)
                // Iterate through all entries of inverse Hessian to make adjustment to variance.
                int   ioffset = 1;
                float reg     = _l2Weight * _l2Weight * n;
                for (int iRow = 1; iRow < m; iRow++)
                    for (int iCol = 0; iCol <= iRow; iCol++)
                        var entry      = (Single)xtx[ioffset];
                        var adjustment = -reg * entry * entry;
                        standardErrors[iRow] -= adjustment;
                        if (0 < iCol && iCol < iRow)
                            standardErrors[iCol] -= adjustment;

                Contracts.Assert(ioffset == xtx.Length);

            for (int i = 0; i < m; i++)
                // sqrt of diagonal entries of s2 * inverse(X'X + reg * I) * X'X * inverse(X'X + reg * I).
                standardErrors[i] = Math.Sqrt(s2 * standardErrors[i]);
                ch.Check(FloatUtils.IsFinite(standardErrors[i]), "Non-finite standard error detected from OLS solution");
                tValues[i] = beta[i] / standardErrors[i];
                pValues[i] = (float)MathUtils.TStatisticToPValue(tValues[i], n - m);
                ch.Check(0 <= pValues[i] && pValues[i] <= 1, "p-Value calculated outside expected [0,1] range");

            return(new OlsLinearRegressionPredictor(Host, in weights, bias, standardErrors, tValues, pValues, rSquared, rSquaredAdjusted));
        protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory cursorFactory, float loss, int numParams)
            Contracts.Assert(NumGoodRows > 0);
            Contracts.Assert(WeightSum > 0);
            Contracts.Assert(BiasCount == 1);
            Contracts.Assert(loss >= 0);
            Contracts.Assert(numParams >= BiasCount);

            ch.Info("Model trained with {0} training examples.", NumGoodRows);

            // Compute deviance: start with loss function.
            float deviance = (float)(2 * loss * WeightSum);

            if (L2Weight > 0)
                // Need to subtract L2 regularization loss.
                // The bias term is not regularized.
                var regLoss = VectorUtils.NormSquared(CurrentWeights.Values, 1, CurrentWeights.Length - 1) * L2Weight;
                deviance -= regLoss;

            if (L1Weight > 0)
                // Need to subtract L1 regularization loss.
                // The bias term is not regularized.
                Double regLoss = 0;
                VBufferUtils.ForEachDefined(ref CurrentWeights, (ind, value) => { if (ind >= BiasCount)
                                                                                      regLoss += Math.Abs(value);
                deviance -= (float)regLoss * L1Weight * 2;

            ch.Info("Residual Deviance: \t{0} (on {1} degrees of freedom)", deviance, Math.Max(NumGoodRows - numParams, 0));

            // Compute null deviance, i.e., the deviance of null hypothesis.
            // Cap the prior positive rate at 1e-15.
            Double priorPosRate = _posWeight / WeightSum;

            Contracts.Assert(0 <= priorPosRate && priorPosRate <= 1);
            float nullDeviance = (priorPosRate <= 1e-15 || 1 - priorPosRate <= 1e-15) ?
                                 0f : (float)(2 * WeightSum * MathUtils.Entropy(priorPosRate, true));

            ch.Info("Null Deviance:     \t{0} (on {1} degrees of freedom)", nullDeviance, NumGoodRows - 1);

            // Compute AIC.
            ch.Info("AIC:               \t{0}", 2 * numParams + deviance);

            // Show the coefficients statistics table.
            var featureColIdx = cursorFactory.Data.Schema.Feature.Index;
            var schema        = cursorFactory.Data.Data.Schema;
            var featureLength = CurrentWeights.Length - BiasCount;
            var namesSpans    = VBufferUtils.CreateEmpty <ReadOnlyMemory <char> >(featureLength);

            if (schema.HasSlotNames(featureColIdx, featureLength))
                schema.GetMetadata(MetadataUtils.Kinds.SlotNames, featureColIdx, ref namesSpans);
            Host.Assert(namesSpans.Length == featureLength);

            // Inverse mapping of non-zero weight slots.
            Dictionary <int, int> weightIndicesInvMap = null;

            // Indices of bias and non-zero weight slots.
            int[] weightIndices = null;

            // Whether all weights are non-zero.
            bool denseWeight = numParams == CurrentWeights.Length;

            // Extract non-zero indices of weight.
            if (!denseWeight)
                weightIndices          = new int[numParams];
                weightIndicesInvMap    = new Dictionary <int, int>(numParams);
                weightIndices[0]       = 0;
                weightIndicesInvMap[0] = 0;
                int j = 1;
                for (int i = 1; i < CurrentWeights.Length; i++)
                    if (CurrentWeights.Values[i] != 0)
                        weightIndices[j]       = i;
                        weightIndicesInvMap[i] = j++;

                Contracts.Assert(j == numParams);

            // Compute the standard error of coefficients.
            long hessianDimension = (long)numParams * (numParams + 1) / 2;

            if (hessianDimension > int.MaxValue)
                ch.Warning("The number of parameter is too large. Cannot hold the variance-covariance matrix in memory. " +
                           "Skipping computation of standard errors and z-statistics of coefficients. Consider choosing a larger L1 regularizer" +
                           "to reduce the number of parameters.");
                _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance);

            // Building the variance-covariance matrix for parameters.
            // The layout of this algorithm is a packed row-major lower triangular matrix.
            // E.g., layout of indices for 4-by-4:
            // 0
            // 1 2
            // 3 4 5
            // 6 7 8 9
            var hessian = new Double[hessianDimension];

            // Initialize diagonal elements with L2 regularizers except for the first entry (index 0)
            // Since bias is not regularized.
            if (L2Weight > 0)
                // i is the array index of the diagonal entry at iRow-th row and iRow-th column.
                // iRow is one-based.
                int i = 0;
                for (int iRow = 2; iRow <= numParams; iRow++)
                    i         += iRow;
                    hessian[i] = L2Weight;

                Contracts.Assert(i == hessian.Length - 1);

            // Initialize the remaining entries.
            var bias = CurrentWeights.Values[0];

            using (var cursor = cursorFactory.Create())
                while (cursor.MoveNext())
                    var label  = cursor.Label;
                    var weight = cursor.Weight;
                    var score  = bias + VectorUtils.DotProductWithOffset(ref CurrentWeights, 1, ref cursor.Features);
                    // Compute Bernoulli variance n_i * p_i * (1 - p_i) for the i-th training example.
                    var variance = weight / (2 + 2 * Math.Cosh(score));

                    // Increment the first entry of hessian.
                    hessian[0] += variance;

                    var values = cursor.Features.Values;
                    if (cursor.Features.IsDense)
                        int ioff = 1;

                        // Increment remaining entries of hessian.
                        for (int i = 1; i < numParams; i++)
                            ch.Assert(ioff == i * (i + 1) / 2);
                            int wi = weightIndices == null ? i - 1 : weightIndices[i] - 1;
                            Contracts.Assert(0 <= wi && wi < cursor.Features.Length);
                            var val = values[wi] * variance;
                            // Add the implicit first bias term to X'X
                            hessian[ioff++] += val;
                            // Add the remainder of X'X
                            for (int j = 0; j < i; j++)
                                int wj = weightIndices == null ? j : weightIndices[j + 1] - 1;
                                Contracts.Assert(0 <= wj && wj < cursor.Features.Length);
                                hessian[ioff++] += val * values[wj];
                        ch.Assert(ioff == hessian.Length);
                        var indices = cursor.Features.Indices;
                        for (int ii = 0; ii < cursor.Features.Count; ++ii)
                            int i  = indices[ii];
                            int wi = i + 1;
                            if (weightIndicesInvMap != null && !weightIndicesInvMap.TryGetValue(i + 1, out wi))

                            Contracts.Assert(0 < wi && wi <= cursor.Features.Length);
                            int ioff = wi * (wi + 1) / 2;
                            var val  = values[ii] * variance;
                            // Add the implicit first bias term to X'X
                            hessian[ioff] += val;
                            // Add the remainder of X'X
                            for (int jj = 0; jj <= ii; jj++)
                                int j  = indices[jj];
                                int wj = j + 1;
                                if (weightIndicesInvMap != null && !weightIndicesInvMap.TryGetValue(j + 1, out wj))

                                Contracts.Assert(0 < wj && wj <= cursor.Features.Length);
                                hessian[ioff + wj] += val * values[jj];

            _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance);