Ejemplo n.º 1
0
            private void GetLabels(Transposer trans, ColumnType labelType, int labelCol)
            {
                int min;
                int lim;
                var labels = default(VBuffer <int>);

                // Note: NAs have their own separate bin.
                if (labelType == NumberType.I4)
                {
                    var tmp = default(VBuffer <int>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinInts(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberType.R4)
                {
                    var tmp = default(VBuffer <Single>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinSingles(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberType.R8)
                {
                    var tmp = default(VBuffer <Double>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinDoubles(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType is BoolType)
                {
                    var tmp = default(VBuffer <bool>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinBools(in tmp, ref labels);
                    _numLabels = 3;
                    min        = -1;
                    lim        = 2;
                }
                else
                {
                    ulong labelKeyCount = labelType.GetKeyCount();
                    Contracts.Assert(labelKeyCount < Utils.ArrayMaxSize);
                    KeyLabelGetter <int> del = GetKeyLabels <int>;
                    var methodInfo           = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(labelType.RawType);
                    var parameters           = new object[] { trans, labelCol, labelType };
                    _labels    = (VBuffer <int>)methodInfo.Invoke(this, parameters);
                    _numLabels = labelType.GetKeyCountAsInt32(_host) + 1;

                    // No need to densify or shift in this case.
                    return;
                }

                // Densify and shift labels.
                VBufferUtils.Densify(ref labels);
                Contracts.Assert(labels.IsDense);
                var labelsEditor = VBufferEditor.CreateFromBuffer(ref labels);

                for (int i = 0; i < labels.Length; i++)
                {
                    labelsEditor.Values[i] -= min;
                    Contracts.Assert(labelsEditor.Values[i] < _numLabels);
                }
                _labels = labelsEditor.Commit();
            }
Ejemplo n.º 2
0
        private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount)
        {
            int numFeatures   = data.Schema.Feature.Value.Type.GetVectorSize();
            var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features);
            int numThreads    = 1;

            ch.CheckUserArg(numThreads > 0, nameof(_options.NumberOfThreads),
                            "The number of threads must be either null or a positive integer.");

            var             positiveInstanceWeight = _options.PositiveInstanceWeight;
            VBuffer <float> weights = default;
            float           bias    = 0.0f;

            if (predictor != null)
            {
                predictor.GetFeatureWeights(ref weights);
                VBufferUtils.Densify(ref weights);
                bias = predictor.Bias;
            }
            else
            {
                weights = VBufferUtils.CreateDense <float>(numFeatures);
            }

            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights);

            // Reference: Parasail. SymSGD.
            bool tuneLR = _options.LearningRate == null;
            var  lr     = _options.LearningRate ?? 1.0f;

            bool tuneNumLocIter = (_options.UpdateFrequency == null);
            var  numLocIter     = _options.UpdateFrequency ?? 1;

            var l2Const = _options.L2Regularization;
            var piw     = _options.PositiveInstanceWeight;

            // This is state of the learner that is shared with the native code.
            State    state         = new State();
            GCHandle stateGCHandle = default;

            try
            {
                stateGCHandle = GCHandle.Alloc(state, GCHandleType.Pinned);

                state.TotalInstancesProcessed = 0;
                using (InputDataManager inputDataManager = new InputDataManager(this, cursorFactory, ch))
                {
                    bool shouldInitialize = true;
                    using (var pch = Host.StartProgressChannel("Preprocessing"))
                        inputDataManager.LoadAsMuchAsPossible();

                    int iter = 0;
                    if (inputDataManager.IsFullyLoaded)
                    {
                        ch.Info("Data fully loaded into memory.");
                    }
                    using (var pch = Host.StartProgressChannel("Training"))
                    {
                        if (inputDataManager.IsFullyLoaded)
                        {
                            pch.SetHeader(new ProgressHeader(new[] { "iterations" }),
                                          entry => entry.SetProgress(0, state.PassIteration, _options.NumberOfIterations));
                            // If fully loaded, call the SymSGDNative and do not come back until learned for all iterations.
                            Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
                                            _options.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize,
                                            stateGCHandle, ch.Info);
                            shouldInitialize = false;
                        }
                        else
                        {
                            pch.SetHeader(new ProgressHeader(new[] { "iterations" }),
                                          entry => entry.SetProgress(0, iter, _options.NumberOfIterations));

                            // Since we loaded data in batch sizes, multiple passes over the loaded data is feasible.
                            int numPassesForABatch = inputDataManager.Count / 10000;
                            while (iter < _options.NumberOfIterations)
                            {
                                // We want to train on the final passes thoroughly (without learning on the same batch multiple times)
                                // This is for fine tuning the AUC. Experimentally, we found that 1 or 2 passes is enough
                                int numFinalPassesToTrainThoroughly = 2;
                                // We also do not want to learn for more passes than what the user asked
                                int numPassesForThisBatch = Math.Min(numPassesForABatch, _options.NumberOfIterations - iter - numFinalPassesToTrainThoroughly);
                                // If all of this leaves us with 0 passes, then set numPassesForThisBatch to 1
                                numPassesForThisBatch = Math.Max(1, numPassesForThisBatch);
                                state.PassIteration   = iter;
                                Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
                                                numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize,
                                                stateGCHandle, ch.Info);
                                shouldInitialize = false;

                                // Check if we are done with going through the data
                                if (inputDataManager.FinishedTheLoad)
                                {
                                    iter += numPassesForThisBatch;
                                    // Check if more passes are left
                                    if (iter < _options.NumberOfIterations)
                                    {
                                        inputDataManager.RestartLoading(_options.Shuffle, Host);
                                    }
                                }

                                // If more passes are left, load as much as possible
                                if (iter < _options.NumberOfIterations)
                                {
                                    inputDataManager.LoadAsMuchAsPossible();
                                }
                            }
                        }

                        // Maps back the dense features that are mislocated
                        if (numThreads > 1)
                        {
                            Native.MapBackWeightVector(weightsEditor.Values, stateGCHandle);
                        }
                        Native.DeallocateSequentially(stateGCHandle);
                    }
                }
            }
            finally
            {
                if (stateGCHandle.IsAllocated)
                {
                    stateGCHandle.Free();
                }
            }
            return(CreatePredictor(weights, bias));
        }
Ejemplo n.º 3
0
        /// <inheritdoc/>
        private protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, Random rand,
                                                         IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, Float[] biasReg, Float[] invariants, Float lambdaNInv,
                                                         VBuffer <Float>[] weights, Float[] biasUnreg, VBuffer <Float>[] l1IntermediateWeights, Float[] l1IntermediateBias, Float[] featureNormSquared)
        {
            Contracts.AssertValueOrNull(progress);
            Contracts.Assert(Args.L1Threshold.HasValue);
            Contracts.AssertValueOrNull(idToIdx);
            Contracts.AssertValueOrNull(invariants);
            Contracts.AssertValueOrNull(featureNormSquared);
            int numClasses = Utils.Size(weights);

            Contracts.Assert(Utils.Size(biasReg) == numClasses);
            Contracts.Assert(Utils.Size(biasUnreg) == numClasses);

            int  maxUpdateTrials = 2 * numThreads;
            var  l1Threshold     = Args.L1Threshold.Value;
            bool l1ThresholdZero = l1Threshold == 0;
            var  lr = Args.BiasLearningRate * Args.L2Const.Value;

            var pch = progress != null?progress.StartProgressChannel("Dual update") : null;

            using (pch)
                using (var cursor = Args.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create())
                {
                    long rowCount = 0;
                    if (pch != null)
                    {
                        pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount));
                    }

                    Func <RowId, long> getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length);
                    while (cursor.MoveNext())
                    {
                        long  idx = getIndexFromId(cursor.Id);
                        long  dualIndexInitPos = idx * numClasses;
                        var   features         = cursor.Features;
                        var   label            = (int)cursor.Label;
                        Float invariant;
                        Float normSquared;
                        if (invariants != null)
                        {
                            invariant = invariants[idx];
                            Contracts.AssertValue(featureNormSquared);
                            normSquared = featureNormSquared[idx];
                        }
                        else
                        {
                            normSquared = VectorUtils.NormSquared(in features);
                            if (Args.BiasLearningRate == 0)
                            {
                                normSquared += 1;
                            }

                            invariant = _loss.ComputeDualUpdateInvariant(2 * normSquared * lambdaNInv * GetInstanceWeight(cursor));
                        }

                        // The output for the label class using current weights and bias.
                        var labelOutput    = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]);
                        var instanceWeight = GetInstanceWeight(cursor);

                        // This will be the new dual variable corresponding to the label class.
                        Float labelDual = 0;

                        // This will be used to update the weights and regularized bias corresponding to the label class.
                        Float labelPrimalUpdate = 0;

                        // This will be used to update the unregularized bias corresponding to the label class.
                        Float labelAdjustment = 0;

                        // Iterates through all classes.
                        for (int iClass = 0; iClass < numClasses; iClass++)
                        {
                            // Skip the dual/weights/bias update for label class. Will be taken care of at the end.
                            if (iClass == label)
                            {
                                continue;
                            }

                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[iClass]);
                            var l1IntermediateWeightsEditor =
                                !l1ThresholdZero?VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[iClass]) :
                                    default;

                            // Loop trials for compare-and-swap updates of duals.
                            // In general, concurrent update conflict to the same dual variable is rare
                            // if data is shuffled.
                            for (int numTrials = 0; numTrials < maxUpdateTrials; numTrials++)
                            {
                                long dualIndex  = iClass + dualIndexInitPos;
                                var  dual       = duals[dualIndex];
                                var  output     = labelOutput + labelPrimalUpdate * normSquared - WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
                                var  dualUpdate = _loss.DualUpdate(output, 1, dual, invariant, numThreads);

                                // The successive over-relaxation apporach to adjust the sum of dual variables (biasReg) to zero.
                                // Reference to details: http://stat.rutgers.edu/home/tzhang/papers/ml02_dual.pdf, pp. 16-17.
                                var adjustment = l1ThresholdZero ? lr * biasReg[iClass] : lr * l1IntermediateBias[iClass];
                                dualUpdate -= adjustment;
                                bool success = false;
                                duals.ApplyAt(dualIndex, (long index, ref Float value) =>
                                              success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual);

                                if (success)
                                {
                                    // Note: dualConstraint[iClass] = lambdaNInv * (sum of duals[iClass])
                                    var primalUpdate = dualUpdate * lambdaNInv * instanceWeight;
                                    labelDual         -= dual + dualUpdate;
                                    labelPrimalUpdate += primalUpdate;
                                    biasUnreg[iClass] += adjustment * lambdaNInv * instanceWeight;
                                    labelAdjustment   -= adjustment;

                                    if (l1ThresholdZero)
                                    {
                                        VectorUtils.AddMult(in features, weightsEditor.Values, -primalUpdate);
                                        biasReg[iClass] -= primalUpdate;
                                    }
                                    else
                                    {
                                        //Iterative shrinkage-thresholding (aka. soft-thresholding)
                                        //Update v=denseWeights as if there's no L1
                                        //Thresholding: if |v[j]| < threshold, turn off weights[j]
                                        //If not, shrink: w[j] = v[i] - sign(v[j]) * threshold
                                        l1IntermediateBias[iClass] -= primalUpdate;
                                        if (Args.BiasLearningRate == 0)
                                        {
                                            biasReg[iClass] = Math.Abs(l1IntermediateBias[iClass]) - l1Threshold > 0.0
                                        ? l1IntermediateBias[iClass] - Math.Sign(l1IntermediateBias[iClass]) * l1Threshold
                                        : 0;
                                        }

                                        var featureValues = features.GetValues();
                                        if (features.IsDense)
                                        {
                                            CpuMathUtils.SdcaL1UpdateDense(-primalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                                        }
                                        else if (featureValues.Length > 0)
                                        {
                                            CpuMathUtils.SdcaL1UpdateSparse(-primalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                                        }
                                    }

                                    break;
                                }
                            }
                        }

                        // Updating with label class weights and dual variable.
                        duals[label + dualIndexInitPos] = labelDual;
                        biasUnreg[label] += labelAdjustment * lambdaNInv * instanceWeight;
                        if (l1ThresholdZero)
                        {
                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
                            VectorUtils.AddMult(in features, weightsEditor.Values, labelPrimalUpdate);
                            biasReg[label] += labelPrimalUpdate;
                        }
                        else
                        {
                            l1IntermediateBias[label] += labelPrimalUpdate;
                            var intermediateBias = l1IntermediateBias[label];
                            biasReg[label] = Math.Abs(intermediateBias) - l1Threshold > 0.0
                            ? intermediateBias - Math.Sign(intermediateBias) * l1Threshold
                            : 0;

                            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
                            var l1IntermediateWeightsEditor = VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[label]);
                            var featureValues = features.GetValues();
                            if (features.IsDense)
                            {
                                CpuMathUtils.SdcaL1UpdateDense(labelPrimalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                            }
                            else if (featureValues.Length > 0)
                            {
                                CpuMathUtils.SdcaL1UpdateSparse(labelPrimalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
                            }
                        }

                        rowCount++;
                    }
                }
        }