public void SdcaL1UpdateUTest(int test) { float[] src = (float[])testArrays[test].Clone(); float[] v = (float[])src.Clone(); float[] w = (float[])src.Clone(); float[] expected = (float[])w.Clone(); for (int i = 0; i < expected.Length; i++) { float value = src[i] * (1 + DEFAULT_SCALE); expected[i] = Math.Abs(value) > DEFAULT_SCALE ? (value > 0 ? value - DEFAULT_SCALE : value + DEFAULT_SCALE) : 0; } CpuMathUtils.SdcaL1UpdateDense(DEFAULT_SCALE, src.Length, src, DEFAULT_SCALE, v, w); var actual = w; Assert.Equal(expected, actual, comparer); }
public void SdcaL1UpdateUTest(int test) { float[] src = (float[])_testArrays[test].Clone(); float[] v = (float[])src.Clone(); float[] w = (float[])src.Clone(); float[] expected = (float[])w.Clone(); for (int i = 0; i < expected.Length; i++) { float value = src[i] * (1 + DefaultScale); expected[i] = Math.Abs(value) > DefaultScale ? (value > 0 ? value - DefaultScale : value + DefaultScale) : 0; } CpuMathUtils.SdcaL1UpdateDense(DefaultScale, src.Length, src, DefaultScale, v, w); var actual = w; Assert.Equal(expected, actual, _comparer); }
public void SdcaL1UpdateUTest(string mode, string test, string scale, Dictionary <string, string> environmentVariables) { RemoteExecutor.RemoteInvoke((arg0, arg1, arg2) => { CheckProperFlag(arg0); float defaultScale = float.Parse(arg2, CultureInfo.InvariantCulture); float[] src = (float[])_testArrays[int.Parse(arg1)].Clone(); float[] v = (float[])src.Clone(); float[] w = (float[])src.Clone(); float[] expected = (float[])w.Clone(); for (int i = 0; i < expected.Length; i++) { float value = src[i] * (1 + defaultScale); expected[i] = Math.Abs(value) > defaultScale ? (value > 0 ? value - defaultScale : value + defaultScale) : 0; } CpuMathUtils.SdcaL1UpdateDense(defaultScale, src.Length, src, defaultScale, v, w); var actual = w; Assert.Equal(expected, actual, _comparer); return(RemoteExecutor.SuccessExitCode); }, mode, test, scale, new RemoteInvokeOptions(environmentVariables)); }
/// <inheritdoc/> private protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, Random rand, IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, float[] biasReg, float[] invariants, float lambdaNInv, VBuffer <float>[] weights, float[] biasUnreg, VBuffer <float>[] l1IntermediateWeights, float[] l1IntermediateBias, float[] featureNormSquared) { Contracts.AssertValueOrNull(progress); Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); Contracts.AssertValueOrNull(idToIdx); Contracts.AssertValueOrNull(invariants); Contracts.AssertValueOrNull(featureNormSquared); int numClasses = Utils.Size(weights); Contracts.Assert(Utils.Size(biasReg) == numClasses); Contracts.Assert(Utils.Size(biasUnreg) == numClasses); int maxUpdateTrials = 2 * numThreads; var l1Threshold = SdcaTrainerOptions.L1Threshold.Value; bool l1ThresholdZero = l1Threshold == 0; var lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Regularization.Value; var pch = progress != null?progress.StartProgressChannel("Dual update") : null; using (pch) using (var cursor = SdcaTrainerOptions.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create()) { long rowCount = 0; if (pch != null) { pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount)); } Func <DataViewRowId, long> getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length); while (cursor.MoveNext()) { long idx = getIndexFromId(cursor.Id); long dualIndexInitPos = idx * numClasses; var features = cursor.Features; var label = (int)cursor.Label; float invariant; float normSquared; if (invariants != null) { invariant = invariants[idx]; Contracts.AssertValue(featureNormSquared); normSquared = featureNormSquared[idx]; } else { normSquared = VectorUtils.NormSquared(in features); if (SdcaTrainerOptions.BiasLearningRate == 0) { normSquared += 1; } invariant = _loss.ComputeDualUpdateInvariant(2 * normSquared * lambdaNInv * GetInstanceWeight(cursor)); } // The output for the label class using current weights and bias. var labelOutput = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]); var instanceWeight = GetInstanceWeight(cursor); // This will be the new dual variable corresponding to the label class. float labelDual = 0; // This will be used to update the weights and regularized bias corresponding to the label class. float labelPrimalUpdate = 0; // This will be used to update the unregularized bias corresponding to the label class. float labelAdjustment = 0; // Iterates through all classes. for (int iClass = 0; iClass < numClasses; iClass++) { // Skip the dual/weights/bias update for label class. Will be taken care of at the end. if (iClass == label) { continue; } var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[iClass]); var l1IntermediateWeightsEditor = !l1ThresholdZero?VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[iClass]) : default; // Loop trials for compare-and-swap updates of duals. // In general, concurrent update conflict to the same dual variable is rare // if data is shuffled. for (int numTrials = 0; numTrials < maxUpdateTrials; numTrials++) { long dualIndex = iClass + dualIndexInitPos; var dual = duals[dualIndex]; var output = labelOutput + labelPrimalUpdate * normSquared - WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]); var dualUpdate = _loss.DualUpdate(output, 1, dual, invariant, numThreads); // The successive over-relaxation approach to adjust the sum of dual variables (biasReg) to zero. // Reference to details: http://stat.rutgers.edu/home/tzhang/papers/ml02_dual.pdf, pp. 16-17. var adjustment = l1ThresholdZero ? lr * biasReg[iClass] : lr * l1IntermediateBias[iClass]; dualUpdate -= adjustment; bool success = false; duals.ApplyAt(dualIndex, (long index, ref float value) => success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual); if (success) { // Note: dualConstraint[iClass] = lambdaNInv * (sum of duals[iClass]) var primalUpdate = dualUpdate * lambdaNInv * instanceWeight; labelDual -= dual + dualUpdate; labelPrimalUpdate += primalUpdate; biasUnreg[iClass] += adjustment * lambdaNInv * instanceWeight; labelAdjustment -= adjustment; if (l1ThresholdZero) { VectorUtils.AddMult(in features, weightsEditor.Values, -primalUpdate); biasReg[iClass] -= primalUpdate; } else { //Iterative shrinkage-thresholding (aka. soft-thresholding) //Update v=denseWeights as if there's no L1 //Thresholding: if |v[j]| < threshold, turn off weights[j] //If not, shrink: w[j] = v[i] - sign(v[j]) * threshold l1IntermediateBias[iClass] -= primalUpdate; if (SdcaTrainerOptions.BiasLearningRate == 0) { biasReg[iClass] = Math.Abs(l1IntermediateBias[iClass]) - l1Threshold > 0.0 ? l1IntermediateBias[iClass] - Math.Sign(l1IntermediateBias[iClass]) * l1Threshold : 0; } var featureValues = features.GetValues(); if (features.IsDense) { CpuMathUtils.SdcaL1UpdateDense(-primalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values); } else if (featureValues.Length > 0) { CpuMathUtils.SdcaL1UpdateSparse(-primalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values); } } break; } } } // Updating with label class weights and dual variable. duals[label + dualIndexInitPos] = labelDual; biasUnreg[label] += labelAdjustment * lambdaNInv * instanceWeight; if (l1ThresholdZero) { var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]); VectorUtils.AddMult(in features, weightsEditor.Values, labelPrimalUpdate); biasReg[label] += labelPrimalUpdate; } else { l1IntermediateBias[label] += labelPrimalUpdate; var intermediateBias = l1IntermediateBias[label]; biasReg[label] = Math.Abs(intermediateBias) - l1Threshold > 0.0 ? intermediateBias - Math.Sign(intermediateBias) * l1Threshold : 0; var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]); var l1IntermediateWeightsEditor = VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[label]); var featureValues = features.GetValues(); if (features.IsDense) { CpuMathUtils.SdcaL1UpdateDense(labelPrimalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values); } else if (featureValues.Length > 0) { CpuMathUtils.SdcaL1UpdateSparse(labelPrimalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values); } } rowCount++; } } }
public void SdcaL1UpdateU() => CpuMathUtils.SdcaL1UpdateDense(DefaultScale, _smallInputLength, src, DefaultScale, dst, result);
public void ManagedSdcaL1UpdateUPerf() => CpuMathUtils.SdcaL1UpdateDense(DEFAULT_SCALE, LEN, src, DEFAULT_SCALE, dst, result);