예제 #1
0
            protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, OnlineLinearTrainer <TTransformer, TModel> parent)
            {
                Contracts.CheckValue(ch, nameof(ch));
                ch.Check(numFeatures > 0, "Cannot train with zero features!");
                ch.AssertValueOrNull(predictor);
                ch.AssertValue(parent);
                ch.Assert(Iteration == 0);
                ch.Assert(Bias == 0);

                ParentHost = parent.Host;

                ch.Trace("{0} Initializing {1} on {2} features", DateTime.UtcNow, parent.Name, numFeatures);

                // We want a dense vector, to prevent memory creation during training
                // unless we have a lot of features.
                if (predictor != null)
                {
                    predictor.GetFeatureWeights(ref Weights);
                    VBufferUtils.Densify(ref Weights);
                    Bias = predictor.Bias;
                }
                else if (!string.IsNullOrWhiteSpace(parent.Args.InitialWeights))
                {
                    ch.Info("Initializing weights and bias to " + parent.Args.InitialWeights);
                    string[] weightStr = parent.Args.InitialWeights.Split(',');
                    if (weightStr.Length != numFeatures + 1)
                    {
                        throw ch.Except(
                                  "Could not initialize weights from 'initialWeights': expecting {0} values to initialize {1} weights and the intercept",
                                  numFeatures + 1, numFeatures);
                    }

                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                    {
                        weightValues[i] = float.Parse(weightStr[i], CultureInfo.InvariantCulture);
                    }
                    Weights = new VBuffer <float>(numFeatures, weightValues);
                    Bias    = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture);
                }
                else if (parent.Args.InitWtsDiameter > 0)
                {
                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                    {
                        weightValues[i] = parent.Args.InitWtsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                    }
                    Weights = new VBuffer <float>(numFeatures, weightValues);
                    Bias    = parent.Args.InitWtsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                }
                else if (numFeatures <= 1000)
                {
                    Weights = VBufferUtils.CreateDense <float>(numFeatures);
                }
                else
                {
                    Weights = VBufferUtils.CreateEmpty <float>(numFeatures);
                }
                WeightsScale = 1;
            }
예제 #2
0
            private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedLinearTrainer <TTransformer, TModel> parent)
                : base(ch, numFeatures, predictor, parent)
            {
                // Do the other initializations by setting the setters as if user had set them
                // Initialize the averaged weights if needed (i.e., do what happens when Averaged is set)
                Averaged = parent.Args.Averaged;
                if (Averaged)
                {
                    if (parent.Args.AveragedTolerance > 0)
                    {
                        VBufferUtils.Densify(ref Weights);
                    }
                    Weights.CopyTo(ref TotalWeights);
                }
                else
                {
                    // It is definitely advantageous to keep weights dense if we aren't adding them
                    // to another vector with each update.
                    VBufferUtils.Densify(ref Weights);
                }
                _resetWeightsAfterXExamples = parent.Args.ResetWeightsAfterXExamples ?? 0;
                _args = parent.Args;
                _loss = parent.LossFunction;

                Gain = 1;
            }
예제 #3
0
        protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
        {
            base.InitCore(ch, numFeatures, predictor);

            // Verify user didn't specify parameters that conflict
            Contracts.Check(!Args.DoLazyUpdates || !Args.RecencyGainMulti && Args.RecencyGain == 0,
                            "Cannot have both recency gain and lazy updates.");

            // Do the other initializations by setting the setters as if user had set them
            // Initialize the averaged weights if needed (i.e., do what happens when Averaged is set)
            if (Args.Averaged)
            {
                if (Args.AveragedTolerance > 0)
                {
                    VBufferUtils.Densify(ref Weights);
                }
                Weights.CopyTo(ref TotalWeights);
            }
            else
            {
                // It is definitely advantageous to keep weights dense if we aren't adding them
                // to another vector with each update.
                VBufferUtils.Densify(ref Weights);
            }
            Gain = 1;
        }
        protected Float DifferentiableFunctionStream(FloatLabelCursor.Factory cursorFactory, ref VBuffer <Float> xDense, ref VBuffer <Float> grad, IProgressChannel pch)
        {
            Contracts.AssertValue(cursorFactory);

            VBufferUtils.Clear(ref grad);
            VBufferUtils.Densify(ref grad);

            Float[] scratch = null;
            double  loss    = 0;
            long    count   = 0;

            if (pch != null)
            {
                pch.SetHeader(new ProgressHeader(null, new[] { "examples" }), e => e.SetProgress(0, count));
            }
            using (var cursor = cursorFactory.Create())
            {
                while (cursor.MoveNext())
                {
                    loss += AccumulateOneGradient(ref cursor.Features, cursor.Label, cursor.Weight,
                                                  ref xDense, ref grad, ref scratch);
                    count++;
                }
            }

            // we need use double type to accumulate loss to avoid roundoff error
            // please see http://mathworld.wolfram.com/RoundoffError.html for roundoff error definition
            // finally we need to convert double type to float for function definition
            return((Float)loss);
        }
        protected Float DifferentiableFunctionComputeChunk(int ichk, ref VBuffer <Float> xDense, ref VBuffer <Float> grad, IProgressChannel pch)
        {
            Contracts.Assert(0 <= ichk && ichk < _numChunks);
            Contracts.AssertValueOrNull(pch);

            VBufferUtils.Clear(ref grad);
            VBufferUtils.Densify(ref grad);

            Float[] scratch = null;
            double  loss    = 0;
            int     ivMin   = _ranges[ichk];
            int     ivLim   = _ranges[ichk + 1];
            int     iv      = ivMin;

            if (pch != null)
            {
                pch.SetHeader(new ProgressHeader(null, new[] { "examples" }), e => e.SetProgress(0, iv - ivMin, ivLim - ivMin));
            }
            for (iv = ivMin; iv < ivLim; iv++)
            {
                Float weight = _weights != null ? _weights[iv] : 1;
                loss += AccumulateOneGradient(ref _features[iv], _labels[iv], weight, ref xDense, ref grad, ref scratch);
            }
            // we need use double type to accumulate loss to avoid roundoff error
            // please see http://mathworld.wolfram.com/RoundoffError.html for roundoff error definition
            // finally we need to convert double type to float for function definition
            return((Float)loss);
        }
        protected virtual void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
        {
            Contracts.Check(numFeatures > 0, "Can't train with zero features!");
            Contracts.Check(NumFeatures == 0, "Can't re-use trainer!");
            Contracts.Assert(Iteration == 0);
            Contracts.Assert(Bias == 0);

            ch.Trace("{0} Initializing {1} on {2} features", DateTime.UtcNow, Name, numFeatures);
            NumFeatures = numFeatures;

            // We want a dense vector, to prevent memory creation during training
            // unless we have a lot of features.
            // REVIEW: make a setting
            if (predictor != null)
            {
                predictor.GetFeatureWeights(ref Weights);
                VBufferUtils.Densify(ref Weights);
                Bias = predictor.Bias;
            }
            else if (!string.IsNullOrWhiteSpace(Args.InitialWeights))
            {
                ch.Info("Initializing weights and bias to " + Args.InitialWeights);
                string[] weightStr = Args.InitialWeights.Split(',');
                if (weightStr.Length != NumFeatures + 1)
                {
                    throw Contracts.Except(
                              "Could not initialize weights from 'initialWeights': expecting {0} values to initialize {1} weights and the intercept",
                              NumFeatures + 1, NumFeatures);
                }

                Weights = VBufferUtils.CreateDense <Float>(NumFeatures);
                for (int i = 0; i < NumFeatures; i++)
                {
                    Weights.Values[i] = Float.Parse(weightStr[i], CultureInfo.InvariantCulture);
                }
                Bias = Float.Parse(weightStr[NumFeatures], CultureInfo.InvariantCulture);
            }
            else if (Args.InitWtsDiameter > 0)
            {
                Weights = VBufferUtils.CreateDense <Float>(NumFeatures);
                for (int i = 0; i < NumFeatures; i++)
                {
                    Weights.Values[i] = Args.InitWtsDiameter * (Host.Rand.NextSingle() - (Float)0.5);
                }
                Bias = Args.InitWtsDiameter * (Host.Rand.NextSingle() - (Float)0.5);
            }
            else if (NumFeatures <= 1000)
            {
                Weights = VBufferUtils.CreateDense <Float>(NumFeatures);
            }
            else
            {
                Weights = VBufferUtils.CreateEmpty <Float>(NumFeatures);
            }
            WeightsScale = 1;
        }
예제 #7
0
            private VBuffer <int> GetKeyLabels <T>(Transposer trans, int labelCol, DataViewType labelColumnType)
            {
                var tmp    = default(VBuffer <T>);
                var labels = default(VBuffer <int>);

                trans.GetSingleSlotValue(labelCol, ref tmp);
                BinKeys <T>(labelColumnType)(in tmp, ref labels);
                VBufferUtils.Densify(ref labels);
                return(labels);
            }
        internal static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, out int[] categoricalFeatures)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.Check(colIndex >= 0, nameof(colIndex));

            bool isValid = false;

            categoricalFeatures = null;
            if (!(schema[colIndex].Type is VectorType vecType && vecType.Size > 0))
            {
                return(isValid);
            }

            var type = schema[colIndex].Metadata.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;

            if (type?.RawType == typeof(VBuffer <int>))
            {
                VBuffer <int> catIndices = default(VBuffer <int>);
                schema[colIndex].Metadata.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
                VBufferUtils.Densify(ref catIndices);
                int columnSlotsCount = vecType.Size;
                if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
                {
                    int previousEndIndex = -1;
                    isValid = true;
                    var catIndicesValues = catIndices.GetValues();
                    for (int i = 0; i < catIndicesValues.Length; i += 2)
                    {
                        if (catIndicesValues[i] > catIndicesValues[i + 1] ||
                            catIndicesValues[i] <= previousEndIndex ||
                            catIndicesValues[i] >= columnSlotsCount ||
                            catIndicesValues[i + 1] >= columnSlotsCount)
                        {
                            isValid = false;
                            break;
                        }

                        previousEndIndex = catIndicesValues[i + 1];
                    }
                    if (isValid)
                    {
                        categoricalFeatures = catIndicesValues.ToArray();
                    }
                }
            }

            return(isValid);
        }
예제 #9
0
        protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
        {
            base.InitCore(ch, numFeatures, predictor);

            if (Args.NoBias)
            {
                Bias = 0;
            }

            if (predictor == null)
            {
                VBufferUtils.Densify(ref Weights);
            }

            _weightsUpdate = VBufferUtils.CreateEmpty <Float>(numFeatures);
        }
예제 #10
0
            private KeyToValueMap GetKeyMetadata <TKey, TValue>(int iinfo, ColumnType typeKey, ColumnType typeVal)
            {
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                Host.AssertValue(typeKey);
                Host.AssertValue(typeVal);
                Host.Assert(typeKey.ItemType.RawType == typeof(TKey));
                Host.Assert(typeVal.ItemType.RawType == typeof(TValue));

                var keyMetadata = default(VBuffer <TValue>);

                InputSchema[ColMapNewToOld[iinfo]].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyMetadata);
                Host.Check(keyMetadata.Length == typeKey.ItemType.KeyCount);

                VBufferUtils.Densify(ref keyMetadata);
                return(new KeyToValueMap <TKey, TValue>(this, typeKey.ItemType.AsKey, typeVal.ItemType.AsPrimitive, keyMetadata.Values, iinfo));
            }
예제 #11
0
        private KeyToValueMap GetKeyMetadata <TKey, TValue>(int iinfo, ColumnType typeKey, ColumnType typeVal)
        {
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
            Host.AssertValue(typeKey);
            Host.AssertValue(typeVal);
            Host.Assert(typeKey.ItemType.RawType == typeof(TKey));
            Host.Assert(typeVal.ItemType.RawType == typeof(TValue));

            var keyMetadata = default(VBuffer <TValue>);

            Source.Schema.GetMetadata <VBuffer <TValue> >(MetadataUtils.Kinds.KeyValues, Infos[iinfo].Source, ref keyMetadata);
            Host.Check(keyMetadata.Length == typeKey.ItemType.KeyCount);

            VBufferUtils.Densify <TValue>(ref keyMetadata);
            return(new KeyToValueMap <TKey, TValue>(this, typeKey.ItemType.AsKey, typeVal.ItemType.AsPrimitive, keyMetadata.Values, iinfo));
        }
예제 #12
0
            private int[] GetKeyLabels <T>(Transposer trans, int labelCol, ColumnType labeColumnType)
            {
                var tmp    = default(VBuffer <T>);
                var labels = default(VBuffer <int>);

                trans.GetSingleSlotValue(labelCol, ref tmp);
                BinKeys <T>(labeColumnType)(ref tmp, ref labels);
                VBufferUtils.Densify(ref labels);
                var values = labels.Values;

                if (labels.Length < values.Length)
                {
                    Array.Resize(ref values, labels.Length);
                }
                return(values);
            }
예제 #13
0
        /// <summary>
        /// The categoricalFeatures is a vector of the indices of categorical features slots.
        /// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
        /// So if its value is the range of numbers: 0,2,3,4,8,9
        /// look at it as [0,2],[3,4],[8,9].
        /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
        /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
        /// </summary>
        public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, out int[] categoricalFeatures)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.Check(colIndex >= 0, nameof(colIndex));

            bool isValid = false;

            categoricalFeatures = null;
            if (!schema.GetColumnType(colIndex).IsKnownSizeVector)
            {
                return(isValid);
            }

            var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex);

            if (type?.RawType == typeof(VBuffer <int>))
            {
                VBuffer <int> catIndices = default(VBuffer <int>);
                schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices);
                VBufferUtils.Densify(ref catIndices);
                int columnSlotsCount = schema.GetColumnType(colIndex).AsVector.VectorSizeCore;
                if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
                {
                    int previousEndIndex = -1;
                    isValid = true;
                    for (int i = 0; i < catIndices.Values.Length; i += 2)
                    {
                        if (catIndices.Values[i] > catIndices.Values[i + 1] ||
                            catIndices.Values[i] <= previousEndIndex ||
                            catIndices.Values[i] >= columnSlotsCount ||
                            catIndices.Values[i + 1] >= columnSlotsCount)
                        {
                            isValid = false;
                            break;
                        }

                        previousEndIndex = catIndices.Values[i + 1];
                    }
                    if (isValid)
                    {
                        categoricalFeatures = catIndices.Values.Select(val => val).ToArray();
                    }
                }
            }

            return(isValid);
        }
예제 #14
0
        private T[] GetValuesArray <T>(VBuffer <T> src, ColumnType srcType, int iinfo)
        {
            Host.Assert(srcType.IsVector);
            Host.Assert(srcType.VectorSize == src.Length);
            VBufferUtils.Densify <T>(ref src);
            RefPredicate <T> defaultPred = Conversions.Instance.GetIsDefaultPredicate <T>(srcType.ItemType);

            _repIsDefault[iinfo] = new BitArray(srcType.VectorSize);
            for (int slot = 0; slot < src.Length; slot++)
            {
                if (defaultPred(ref src.Values[slot]))
                {
                    _repIsDefault[iinfo][slot] = true;
                }
            }
            T[] valReturn = src.Values;
            Array.Resize <T>(ref valReturn, srcType.VectorSize);
            Host.Assert(valReturn.Length == src.Length);
            return(valReturn);
        }
예제 #15
0
            public TrainState(IChannel ch, int numFeatures, LinearPredictor predictor, LinearSvm parent)
                : base(ch, numFeatures, predictor, parent)
            {
                _batchSize         = parent.Args.BatchSize;
                _noBias            = parent.Args.NoBias;
                _performProjection = parent.Args.PerformProjection;
                _lambda            = parent.Args.Lambda;

                if (_noBias)
                {
                    Bias = 0;
                }

                if (predictor == null)
                {
                    VBufferUtils.Densify(ref Weights);
                }

                _weightsUpdate = VBufferUtils.CreateEmpty <Float>(numFeatures);
            }
예제 #16
0
            private void GetLabels(Transposer trans, DataViewType labelType, int labelCol)
            {
                int min;
                int lim;
                var labels = default(VBuffer <int>);

                // Note: NAs have their own separate bin.
                if (labelType == NumberDataViewType.Int32)
                {
                    var tmp = default(VBuffer <int>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinInts(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberDataViewType.Single)
                {
                    var tmp = default(VBuffer <Single>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinSingles(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberDataViewType.Double)
                {
                    var tmp = default(VBuffer <Double>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinDoubles(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType is BooleanDataViewType)
                {
                    var tmp = default(VBuffer <bool>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinBools(in tmp, ref labels);
                    _numLabels = 3;
                    min        = -1;
                    lim        = 2;
                }
                else
                {
                    ulong labelKeyCount = labelType.GetKeyCount();
                    Contracts.Assert(labelKeyCount < Utils.ArrayMaxSize);
                    KeyLabelGetter <int> del = GetKeyLabels <int>;
                    var methodInfo           = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(labelType.RawType);
                    var parameters           = new object[] { trans, labelCol, labelType };
                    _labels    = (VBuffer <int>)methodInfo.Invoke(this, parameters);
                    _numLabels = labelType.GetKeyCountAsInt32(_host) + 1;

                    // No need to densify or shift in this case.
                    return;
                }

                // Densify and shift labels.
                VBufferUtils.Densify(ref labels);
                Contracts.Assert(labels.IsDense);
                var labelsEditor = VBufferEditor.CreateFromBuffer(ref labels);

                for (int i = 0; i < labels.Length; i++)
                {
                    labelsEditor.Values[i] -= min;
                    Contracts.Assert(labelsEditor.Values[i] < _numLabels);
                }
                _labels = labelsEditor.Commit();
            }
        private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount)
        {
            int numFeatures   = data.Schema.Feature.Value.Type.GetVectorSize();
            var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features);
            int numThreads    = 1;

            ch.CheckUserArg(numThreads > 0, nameof(_options.NumberOfThreads),
                            "The number of threads must be either null or a positive integer.");

            var             positiveInstanceWeight = _options.PositiveInstanceWeight;
            VBuffer <float> weights = default;
            float           bias    = 0.0f;

            if (predictor != null)
            {
                predictor.GetFeatureWeights(ref weights);
                VBufferUtils.Densify(ref weights);
                bias = predictor.Bias;
            }
            else
            {
                weights = VBufferUtils.CreateDense <float>(numFeatures);
            }

            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights);

            // Reference: Parasail. SymSGD.
            bool tuneLR = _options.LearningRate == null;
            var  lr     = _options.LearningRate ?? 1.0f;

            bool tuneNumLocIter = (_options.UpdateFrequency == null);
            var  numLocIter     = _options.UpdateFrequency ?? 1;

            var l2Const = _options.L2Regularization;
            var piw     = _options.PositiveInstanceWeight;

            // This is state of the learner that is shared with the native code.
            State    state         = new State();
            GCHandle stateGCHandle = default;

            try
            {
                stateGCHandle = GCHandle.Alloc(state, GCHandleType.Pinned);

                state.TotalInstancesProcessed = 0;
                using (InputDataManager inputDataManager = new InputDataManager(this, cursorFactory, ch))
                {
                    bool shouldInitialize = true;
                    using (var pch = Host.StartProgressChannel("Preprocessing"))
                        inputDataManager.LoadAsMuchAsPossible();

                    int iter = 0;
                    if (inputDataManager.IsFullyLoaded)
                    {
                        ch.Info("Data fully loaded into memory.");
                    }
                    using (var pch = Host.StartProgressChannel("Training"))
                    {
                        if (inputDataManager.IsFullyLoaded)
                        {
                            pch.SetHeader(new ProgressHeader(new[] { "iterations" }),
                                          entry => entry.SetProgress(0, state.PassIteration, _options.NumberOfIterations));
                            // If fully loaded, call the SymSGDNative and do not come back until learned for all iterations.
                            Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
                                            _options.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize,
                                            stateGCHandle, ch.Info);
                            shouldInitialize = false;
                        }
                        else
                        {
                            pch.SetHeader(new ProgressHeader(new[] { "iterations" }),
                                          entry => entry.SetProgress(0, iter, _options.NumberOfIterations));

                            // Since we loaded data in batch sizes, multiple passes over the loaded data is feasible.
                            int numPassesForABatch = inputDataManager.Count / 10000;
                            while (iter < _options.NumberOfIterations)
                            {
                                // We want to train on the final passes thoroughly (without learning on the same batch multiple times)
                                // This is for fine tuning the AUC. Experimentally, we found that 1 or 2 passes is enough
                                int numFinalPassesToTrainThoroughly = 2;
                                // We also do not want to learn for more passes than what the user asked
                                int numPassesForThisBatch = Math.Min(numPassesForABatch, _options.NumberOfIterations - iter - numFinalPassesToTrainThoroughly);
                                // If all of this leaves us with 0 passes, then set numPassesForThisBatch to 1
                                numPassesForThisBatch = Math.Max(1, numPassesForThisBatch);
                                state.PassIteration   = iter;
                                Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
                                                numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize,
                                                stateGCHandle, ch.Info);
                                shouldInitialize = false;

                                // Check if we are done with going through the data
                                if (inputDataManager.FinishedTheLoad)
                                {
                                    iter += numPassesForThisBatch;
                                    // Check if more passes are left
                                    if (iter < _options.NumberOfIterations)
                                    {
                                        inputDataManager.RestartLoading(_options.Shuffle, Host);
                                    }
                                }

                                // If more passes are left, load as much as possible
                                if (iter < _options.NumberOfIterations)
                                {
                                    inputDataManager.LoadAsMuchAsPossible();
                                }
                            }
                        }

                        // Maps back the dense features that are mislocated
                        if (numThreads > 1)
                        {
                            Native.MapBackWeightVector(weightsEditor.Values, stateGCHandle);
                        }
                        Native.DeallocateSequentially(stateGCHandle);
                    }
                }
            }
            finally
            {
                if (stateGCHandle.IsAllocated)
                {
                    stateGCHandle.Free();
                }
            }
            return(CreatePredictor(weights, bias));
        }
예제 #18
0
            private void GetLabels(Transposer trans, ColumnType labelType, int labelCol)
            {
                int min;
                int lim;
                var labels = default(VBuffer <int>);

                // Note: NAs have their own separate bin.
                if (labelType == NumberType.I4)
                {
                    var tmp = default(VBuffer <DvInt4>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinInts(ref tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberType.R4)
                {
                    var tmp = default(VBuffer <Single>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinSingles(ref tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberType.R8)
                {
                    var tmp = default(VBuffer <Double>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinDoubles(ref tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType.IsBool)
                {
                    var tmp = default(VBuffer <DvBool>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinBools(ref tmp, ref labels);
                    _numLabels = 3;
                    min        = -1;
                    lim        = 2;
                }
                else
                {
                    Contracts.Assert(0 < labelType.KeyCount && labelType.KeyCount < Utils.ArrayMaxSize);
                    KeyLabelGetter <int> del = GetKeyLabels <int>;
                    var methodInfo           = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(labelType.RawType);
                    var parameters           = new object[] { trans, labelCol, labelType };
                    _labels    = (int[])methodInfo.Invoke(this, parameters);
                    _numLabels = labelType.KeyCount + 1;

                    // No need to densify or shift in this case.
                    return;
                }

                // Densify and shift labels.
                VBufferUtils.Densify(ref labels);
                Contracts.Assert(labels.IsDense);
                _labels = labels.Values;
                if (labels.Length < _labels.Length)
                {
                    Array.Resize(ref _labels, labels.Length);
                }
                for (int i = 0; i < _labels.Length; i++)
                {
                    _labels[i] -= min;
                    Contracts.Assert(_labels[i] < _numLabels);
                }
            }