/// <summary>
        /// Fill a sparse DMatrix using CSR compression.
        /// See http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html.
        /// </summary>
        private DMatrix FillSparseMatrix(IChannel ch, int nbDim, long nbRows, RoleMappedData data,
                                         out Float[] labels, out uint[] groupCount)
        {
            // Allocation.
            if ((2 * nbRows) >= Utils.ArrayMaxSize)
            {
                throw _host.Except("The training dataset is too big to hold in memory. " +
                                   "2 features multiplied by the number of rows must be less than {0}.", Utils.ArrayMaxSize);
            }

            var  features = new Float[nbRows * 2];
            var  indices  = new uint[features.Length];
            var  indptr   = new ulong[nbRows + 1];
            long nelem    = 0;

            labels = new Float[nbRows];
            var hasWeights = data.Schema.Weight != null;
            var hasGroup   = data.Schema.Group != null;
            var weights    = hasWeights ? new Float[nbRows] : null;
            var groupsML   = hasGroup ? new uint[nbRows] : null;

            groupCount = hasGroup ? new uint[nbRows] : null;
            var groupId = hasGroup ? new HashSet <uint>() : null;

            int count     = 0;
            int lastGroup = -1;
            var flags     = CursOpt.Features | CursOpt.Label | CursOpt.AllowBadEverything | CursOpt.Weight | CursOpt.Group;

            var featureVector = default(VBuffer <float>);
            var labelProxy    = float.NaN;
            var groupProxy    = ulong.MaxValue;

            using (var cursor = data.CreateRowCursor(flags, null))
            {
                var featureGetter = cursor.GetFeatureFloatVectorGetter(data);
                var labelGetter   = cursor.GetLabelFloatGetter(data);
                var weighGetter   = cursor.GetOptWeightFloatGetter(data);
                var groupGetter   = cursor.GetOptGroupGetter(data);
                while (cursor.MoveNext())
                {
                    featureGetter(ref featureVector);
                    labelGetter(ref labelProxy);
                    labels[count] = labelProxy;
                    if (Single.IsNaN(labels[count]))
                    {
                        continue;
                    }

                    indptr[count] = (ulong)nelem;
                    int nbValues = featureVector.Count;
                    if (nbValues > 0)
                    {
                        if (nelem + nbValues > features.Length)
                        {
                            long newSize = Math.Max(nelem + nbValues, features.Length * 2);
                            if (newSize >= Utils.ArrayMaxSize)
                            {
                                throw _host.Except("The training dataset is too big to hold in memory. " +
                                                   "It should be half of {0}.", Utils.ArrayMaxSize);
                            }
                            Array.Resize(ref features, (int)newSize);
                            Array.Resize(ref indices, (int)newSize);
                        }

                        Array.Copy(featureVector.Values, 0, features, nelem, nbValues);
                        if (featureVector.IsDense)
                        {
                            for (int i = 0; i < nbValues; ++i)
                            {
                                indices[nelem++] = (uint)i;
                            }
                        }
                        else
                        {
                            for (int i = 0; i < nbValues; ++i)
                            {
                                indices[nelem++] = (uint)featureVector.Indices[i];
                            }
                        }
                    }

                    if (hasWeights)
                    {
                        weighGetter(ref weights[count]);
                    }
                    if (hasGroup)
                    {
                        groupGetter(ref groupProxy);
                        if (groupProxy >= uint.MaxValue)
                        {
                            throw _host.Except($"Group is above {uint.MaxValue}");
                        }
                        groupsML[count] = (uint)groupProxy;
                        if (count == 0 || groupsML[count - 1] != groupsML[count])
                        {
                            groupCount[++lastGroup] = 1;
                            ch.Check(!groupId.Contains(groupsML[count]), "Group Id are not contiguous.");
                            groupId.Add(groupsML[count]);
                        }
                        else
                        {
                            ++groupCount[lastGroup];
                        }
                    }
                    ++count;
                }
            }
            indptr[count] = (uint)nelem;

            if (nelem < features.Length * 3 / 4)
            {
                Array.Resize(ref features, (int)nelem);
                Array.Resize(ref indices, (int)nelem);
            }

            PostProcessLabelsBeforeCreatingXGBoostContainer(ch, data, labels);

            // We create a DMatrix.
            DMatrix dtrain = new DMatrix((uint)nbDim, indptr, indices, features, (uint)count, (uint)nelem, labels: labels, weights: weights, groups: groupCount);

            return(dtrain);
        }
示例#2
0
 protected static IRowCursor CreateCursor(RoleMappedData data, CursOpt opt, IRandom rand, params int[] extraCols)
 {
     Contracts.AssertValue(data);
     Contracts.AssertValueOrNull(rand);
     return(data.CreateRowCursor(opt, rand, extraCols));
 }
        private DMatrix FillDenseMatrix(IChannel ch, int nbDim, long nbRows,
                                        RoleMappedData data, out Float[] labels, out uint[] groupCount)
        {
            // Allocation.
            string errorMessageGroup = string.Format("Group is above {0}.", uint.MaxValue);

            if (nbDim * nbRows >= Utils.ArrayMaxSize)
            {
                throw _host.Except("The training dataset is too big to hold in memory. " +
                                   "Number of features ({0}) multiplied by the number of rows ({1}) must be less than {2}.", nbDim, nbRows, Utils.ArrayMaxSize);
            }
            var features = new Float[nbDim * nbRows];

            labels = new Float[nbRows];
            var hasWeights = data.Schema.Weight != null;
            var hasGroup   = data.Schema.Group != null;
            var weights    = hasWeights ? new Float[nbRows] : null;
            var groupsML   = hasGroup ? new uint[nbRows] : null;

            groupCount = hasGroup ? new uint[nbRows] : null;
            var groupId = hasGroup ? new HashSet <uint>() : null;

            int count     = 0;
            int lastGroup = -1;
            int fcount    = 0;
            var flags     = CursOpt.Features | CursOpt.Label | CursOpt.AllowBadEverything | CursOpt.Weight | CursOpt.Group;

            var featureVector = default(VBuffer <float>);
            var labelProxy    = float.NaN;
            var groupProxy    = ulong.MaxValue;

            using (var cursor = data.CreateRowCursor(flags, null))
            {
                var featureGetter = cursor.GetFeatureFloatVectorGetter(data);
                var labelGetter   = cursor.GetLabelFloatGetter(data);
                var weighGetter   = cursor.GetOptWeightFloatGetter(data);
                var groupGetter   = cursor.GetOptGroupGetter(data);

                while (cursor.MoveNext())
                {
                    featureGetter(ref featureVector);
                    labelGetter(ref labelProxy);

                    labels[count] = labelProxy;
                    if (Single.IsNaN(labels[count]))
                    {
                        continue;
                    }

                    featureVector.CopyTo(features, fcount, Single.NaN);
                    fcount += featureVector.Count;

                    if (hasWeights)
                    {
                        weighGetter(ref weights[count]);
                    }
                    if (hasGroup)
                    {
                        groupGetter(ref groupProxy);
                        _host.Check(groupProxy < uint.MaxValue, errorMessageGroup);
                        groupsML[count] = (uint)groupProxy;
                        if (count == 0 || groupsML[count - 1] != groupsML[count])
                        {
                            groupCount[++lastGroup] = 1;
                            ch.Check(!groupId.Contains(groupsML[count]), "Group Id are not contiguous.");
                            groupId.Add(groupsML[count]);
                        }
                        else
                        {
                            ++groupCount[lastGroup];
                        }
                    }
                    ++count;
                }
            }

            PostProcessLabelsBeforeCreatingXGBoostContainer(ch, data, labels);

            // We create a DMatrix.
            DMatrix dtrain = new DMatrix(features, (uint)count, (uint)nbDim, labels: labels, weights: weights, groups: groupCount);

            return(dtrain);
        }