示例#1
0
        public void SdcaL1UpdateSUTest(string mode, string test, string scale, Dictionary <string, string> environmentVariables)
        {
            RemoteExecutor.RemoteInvoke((arg0, arg1, arg2) =>
            {
                CheckProperFlag(arg0);
                float defaultScale = float.Parse(arg2, CultureInfo.InvariantCulture);
                float[] src        = (float[])_testArrays[int.Parse(arg1)].Clone();
                float[] v          = (float[])src.Clone();
                float[] w          = (float[])src.Clone();
                int[] idx          = _testIndexArray;
                float[] expected   = (float[])w.Clone();

                for (int i = 0; i < idx.Length; i++)
                {
                    int index       = idx[i];
                    float value     = v[index] + src[i] * defaultScale;
                    expected[index] = Math.Abs(value) > defaultScale ? (value > 0 ? value - defaultScale : value + defaultScale) : 0;
                }

                CpuMathUtils.SdcaL1UpdateSparse(defaultScale, idx.Length, src, idx, defaultScale, v, w);
                var actual = w;
                Assert.Equal(expected, actual, _comparer);
                return(RemoteExecutor.SuccessExitCode);
            }, mode, test, scale, new RemoteInvokeOptions(environmentVariables));
        }
示例#2
0
        public void SdcaL1UpdateSUTest(int test)
        {
            float[] src      = (float[])testArrays[test].Clone();
            float[] v        = (float[])src.Clone();
            float[] w        = (float[])src.Clone();
            int[]   idx      = testIndexArray;
            float[] expected = (float[])w.Clone();

            for (int i = 0; i < idx.Length; i++)
            {
                int   index = idx[i];
                float value = v[index] + src[i] * DEFAULT_SCALE;
                expected[index] = Math.Abs(value) > DEFAULT_SCALE ? (value > 0 ? value - DEFAULT_SCALE : value + DEFAULT_SCALE) : 0;
            }

            CpuMathUtils.SdcaL1UpdateSparse(DEFAULT_SCALE, src.Length, src, idx, idx.Length, DEFAULT_SCALE, v, w);
            var actual = w;

            Assert.Equal(expected, actual, comparer);
        }
        public void SdcaL1UpdateSUTest(int test)
        {
            float[] src      = (float[])_testArrays[test].Clone();
            float[] v        = (float[])src.Clone();
            float[] w        = (float[])src.Clone();
            int[]   idx      = _testIndexArray;
            float[] expected = (float[])w.Clone();

            for (int i = 0; i < idx.Length; i++)
            {
                int   index = idx[i];
                float value = v[index] + src[i] * DefaultScale;
                expected[index] = Math.Abs(value) > DefaultScale ? (value > 0 ? value - DefaultScale : value + DefaultScale) : 0;
            }

            CpuMathUtils.SdcaL1UpdateSparse(DefaultScale, idx.Length, src, idx, DefaultScale, v, w);
            var actual = w;

            Assert.Equal(expected, actual, _comparer);
        }
示例#4
0
        /// <inheritdoc/>
        private protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, Random rand,
                                                         IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, float[] biasReg, float[] invariants, float lambdaNInv,
                                                         VBuffer <float>[] weights, float[] biasUnreg, VBuffer <float>[] l1IntermediateWeights, float[] l1IntermediateBias, float[] featureNormSquared)
        {
            Contracts.AssertValueOrNull(progress);
            Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue);
            Contracts.AssertValueOrNull(idToIdx);
            Contracts.AssertValueOrNull(invariants);
            Contracts.AssertValueOrNull(featureNormSquared);
            int numClasses = Utils.Size(weights);

            Contracts.Assert(Utils.Size(biasReg) == numClasses);
            Contracts.Assert(Utils.Size(biasUnreg) == numClasses);

            int  maxUpdateTrials = 2 * numThreads;
            var  l1Threshold     = SdcaTrainerOptions.L1Threshold.Value;
            bool l1ThresholdZero = l1Threshold == 0;
            var  lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Regularization.Value;

            var pch = progress != null?progress.StartProgressChannel("Dual update") : null;

            using (pch)
                using (var cursor = SdcaTrainerOptions.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create())
                {
                    long rowCount = 0;
                    if (pch != null)
                    {
                        pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount));
                    }

                    Func <DataViewRowId, long> getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length);
                    while (cursor.MoveNext())
                    {
                        long  idx = getIndexFromId(cursor.Id);
                        long  dualIndexInitPos = idx * numClasses;
                        var   features         = cursor.Features;
                        var   label            = (int)cursor.Label;
                        float invariant;
                        float normSquared;
                        if (invariants != null)
                        {
                            invariant = invariants[idx];
                            Contracts.AssertValue(featureNormSquared);
                            normSquared = featureNormSquared[idx];
                        }
                        else
                        {
                            normSquared = VectorUtils.NormSquared(in features);
                            if (SdcaTrainerOptions.BiasLearningRate == 0)
                            {
                                normSquared += 1;
                            }

                            invariant = _loss.ComputeDualUpdateInvariant(2 * normSquared * lambdaNInv * GetInstanceWeight(cursor));
                        }

                        // The output for the label class using current weights and bias.
                        var labelOutput    = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]);
                        var instanceWeight = GetInstanceWeight(cursor);

                        // This will be the new dual variable corresponding to the label class.
                        float labelDual = 0;

                        // This will be used to update the weights and regularized bias corresponding to the label class.
                        float labelPrimalUpdate = 0;

                        // This will be used to update the unregularized bias corresponding to the label class.
                        float labelAdjustment = 0;

                        // Iterates through all classes.
                        for (int iClass = 0; iClass < numClasses; iClass++)
                        {
                            // Skip the dual/weights/bias update for label class. Will be taken care of at the end.
                            if (iClass == label)
                            {
                                continue;
                            }

                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[iClass]);
                            var l1IntermediateWeightsEditor =
                                !l1ThresholdZero?VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[iClass]) :
                                    default;

                            // Loop trials for compare-and-swap updates of duals.
                            // In general, concurrent update conflict to the same dual variable is rare
                            // if data is shuffled.
                            for (int numTrials = 0; numTrials < maxUpdateTrials; numTrials++)
                            {
                                long dualIndex  = iClass + dualIndexInitPos;
                                var  dual       = duals[dualIndex];
                                var  output     = labelOutput + labelPrimalUpdate * normSquared - WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
                                var  dualUpdate = _loss.DualUpdate(output, 1, dual, invariant, numThreads);

                                // The successive over-relaxation approach to adjust the sum of dual variables (biasReg) to zero.
                                // Reference to details: http://stat.rutgers.edu/home/tzhang/papers/ml02_dual.pdf, pp. 16-17.
                                var adjustment = l1ThresholdZero ? lr * biasReg[iClass] : lr * l1IntermediateBias[iClass];
                                dualUpdate -= adjustment;
                                bool success = false;
                                duals.ApplyAt(dualIndex, (long index, ref float value) =>
                                              success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual);

                                if (success)
                                {
                                    // Note: dualConstraint[iClass] = lambdaNInv * (sum of duals[iClass])
                                    var primalUpdate = dualUpdate * lambdaNInv * instanceWeight;
                                    labelDual         -= dual + dualUpdate;
                                    labelPrimalUpdate += primalUpdate;
                                    biasUnreg[iClass] += adjustment * lambdaNInv * instanceWeight;
                                    labelAdjustment   -= adjustment;

                                    if (l1ThresholdZero)
                                    {
                                        VectorUtils.AddMult(in features, weightsEditor.Values, -primalUpdate);
                                        biasReg[iClass] -= primalUpdate;
                                    }
                                    else
                                    {
                                        //Iterative shrinkage-thresholding (aka. soft-thresholding)
                                        //Update v=denseWeights as if there's no L1
                                        //Thresholding: if |v[j]| < threshold, turn off weights[j]
                                        //If not, shrink: w[j] = v[i] - sign(v[j]) * threshold
                                        l1IntermediateBias[iClass] -= primalUpdate;
                                        if (SdcaTrainerOptions.BiasLearningRate == 0)
                                        {
                                            biasReg[iClass] = Math.Abs(l1IntermediateBias[iClass]) - l1Threshold > 0.0
                                        ? l1IntermediateBias[iClass] - Math.Sign(l1IntermediateBias[iClass]) * l1Threshold
                                        : 0;
                                        }

                                        var featureValues = features.GetValues();
                                        if (features.IsDense)
                                        {
                                            CpuMathUtils.SdcaL1UpdateDense(-primalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                                        }
                                        else if (featureValues.Length > 0)
                                        {
                                            CpuMathUtils.SdcaL1UpdateSparse(-primalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                                        }
                                    }

                                    break;
                                }
                            }
                        }

                        // Updating with label class weights and dual variable.
                        duals[label + dualIndexInitPos] = labelDual;
                        biasUnreg[label] += labelAdjustment * lambdaNInv * instanceWeight;
                        if (l1ThresholdZero)
                        {
                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
                            VectorUtils.AddMult(in features, weightsEditor.Values, labelPrimalUpdate);
                            biasReg[label] += labelPrimalUpdate;
                        }
                        else
                        {
                            l1IntermediateBias[label] += labelPrimalUpdate;
                            var intermediateBias = l1IntermediateBias[label];
                            biasReg[label] = Math.Abs(intermediateBias) - l1Threshold > 0.0
                            ? intermediateBias - Math.Sign(intermediateBias) * l1Threshold
                            : 0;

                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
                            var l1IntermediateWeightsEditor = VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[label]);
                            var featureValues = features.GetValues();
                            if (features.IsDense)
                            {
                                CpuMathUtils.SdcaL1UpdateDense(labelPrimalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                            }
                            else if (featureValues.Length > 0)
                            {
                                CpuMathUtils.SdcaL1UpdateSparse(labelPrimalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                            }
                        }

                        rowCount++;
                    }
                }
        }
示例#5
0
 public void SdcaL1UpdateSU()
 => CpuMathUtils.SdcaL1UpdateSparse(DefaultScale, _smallInputLength, src, idx, DefaultScale, dst, result);
示例#6
0
 public void ManagedSdcaL1UpdateSUPerf() => CpuMathUtils.SdcaL1UpdateSparse(DEFAULT_SCALE, LEN, src, idx, IDXLEN, DEFAULT_SCALE, dst, result);