/// <inheritdoc/> protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, IRandom rand, IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, Float[] biasReg, Float[] invariants, Float lambdaNInv, VBuffer <Float>[] weights, Float[] biasUnreg, VBuffer <Float>[] l1IntermediateWeights, Float[] l1IntermediateBias, Float[] featureNormSquared) { Contracts.AssertValueOrNull(progress); Contracts.Assert(_args.L1Threshold.HasValue); Contracts.AssertValueOrNull(idToIdx); Contracts.AssertValueOrNull(invariants); Contracts.AssertValueOrNull(featureNormSquared); int weightArraySize = WeightArraySize; Contracts.Assert(weightArraySize == _numClasses); Contracts.Assert(Utils.Size(weights) == weightArraySize); Contracts.Assert(Utils.Size(biasReg) == weightArraySize); Contracts.Assert(Utils.Size(biasUnreg) == weightArraySize); int maxUpdateTrials = 2 * numThreads; var l1Threshold = _args.L1Threshold.Value; bool l1ThresholdZero = l1Threshold == 0; var lr = _args.BiasLearningRate * _args.L2Const.Value; var pch = progress != null?progress.StartProgressChannel("Dual update") : null; using (pch) using (var cursor = _args.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create()) { long rowCount = 0; if (pch != null) { pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount)); } Func <UInt128, long> getIndexFromId = GetIndexFromIdGetter(idToIdx); while (cursor.MoveNext()) { long idx = getIndexFromId(cursor.Id); long dualIndexInitPos = idx * weightArraySize; var features = cursor.Features; var label = (int)cursor.Label; Float invariant; Float normSquared; if (invariants != null) { invariant = invariants[idx]; Contracts.AssertValue(featureNormSquared); normSquared = featureNormSquared[idx]; } else { normSquared = VectorUtils.NormSquared(features); if (_args.BiasLearningRate == 0) { normSquared += 1; } invariant = _loss.ComputeDualUpdateInvariant(2 * normSquared * lambdaNInv * GetInstanceWeight(cursor)); } // The output for the label class using current weights and bias. var labelOutput = WDot(ref features, ref weights[label], biasReg[label] + biasUnreg[label]); var instanceWeight = GetInstanceWeight(cursor); // This will be the new dual variable corresponding to the label class. Float labelDual = 0; // This will be used to update the weights and regularized bias corresponding to the label class. Float labelPrimalUpdate = 0; // This will be used to update the unregularized bias corresponding to the label class. Float labelAdjustment = 0; // Iterates through all classes. for (int iClass = 0; iClass < _numClasses; iClass++) { // Skip the dual/weights/bias update for label class. Will be taken care of at the end. if (iClass == label) { continue; } // Loop trials for compare-and-swap updates of duals. // In general, concurrent update conflict to the same dual variable is rare // if data is shuffled. for (int numTrials = 0; numTrials < maxUpdateTrials; numTrials++) { long dualIndex = iClass + dualIndexInitPos; var dual = duals[dualIndex]; var output = labelOutput + labelPrimalUpdate * normSquared - WDot(ref features, ref weights[iClass], biasReg[iClass] + biasUnreg[iClass]); var dualUpdate = _loss.DualUpdate(output, 1, dual, invariant, numThreads); // The successive over-relaxation apporach to adjust the sum of dual variables (biasReg) to zero. // Reference to details: http://stat.rutgers.edu/home/tzhang/papers/ml02_dual.pdf, pp. 16-17. var adjustment = l1ThresholdZero ? lr * biasReg[iClass] : lr * l1IntermediateBias[iClass]; dualUpdate -= adjustment; bool success = false; duals.ApplyAt(dualIndex, (long index, ref Float value) => { success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual; }); if (success) { // Note: dualConstraint[iClass] = lambdaNInv * (sum of duals[iClass]) var primalUpdate = dualUpdate * lambdaNInv * instanceWeight; labelDual -= dual + dualUpdate; labelPrimalUpdate += primalUpdate; biasUnreg[iClass] += adjustment * lambdaNInv * instanceWeight; labelAdjustment -= adjustment; if (l1ThresholdZero) { VectorUtils.AddMult(ref features, weights[iClass].Values, -primalUpdate); biasReg[iClass] -= primalUpdate; } else { //Iterative shrinkage-thresholding (aka. soft-thresholding) //Update v=denseWeights as if there's no L1 //Thresholding: if |v[j]| < threshold, turn off weights[j] //If not, shrink: w[j] = v[i] - sign(v[j]) * threshold l1IntermediateBias[iClass] -= primalUpdate; if (_args.BiasLearningRate == 0) { biasReg[iClass] = Math.Abs(l1IntermediateBias[iClass]) - l1Threshold > 0.0 ? l1IntermediateBias[iClass] - Math.Sign(l1IntermediateBias[iClass]) * l1Threshold : 0; } if (features.IsDense) { SseUtils.SdcaL1UpdateDense(-primalUpdate, features.Length, features.Values, l1Threshold, l1IntermediateWeights[iClass].Values, weights[iClass].Values); } else if (features.Count > 0) { SseUtils.SdcaL1UpdateSparse(-primalUpdate, features.Length, features.Values, features.Indices, features.Count, l1Threshold, l1IntermediateWeights[iClass].Values, weights[iClass].Values); } } break; } } } // Updating with label class weights and dual variable. duals[label + dualIndexInitPos] = labelDual; biasUnreg[label] += labelAdjustment * lambdaNInv * instanceWeight; if (l1ThresholdZero) { VectorUtils.AddMult(ref features, weights[label].Values, labelPrimalUpdate); biasReg[label] += labelPrimalUpdate; } else { l1IntermediateBias[label] += labelPrimalUpdate; var intermediateBias = l1IntermediateBias[label]; biasReg[label] = Math.Abs(intermediateBias) - l1Threshold > 0.0 ? intermediateBias - Math.Sign(intermediateBias) * l1Threshold : 0; if (features.IsDense) { SseUtils.SdcaL1UpdateDense(labelPrimalUpdate, features.Length, features.Values, l1Threshold, l1IntermediateWeights[label].Values, weights[label].Values); } else if (features.Count > 0) { SseUtils.SdcaL1UpdateSparse(labelPrimalUpdate, features.Length, features.Values, features.Indices, features.Count, l1Threshold, l1IntermediateWeights[label].Values, weights[label].Values); } } rowCount++; } } }
public static void SdcaL1UpdateDense(float primalUpdate, int length, float[] src, float threshold, float[] v, float[] w) => SseUtils.SdcaL1UpdateDense(primalUpdate, length, src, threshold, v, w);