/// <inheritdoc/>
        private protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, Random rand,
                                                         IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, float[] biasReg, float[] invariants, float lambdaNInv,
                                                         VBuffer <float>[] weights, float[] biasUnreg, VBuffer <float>[] l1IntermediateWeights, float[] l1IntermediateBias, float[] featureNormSquared)
        {
            Contracts.AssertValueOrNull(progress);
            Contracts.Assert(SdcaTrainerOptions.L1Regularization.HasValue);
            Contracts.AssertValueOrNull(idToIdx);
            Contracts.AssertValueOrNull(invariants);
            Contracts.AssertValueOrNull(featureNormSquared);
            int numClasses = Utils.Size(weights);

            Contracts.Assert(Utils.Size(biasReg) == numClasses);
            Contracts.Assert(Utils.Size(biasUnreg) == numClasses);

            int  maxUpdateTrials = 2 * numThreads;
            var  l1Threshold     = SdcaTrainerOptions.L1Regularization.Value;
            bool l1ThresholdZero = l1Threshold == 0;
            var  lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Regularization.Value;

            var pch = progress != null?progress.StartProgressChannel("Dual update") : null;

            using (pch)
                using (var cursor = SdcaTrainerOptions.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create())
                {
                    long rowCount = 0;
                    if (pch != null)
                    {
                        pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount));
                    }

                    Func <DataViewRowId, long> getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length);
                    while (cursor.MoveNext())
                    {
                        Host.CheckAlive();
                        long  idx = getIndexFromId(cursor.Id);
                        long  dualIndexInitPos = idx * numClasses;
                        var   features         = cursor.Features;
                        var   label            = (int)cursor.Label;
                        float invariant;
                        float normSquared;
                        if (invariants != null)
                        {
                            invariant = invariants[idx];
                            Contracts.AssertValue(featureNormSquared);
                            normSquared = featureNormSquared[idx];
                        }
                        else
                        {
                            normSquared = VectorUtils.NormSquared(in features);
                            if (SdcaTrainerOptions.BiasLearningRate == 0)
                            {
                                normSquared += 1;
                            }

                            invariant = _loss.ComputeDualUpdateInvariant(2 * normSquared * lambdaNInv * GetInstanceWeight(cursor));
                        }

                        // The output for the label class using current weights and bias.
                        var labelOutput    = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]);
                        var instanceWeight = GetInstanceWeight(cursor);

                        // This will be the new dual variable corresponding to the label class.
                        float labelDual = 0;

                        // This will be used to update the weights and regularized bias corresponding to the label class.
                        float labelPrimalUpdate = 0;

                        // This will be used to update the unregularized bias corresponding to the label class.
                        float labelAdjustment = 0;

                        // Iterates through all classes.
                        for (int iClass = 0; iClass < numClasses; iClass++)
                        {
                            // Skip the dual/weights/bias update for label class. Will be taken care of at the end.
                            if (iClass == label)
                            {
                                continue;
                            }

                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[iClass]);
                            var l1IntermediateWeightsEditor =
                                !l1ThresholdZero?VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[iClass]) :
                                    default;

                            // Loop trials for compare-and-swap updates of duals.
                            // In general, concurrent update conflict to the same dual variable is rare
                            // if data is shuffled.
                            for (int numTrials = 0; numTrials < maxUpdateTrials; numTrials++)
                            {
                                long dualIndex  = iClass + dualIndexInitPos;
                                var  dual       = duals[dualIndex];
                                var  output     = labelOutput + labelPrimalUpdate * normSquared - WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
                                var  dualUpdate = _loss.DualUpdate(output, 1, dual, invariant, numThreads);

                                // The successive over-relaxation approach to adjust the sum of dual variables (biasReg) to zero.
                                // Reference to details: http://stat.rutgers.edu/home/tzhang/papers/ml02_dual.pdf, pp. 16-17.
                                var adjustment = l1ThresholdZero ? lr * biasReg[iClass] : lr * l1IntermediateBias[iClass];
                                dualUpdate -= adjustment;
                                bool success = false;
                                duals.ApplyAt(dualIndex, (long index, ref float value) =>
                                              success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual);

                                if (success)
                                {
                                    // Note: dualConstraint[iClass] = lambdaNInv * (sum of duals[iClass])
                                    var primalUpdate = dualUpdate * lambdaNInv * instanceWeight;
                                    labelDual         -= dual + dualUpdate;
                                    labelPrimalUpdate += primalUpdate;
                                    biasUnreg[iClass] += adjustment * lambdaNInv * instanceWeight;
                                    labelAdjustment   -= adjustment;

                                    if (l1ThresholdZero)
                                    {
                                        VectorUtils.AddMult(in features, weightsEditor.Values, -primalUpdate);
                                        biasReg[iClass] -= primalUpdate;
                                    }
                                    else
                                    {
                                        //Iterative shrinkage-thresholding (aka. soft-thresholding)
                                        //Update v=denseWeights as if there's no L1
                                        //Thresholding: if |v[j]| < threshold, turn off weights[j]
                                        //If not, shrink: w[j] = v[i] - sign(v[j]) * threshold
                                        l1IntermediateBias[iClass] -= primalUpdate;
                                        if (SdcaTrainerOptions.BiasLearningRate == 0)
                                        {
                                            biasReg[iClass] = Math.Abs(l1IntermediateBias[iClass]) - l1Threshold > 0.0
                                        ? l1IntermediateBias[iClass] - Math.Sign(l1IntermediateBias[iClass]) * l1Threshold
                                        : 0;
                                        }

                                        var featureValues = features.GetValues();
                                        if (features.IsDense)
                                        {
                                            CpuMathUtils.SdcaL1UpdateDense(-primalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                                        }
                                        else if (featureValues.Length > 0)
                                        {
                                            CpuMathUtils.SdcaL1UpdateSparse(-primalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                                        }
                                    }

                                    break;
                                }
                            }
                        }

                        // Updating with label class weights and dual variable.
                        duals[label + dualIndexInitPos] = labelDual;
                        biasUnreg[label] += labelAdjustment * lambdaNInv * instanceWeight;
                        if (l1ThresholdZero)
                        {
                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
                            VectorUtils.AddMult(in features, weightsEditor.Values, labelPrimalUpdate);
                            biasReg[label] += labelPrimalUpdate;
                        }
                        else
                        {
                            l1IntermediateBias[label] += labelPrimalUpdate;
                            var intermediateBias = l1IntermediateBias[label];
                            biasReg[label] = Math.Abs(intermediateBias) - l1Threshold > 0.0
                            ? intermediateBias - Math.Sign(intermediateBias) * l1Threshold
                            : 0;

                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
                            var l1IntermediateWeightsEditor = VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[label]);
                            var featureValues = features.GetValues();
                            if (features.IsDense)
                            {
                                CpuMathUtils.SdcaL1UpdateDense(labelPrimalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                            }
                            else if (featureValues.Length > 0)
                            {
                                CpuMathUtils.SdcaL1UpdateSparse(labelPrimalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                            }
                        }

                        rowCount++;
                    }
                }
        }
예제 #2
0
        /// <inheritdoc/>
        private protected override bool CheckConvergence(
            IProgressChannel pch,
            int iter,
            FloatLabelCursor.Factory cursorFactory,
            DualsTableBase duals,
            IdToIdxLookup idToIdx,
            VBuffer <Float>[] weights,
            VBuffer <Float>[] bestWeights,
            Float[] biasUnreg,
            Float[] bestBiasUnreg,
            Float[] biasReg,
            Float[] bestBiasReg,
            long count,
            Double[] metrics,
            ref Double bestPrimalLoss,
            ref int bestIter)
        {
            Contracts.AssertValue(weights);
            Contracts.AssertValue(duals);
            int numClasses = weights.Length;

            Contracts.Assert(duals.Length >= numClasses * count);
            Contracts.AssertValueOrNull(idToIdx);
            Contracts.Assert(Utils.Size(weights) == numClasses);
            Contracts.Assert(Utils.Size(biasReg) == numClasses);
            Contracts.Assert(Utils.Size(biasUnreg) == numClasses);
            Contracts.Assert(Utils.Size(metrics) == 6);
            var reportedValues = new Double?[metrics.Length + 1];

            reportedValues[metrics.Length] = iter;
            var lossSum     = new CompensatedSum();
            var dualLossSum = new CompensatedSum();
            int numFeatures = weights[0].Length;

            using (var cursor = cursorFactory.Create())
            {
                long row = 0;
                Func <DataViewRowId, long, long> getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length);
                // Iterates through data to compute loss function.
                while (cursor.MoveNext())
                {
                    var    instanceWeight = GetInstanceWeight(cursor);
                    var    features       = cursor.Features;
                    var    label          = (int)cursor.Label;
                    var    labelOutput    = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]);
                    Double subLoss        = 0;
                    Double subDualLoss    = 0;
                    long   idx            = getIndexFromIdAndRow(cursor.Id, row);
                    long   dualIndex      = idx * numClasses;
                    for (int iClass = 0; iClass < numClasses; iClass++)
                    {
                        if (iClass == label)
                        {
                            dualIndex++;
                            continue;
                        }

                        var currentClassOutput = WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
                        subLoss += _loss.Loss(labelOutput - currentClassOutput, 1);
                        Contracts.Assert(dualIndex == iClass + idx * numClasses);
                        var dual = duals[dualIndex++];
                        subDualLoss += _loss.DualLoss(1, dual);
                    }

                    lossSum.Add(subLoss * instanceWeight);
                    dualLossSum.Add(subDualLoss * instanceWeight);

                    row++;
                }
                Host.Assert(idToIdx == null || row * numClasses == duals.Length);
            }

            Contracts.Assert(Args.L2Const.HasValue);
            Contracts.Assert(Args.L1Threshold.HasValue);
            Double l2Const     = Args.L2Const.Value;
            Double l1Threshold = Args.L1Threshold.Value;

            Double weightsL1Norm                = 0;
            Double weightsL2NormSquared         = 0;
            Double biasRegularizationAdjustment = 0;

            for (int iClass = 0; iClass < numClasses; iClass++)
            {
                weightsL1Norm                += VectorUtils.L1Norm(in weights[iClass]) + Math.Abs(biasReg[iClass]);
                weightsL2NormSquared         += VectorUtils.NormSquared(weights[iClass]) + biasReg[iClass] * biasReg[iClass];
                biasRegularizationAdjustment += biasReg[iClass] * biasUnreg[iClass];
            }

            Double l1Regularizer = Args.L1Threshold.Value * l2Const * weightsL1Norm;
            var    l2Regularizer = l2Const * weightsL2NormSquared * 0.5;

            var newLoss     = lossSum.Sum / count + l2Regularizer + l1Regularizer;
            var newDualLoss = dualLossSum.Sum / count - l2Regularizer - l2Const * biasRegularizationAdjustment;
            var dualityGap  = newLoss - newDualLoss;

            metrics[(int)MetricKind.Loss]       = newLoss;
            metrics[(int)MetricKind.DualLoss]   = newDualLoss;
            metrics[(int)MetricKind.DualityGap] = dualityGap;
            metrics[(int)MetricKind.BiasUnreg]  = biasUnreg[0];
            metrics[(int)MetricKind.BiasReg]    = biasReg[0];
            metrics[(int)MetricKind.L1Sparsity] = Args.L1Threshold == 0 ? 1 : weights.Sum(
                weight => weight.GetValues().Count(w => w != 0)) / (numClasses * numFeatures);

            bool converged = dualityGap / newLoss < Args.ConvergenceTolerance;

            if (metrics[(int)MetricKind.Loss] < bestPrimalLoss)
            {
                for (int iClass = 0; iClass < numClasses; iClass++)
                {
                    // Maintain a copy of weights and bias with best primal loss thus far.
                    // This is some extra work and uses extra memory, but it seems worth doing it.
                    // REVIEW: Sparsify bestWeights?
                    weights[iClass].CopyTo(ref bestWeights[iClass]);
                    bestBiasReg[iClass]   = biasReg[iClass];
                    bestBiasUnreg[iClass] = biasUnreg[iClass];
                }

                bestPrimalLoss = metrics[(int)MetricKind.Loss];
                bestIter       = iter;
            }

            for (int i = 0; i < metrics.Length; i++)
            {
                reportedValues[i] = metrics[i];
            }
            if (pch != null)
            {
                pch.Checkpoint(reportedValues);
            }

            return(converged);
        }