Example #1
0
            private static void SetupCursor(TextLoader parent, bool[] active, int n,
                                            out int srcNeeded, out int cthd)
            {
                // Note that files is allowed to be empty.
                Contracts.AssertValue(parent);
                Contracts.Assert(active == null || active.Length == parent._bindings.Infos.Length);

                var bindings = parent._bindings;

                // This ensures _srcNeeded is >= 0.
                int srcLim = 1;

                for (int i = 0; i < bindings.Infos.Length; i++)
                {
                    if (active != null && !active[i])
                    {
                        continue;
                    }
                    var info = bindings.Infos[i];
                    foreach (var seg in info.Segments)
                    {
                        if (srcLim < seg.Lim)
                        {
                            srcLim = seg.Lim;
                        }
                    }
                }

                if (srcLim > parent._inputSize && parent._inputSize > 0)
                {
                    srcLim = parent._inputSize;
                }
                srcNeeded = srcLim - 1;
                Contracts.Assert(srcNeeded >= 0);

                // Determine the number of threads to use.
                cthd = DataViewUtils.GetThreadCount(parent._host, n, !parent._useThreads);

                long cblkMax = parent._maxRows / BatchSize;

                if (cthd > cblkMax)
                {
                    cthd = Math.Max(1, (int)cblkMax);
                }
            }
                public IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs)
                {
                    Contracts.AssertValue(provider);
                    int cthd = Interlocked.Exchange(ref _cthd, 0);

                    provider.Check(cthd > 1, "Consolidator can only be used once");
                    provider.Check(Utils.Size(inputs) == cthd, "Unexpected number of cursors");

                    // ConsolidateGeneric does all the standard validity checks: all cursors non-null,
                    // all have the same schema, all have the same active columns, and all active
                    // column types are cachable.
                    using (var ch = provider.Start("Consolidator"))
                    {
                        var result = DataViewUtils.ConsolidateGeneric(provider, inputs, BatchSize);
                        ch.Done();
                        return(result);
                    }
                }
Example #3
0
        /// <summary>
        /// Create a set of cursors with additional active columns.
        /// </summary>
        /// <param name="additionalColumnsPredicate">Predicate that denotes which additional columns to include in the cursor,
        /// in addition to the columns that are needed for populating the <typeparamref name="TRow"/> object.</param>
        /// <param name="n">Number of cursors to create</param>
        /// <param name="rand">Random generator to use</param>
        public RowCursor <TRow>[] GetCursorSet(Func <int, bool> additionalColumnsPredicate, int n, Random rand)
        {
            _host.CheckValue(additionalColumnsPredicate, nameof(additionalColumnsPredicate));
            _host.CheckValueOrNull(rand);

            var inputs = _data.GetRowCursorSet(_data.Schema.Where(col => _columnIndices.Contains(col.Index) || additionalColumnsPredicate(col.Index)), n, rand);

            _host.AssertNonEmpty(inputs);

            if (inputs.Length == 1 && n > 1)
            {
                inputs = DataViewUtils.CreateSplitCursors(_host, inputs[0], n);
            }
            _host.AssertNonEmpty(inputs);

            return(inputs
                   .Select(rc => (RowCursor <TRow>)(new RowCursorImplementation(new TypedCursor(this, rc))))
                   .ToArray());
        }
Example #4
0
        public IRowCursor GetRowCursor(Func <int, bool> predicate, IRandom rand = null)
        {
            Host.CheckValue(predicate, nameof(predicate));
            Host.CheckValueOrNull(rand);

            IRowCursor curs;

            if (DataViewUtils.TryCreateConsolidatingCursor(out curs, this, predicate, Host, rand))
            {
                return(curs);
            }

            var activeInputs = _schema.GetActiveInput(predicate);
            Func <int, bool> srcPredicate = c => activeInputs[c];

            var input = _typedSource.GetCursor(srcPredicate, rand == null ? (int?)null : rand.Next());

            return(new Cursor(Host, this, input, predicate));
        }
        public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func<int, bool> predicate, int n, IRandom rand = null)
        {
            Host.CheckValue(predicate, nameof(predicate));
            Host.CheckValueOrNull(rand);

            var inputPred = _bindings.GetDependencies(predicate);
            var active = _bindings.GetActive(predicate);
            var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand);
            Host.AssertNonEmpty(inputs);

            if (inputs.Length == 1 && n > 1 && _bindings.AnyNewColumnsActive(predicate))
                inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n);
            Host.AssertNonEmpty(inputs);

            var cursors = new IRowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = new RowCursor(this, inputs[i], active);
            return cursors;
        }
Example #6
0
        private static T[] NaiveTranspose <T>(IDataView view, int col)
        {
            var type     = view.Schema[col].Type;
            int rc       = checked ((int)DataViewUtils.ComputeRowCount(view));
            var vecType  = type as VectorDataViewType;
            var itemType = vecType?.ItemType ?? type;

            Assert.Equal(typeof(T), itemType.RawType);
            Assert.NotEqual(0, vecType?.Size);
            T[] retval = new T[rc * (vecType?.Size ?? 1)];

            using (var cursor = view.GetRowCursor(view.Schema[col]))
            {
                if (type is VectorDataViewType)
                {
                    var         getter = cursor.GetGetter <VBuffer <T> >(cursor.Schema[col]);
                    VBuffer <T> temp   = default;
                    int         offset = 0;
                    while (cursor.MoveNext())
                    {
                        Assert.True(0 <= offset && offset < rc && offset == cursor.Position);
                        getter(ref temp);
                        var tempValues  = temp.GetValues();
                        var tempIndices = temp.GetIndices();
                        for (int i = 0; i < tempValues.Length; ++i)
                        {
                            retval[(temp.IsDense ? i : tempIndices[i]) * rc + offset] = tempValues[i];
                        }
                        offset++;
                    }
                }
                else
                {
                    var getter = cursor.GetGetter <T>(cursor.Schema[col]);
                    while (cursor.MoveNext())
                    {
                        Assert.True(0 <= cursor.Position && cursor.Position < rc);
                        getter(ref retval[(int)cursor.Position]);
                    }
                }
            }
            return(retval);
        }
        private static void TransposeCheckHelper <T>(IDataView view, int viewCol, ITransposeDataView trans)
        {
            int col     = viewCol;
            var type    = trans.TransposeSchema.GetSlotType(col);
            var colType = trans.Schema.GetColumnType(col);

            Assert.Equal(view.Schema.GetColumnName(viewCol), trans.Schema.GetColumnName(col));
            var expectedType = view.Schema.GetColumnType(viewCol);

            // Unfortunately can't use equals because column type equality is a simple reference comparison. :P
            Assert.Equal(expectedType, colType);
            Assert.Equal(DataViewUtils.ComputeRowCount(view), (long)type.VectorSize);
            string desc = string.Format("Column {0} named '{1}'", col, trans.Schema.GetColumnName(col));

            Assert.True(typeof(T) == type.ItemType.RawType, $"{desc} had wrong type for slot cursor");
            Assert.True(type.IsVector, $"{desc} expected to be vector but is not");
            Assert.True(type.VectorSize > 0, $"{desc} expected to be known sized vector but is not");
            Assert.True(0 != colType.ValueCount, $"{desc} expected to have fixed size, but does not");
            int rc = type.VectorSize;

            T[] expectedVals = NaiveTranspose <T>(view, viewCol);
            T[] vals         = new T[rc * colType.ValueCount];
            Contracts.Assert(vals.Length == expectedVals.Length);
            using (var cursor = trans.GetSlotCursor(col))
            {
                var         getter = cursor.GetGetter <T>();
                VBuffer <T> temp   = default(VBuffer <T>);
                int         offset = 0;
                while (cursor.MoveNext())
                {
                    Assert.True(offset < vals.Length, $"{desc} slot cursor went further than it should have");
                    getter(ref temp);
                    Assert.True(rc == temp.Length, $"{desc} slot cursor yielded vector with unexpected length");
                    temp.CopyTo(vals, offset);
                    offset += rc;
                }
                Assert.True(colType.ValueCount == offset / rc, $"{desc} slot cursor yielded fewer than expected values");
            }
            for (int i = 0; i < vals.Length; ++i)
            {
                Assert.Equal(expectedVals[i], vals[i]);
            }
        }
Example #8
0
        public void SdcaBinaryClassificationNoClaibration()
        {
            var env        = new TlcEnvironment(seed: 0);
            var dataPath   = GetDataPath("breast-cancer.txt");
            var dataSource = new MultiFileSource(dataPath);

            var reader = TextLoader.CreateReader(env,
                                                 c => (label: c.LoadBool(0), features: c.LoadFloat(1, 9)));

            LinearBinaryPredictor pred = null;

            var loss = new HingeLoss(new HingeLoss.Arguments()
            {
                Margin = 1
            });

            // With a custom loss function we no longer get calibrated predictions.
            var est = reader.MakeNewEstimator()
                      .Append(r => (r.label, preds: r.label.PredictSdcaBinaryClassification(r.features,
                                                                                            maxIterations: 2,
                                                                                            loss: loss, onFit: p => pred = p)));

            var pipe = reader.Append(est);

            Assert.Null(pred);
            var model = pipe.Fit(dataSource);

            Assert.NotNull(pred);
            // 9 input features, so we ought to have 9 weights.
            Assert.Equal(9, pred.Weights2.Count);

            var data = model.Read(dataSource);

            // Just output some data on the schema for fun.
            var rows   = DataViewUtils.ComputeRowCount(data.AsDynamic);
            var schema = data.AsDynamic.Schema;

            for (int c = 0; c < schema.ColumnCount; ++c)
            {
                Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
            }
        }
        /// <summary>
        /// Create a set of cursors with additional active columns.
        /// </summary>
        /// <param name="consolidator">The consolidator for the original row cursors</param>
        /// <param name="additionalColumnsPredicate">Predicate that denotes which additional columns to include in the cursor,
        /// in addition to the columns that are needed for populating the <typeparamref name="TRow"/> object.</param>
        /// <param name="n">Number of cursors to create</param>
        /// <param name="rand">Random generator to use</param>
        public IRowCursor <TRow>[] GetCursorSet(out IRowCursorConsolidator consolidator,
                                                Func <int, bool> additionalColumnsPredicate, int n, IRandom rand)
        {
            _host.CheckValue(additionalColumnsPredicate, nameof(additionalColumnsPredicate));
            _host.CheckValueOrNull(rand);

            Func <int, bool> inputPredicate = col => _columnIndices.Contains(col) || additionalColumnsPredicate(col);
            var inputs = _data.GetRowCursorSet(out consolidator, inputPredicate, n, rand);

            _host.AssertNonEmpty(inputs);

            if (inputs.Length == 1 && n > 1)
            {
                inputs = DataViewUtils.CreateSplitCursors(out consolidator, _host, inputs[0], n);
            }
            _host.AssertNonEmpty(inputs);

            return(inputs
                   .Select(rc => (IRowCursor <TRow>)(new TypedCursor(this, rc)))
                   .ToArray());
        }
Example #10
0
        public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func<int, bool> predicate, int n, Random rand = null)
        {
            Host.CheckValue(predicate, nameof(predicate));
            Host.CheckValueOrNull(rand);

            var bindings = GetBindings();
            Func<int, bool> predicateInput;
            Func<int, bool> predicateMapper;
            var active = GetActive(bindings, predicate, out predicateInput, out predicateMapper);
            var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand);
            Contracts.AssertNonEmpty(inputs);

            if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate) && (Source.GetRowCount() ?? int.MaxValue) > n)
                inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n);
            Contracts.AssertNonEmpty(inputs);

            var cursors = new IRowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = new RowCursor(Host, this, inputs[i], active, predicateMapper);
            return cursors;
        }
Example #11
0
        public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func <int, bool> predicate, int n, IRandom rand = null)
        {
            var host = new ConsoleEnvironment().Register("Estimate n threads");

            n = DataViewUtils.GetThreadCount(host, n);

            if (n <= 1)
            {
                consolidator = null;
                return(new IRowCursor[] { GetRowCursor(predicate, rand) });
            }
            else
            {
                var cursors = Source.GetRowCursorSet(out consolidator, i => predicate(i) || predicate(SchemaHelper.NeedColumn(_columnMapping, i)),
                                                     n, rand);
                for (int i = 0; i < cursors.Length; ++i)
                {
                    cursors[i] = new AddRandomCursor(this, cursors[i]);
                }
                return(cursors);
            }
        }
Example #12
0
        public IRowCursor GetRowCursor(Func <int, bool> predicate, IRandom rand = null)
        {
            Host.CheckValue(predicate, nameof(predicate));
            Host.CheckValueOrNull(rand);

            var  rng         = CanShuffle ? rand : null;
            bool?useParallel = ShouldUseParallelCursors(predicate);

            // When useParallel is null, let the input decide, so go ahead and ask for parallel.
            // When the input wants to be split, this puts the consolidation after this transform
            // instead of before. This is likely to produce better performance, for example, when
            // this is RangeFilter.
            IRowCursor curs;

            if (useParallel != false &&
                DataViewUtils.TryCreateConsolidatingCursor(out curs, this, predicate, Host, rng))
            {
                return(curs);
            }

            return(GetRowCursorCore(predicate, rng));
        }
Example #13
0
        public DataViewRowCursor[] GetRowCursorSet(IEnumerable <DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            n = DataViewUtils.GetThreadCount(n);

            if (n <= 1)
            {
                return new DataViewRowCursor[] { GetRowCursor(columnsNeeded, rand) }
            }
            ;
            else
            {
                //var cols = SchemaHelper.ColumnsNeeded(columnsNeeded, Schema, _args.columns);
                var cols = SchemaHelper.ColumnsNeeded(columnsNeeded, Source.Schema);

                var cursors = Source.GetRowCursorSet(cols, n, rand);
                for (int i = 0; i < cursors.Length; ++i)
                {
                    cursors[i] = new AddRandomCursor(this, cursors[i]);
                }
                return(cursors);
            }
        }
        private static T[] NaiveTranspose <T>(IDataView view, int col)
        {
            var type = view.Schema.GetColumnType(col);
            int rc   = checked ((int)DataViewUtils.ComputeRowCount(view));

            Assert.True(type.ItemType.RawType == typeof(T));
            Assert.True(type.ValueCount > 0);
            T[] retval = new T[rc * type.ValueCount];

            using (var cursor = view.GetRowCursor(c => c == col))
            {
                if (type.IsVector)
                {
                    var         getter = cursor.GetGetter <VBuffer <T> >(col);
                    VBuffer <T> temp   = default(VBuffer <T>);
                    int         offset = 0;
                    while (cursor.MoveNext())
                    {
                        Assert.True(0 <= offset && offset < rc && offset == cursor.Position);
                        getter(ref temp);
                        for (int i = 0; i < temp.Count; ++i)
                        {
                            retval[(temp.IsDense ? i : temp.Indices[i]) * rc + offset] = temp.Values[i];
                        }
                        offset++;
                    }
                }
                else
                {
                    var getter = cursor.GetGetter <T>(col);
                    while (cursor.MoveNext())
                    {
                        Assert.True(0 <= cursor.Position && cursor.Position < rc);
                        getter(ref retval[(int)cursor.Position]);
                    }
                }
            }
            return(retval);
        }
Example #15
0
        public sealed override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.CheckValueOrNull(rand);

            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);

            var inputPred = _bindings.GetDependencies(predicate);
            var active = _bindings.GetActive(predicate);

            var inputCols = Source.Schema.Where(x => inputPred(x.Index));
            var inputs = Source.GetRowCursorSet(inputCols, n, rand);
            Host.AssertNonEmpty(inputs);

            if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate))
                inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n);
            Host.AssertNonEmpty(inputs);

            var cursors = new RowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = new Cursor(Host, this, inputs[i], active);
            return cursors;
        }
Example #16
0
        public RowCursor GetRowCursor(IEnumerable<Schema.Column> columnsNeeded, Random rand = null)
        {
            Host.CheckValueOrNull(rand);

            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);

            var rng = CanShuffle ? rand : null;
            bool? useParallel = ShouldUseParallelCursors(predicate);

            // When useParallel is null, let the input decide, so go ahead and ask for parallel.
            // When the input wants to be split, this puts the consolidation after this transform
            // instead of before. This is likely to produce better performance, for example, when
            // this is RangeFilter.
            RowCursor curs;
            if (useParallel != false &&
                DataViewUtils.TryCreateConsolidatingCursor(out curs, this, columnsNeeded, Host, rng))
            {
                return curs;
            }

            return GetRowCursorCore(columnsNeeded, rng);
        }
        public void TestI_ScalerTransformSerialize()
        {
            /*using (*/
            var host = EnvHelper.NewTestEnvironment();
            {
                var inputs = new[] {
                    new ExampleA()
                    {
                        X = new float[] { 1, 10, 100 }
                    },
                    new ExampleA()
                    {
                        X = new float[] { 2, 3, 5 }
                    }
                };

                IDataView loader = DataViewConstructionUtils.CreateFromEnumerable(host, inputs);
                var       data   = host.CreateTransform("Scaler{col=X}", loader);
                (data as ITrainableTransform).Estimate();

                // We create a specific folder in build/UnitTest which will contain the output.
                var methodName       = System.Reflection.MethodBase.GetCurrentMethod().Name;
                var outModelFilePath = FileHelper.GetOutputFile("outModelFilePath.zip", methodName);
                var outData          = FileHelper.GetOutputFile("outData.txt", methodName);
                var outData2         = FileHelper.GetOutputFile("outData2.txt", methodName);
                var nb = DataViewUtils.ComputeRowCount(data);
                if (nb < 1)
                {
                    throw new Exception("empty view");
                }

                // This function serializes the output data twice, once before saving the pipeline, once after loading the pipeline.
                // It checks it gives the same result.
                TestTransformHelper.SerializationTestTransform(host, outModelFilePath, data, loader, outData, outData2);
            }
        }
        static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            env.CheckValue(args.tag, "Tag cannot be empty.");
            if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag).Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }
            env.CheckValue(args.selectTag, "Selected tag cannot be empty.");

            if (string.IsNullOrEmpty(args.filename))
            {
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag);
                if (!selected.Any())
                {
                    throw env.Except("Unable to find a view to select with tag '{0}'. Did you forget to specify a filename?", args.selectTag);
                }
                var first = selected.First();
                if (selected.Skip(1).Any())
                {
                    throw env.Except("Tag '{0}' is ambiguous, {1} views were found.", args.selectTag, selected.Count());
                }
                var tagged = input as ITaggedDataView;
                if (tagged == null)
                {
                    var ag = new TagViewTransform.Arguments {
                        tag = args.tag
                    };
                    tagged = new TagViewTransform(env, ag, input);
                }
                first.Item2.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) });
                tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, first.Item2) });
#if (DEBUG_TIP)
                long count = DataViewUtils.ComputeRowCount(tagged);
                if (count == 0)
                {
                    throw env.Except("Replaced view is empty.");
                }
                count = DataViewUtils.ComputeRowCount(first.Item2);
                if (count == 0)
                {
                    throw env.Except("Selected view is empty.");
                }
#endif
                var tr = first.Item2 as IDataTransform;
                env.AssertValue(tr);
                return(tr);
            }
            else
            {
                if (!File.Exists(args.filename))
                {
                    throw env.Except("Unable to find file '{0}'.", args.filename);
                }
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag);
                if (selected.Any())
                {
                    throw env.Except("Tag '{0}' was already given. It cannot be assigned to the new file.", args.selectTag);
                }
                var loaderArgs   = new BinaryLoader.Arguments();
                var file         = new MultiFileSource(args.filename);
                var loadSettings = ScikitSubComponent <ILegacyDataLoader, SignatureDataLoader> .AsSubComponent(args.loaderSettings);

                IDataView loader = loadSettings.CreateInstance(env, file);

                var ag = new TagViewTransform.Arguments {
                    tag = args.selectTag
                };
                var newInput = new TagViewTransform(env, ag, loader);
                var tagged   = input as ITaggedDataView;
                if (tagged == null)
                {
                    ag = new TagViewTransform.Arguments {
                        tag = args.tag
                    };
                    tagged = new TagViewTransform(env, ag, input);
                }

                newInput.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) });
                tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, newInput) });

                var schema = loader.Schema;
                if (schema.Count == 0)
                {
                    throw env.Except("The loaded view '{0}' is empty (empty schema).", args.filename);
                }
                return(newInput);
            }
        }
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var data0 = data;

            #region adding group ID

            // We insert a group Id.
            string groupColumnTemp = DataViewUtils.GetTempColumnName(data.Schema.Schema) + "GR";
            var    groupArgs       = new GenerateNumberTransform.Options
            {
                Columns    = new[] { GenerateNumberTransform.Column.Parse(groupColumnTemp) },
                UseCounter = true
            };

            var withGroup = new GenerateNumberTransform(Host, groupArgs, data.Data);
            data = new RoleMappedData(withGroup, data.Schema.GetColumnRoleNames());

            #endregion

            #region preparing the training dataset

            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            // We check the label is not boolean.
            int indexLab = SchemaHelper.GetColumnIndex(trans.Schema, dstName);
            var typeLab  = trans.Schema[indexLab].Type;
            if (typeLab.RawKind() == DataKind.Boolean)
            {
                throw Host.Except("Column '{0}' has an unexpected type {1}.", dstName, typeLab.RawKind());
            }

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;
            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                ulong nb  = data.Schema.Label.Value.Type.AsKey().GetKeyCount();
                var   sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMulticlassLabel(out count3);
                if ((ulong)count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            var after_concatenation_ = ComponentCreation.CreateTransform(Host, args, viewI);

            #endregion

            #region converting label and group into keys

            // We need to convert the label into a Key.
            var convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", dstName)) },
                keyCount   = new KeyCount(4),
                resultType = DataKind.UInt32
            };
            IDataView after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_);

            // The group must be a key too!
            convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", groupColumnTemp)) },
                keyCount   = new KeyCount(),
                resultType = _args.groupIsU4 ? DataKind.UInt32 : DataKind.UInt64
            };
            after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_key_label);

            #endregion

            #region preparing the RoleMapData view

            string groupColumn = groupColumnTemp + "k";
            dstName += "k";

            var roles      = data.Schema.GetColumnRoleNames();
            var rolesArray = roles.ToArray();
            roles = roles
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                    .Where(kvp => kvp.Key.Value != groupColumn)
                    .Where(kvp => kvp.Key.Value != groupColumnTemp);
            rolesArray = roles.ToArray();
            if (rolesArray.Any() && rolesArray[0].Value == groupColumnTemp)
            {
                throw ch.Except("Duplicated group.");
            }
            roles = roles
                    .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                    .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName))
                    .Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupColumn));
            var trainer_input = new RoleMappedData(after_concatenation_key_label, roles);

            #endregion

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }

            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }
Example #20
0
        public static void TrainkNNTransformId(int k, NearestNeighborsWeights weight, int threads, string distance = "L2")
        {
            var methodName       = string.Format("{0}-k{1}-W{2}-T{3}-D{4}", System.Reflection.MethodBase.GetCurrentMethod().Name, k, weight, threads, distance);
            var dataFilePath     = FileHelper.GetTestFile("iris_binary_id.txt");
            var outModelFilePath = FileHelper.GetOutputFile("outModelFilePath.zip", methodName);
            var outData          = FileHelper.GetOutputFile("outData1.txt", methodName);
            var outData2         = FileHelper.GetOutputFile("outData2.txt", methodName);

            var env = k == 1 ? EnvHelper.NewTestEnvironment(conc: 1) : EnvHelper.NewTestEnvironment();

            using (env)
            {
                var loader = env.CreateLoader("Text{col=Label:R4:0 col=Slength:R4:1 col=Swidth:R4:2 col=Plength:R4:3 col=Pwidth:R4:4 col=Uid:I8:5 header=+}",
                                              new MultiFileSource(dataFilePath));

                var concat = env.CreateTransform("Concat{col=Features:Slength,Swidth}", loader);
                if (distance == "cosine")
                {
                    concat = env.CreateTransform("Scaler{col=Features}", concat);
                }
                concat = env.CreateTransform("knntr{k=5 id=Uid}", concat);
                long nb = DataViewUtils.ComputeRowCount(concat);
                if (nb == 0)
                {
                    throw new System.Exception("Empty pipeline.");
                }

                using (var cursor = concat.GetRowCursor(i => true))
                {
                    var getdist = cursor.GetGetter <VBuffer <float> >(7);
                    var getid   = cursor.GetGetter <VBuffer <long> >(8);
                    var ddist   = new VBuffer <float>();
                    var did     = new VBuffer <long>();
                    while (cursor.MoveNext())
                    {
                        getdist(ref ddist);
                        getid(ref did);
                        if (!ddist.IsDense || !did.IsDense)
                        {
                            throw new System.Exception("not dense");
                        }
                        if (ddist.Count != did.Count)
                        {
                            throw new System.Exception("not the same dimension");
                        }
                        for (int i = 1; i < ddist.Count; ++i)
                        {
                            if (ddist.Values[i - 1] > ddist.Values[i])
                            {
                                throw new System.Exception("not sorted");
                            }
                            if (did.Values[i] % 2 != 1)
                            {
                                throw new System.Exception("wrong id");
                            }
                        }
                    }
                }

                TestTransformHelper.SerializationTestTransform(env, outModelFilePath, concat, loader, outData, outData2, false);
            }
        }
Example #21
0
        public static Type InferPredictorCategoryType(IDataView data, PurposeInference.Column[] columns)
        {
            List <PurposeInference.Column> labels = columns.Where(col => col.Purpose == ColumnPurpose.Label).ToList();

            if (labels.Count == 0)
            {
                return(typeof(SignatureClusteringTrainer));
            }

            if (labels.Count > 1)
            {
                return(typeof(SignatureMultiOutputRegressorTrainer));
            }

            PurposeInference.Column label             = labels.First();
            HashSet <string>        uniqueLabelValues = new HashSet <string>();

            data = data.Take(1000);
            using (var cursor = data.GetRowCursor(index => index == label.ColumnIndex))
            {
                ValueGetter <DvText> getter = DataViewUtils.PopulateGetterArray(cursor, new List <int> {
                    label.ColumnIndex
                })[0];
                while (cursor.MoveNext())
                {
                    var currentLabel = new DvText();
                    getter(ref currentLabel);
                    string currentLabelString = currentLabel.ToString();
                    if (!String.IsNullOrEmpty(currentLabelString) && !uniqueLabelValues.Contains(currentLabelString))
                    {
                        uniqueLabelValues.Add(currentLabelString);
                    }
                }
            }

            if (uniqueLabelValues.Count == 1)
            {
                return(typeof(SignatureAnomalyDetectorTrainer));
            }

            if (uniqueLabelValues.Count == 2)
            {
                return(typeof(SignatureBinaryClassifierTrainer));
            }

            if (uniqueLabelValues.Count > 2)
            {
                if ((label.ItemKind == DataKind.R4) &&
                    uniqueLabelValues.Any(val =>
                {
                    float fVal;
                    return(float.TryParse(val, out fVal) && (fVal > 50 || fVal < 0 || val.Contains('.')));
                }))
                {
                    return(typeof(SignatureRegressorTrainer));
                }

                if (label.ItemKind == DataKind.R4 ||
                    label.ItemKind == DataKind.TX ||
                    data.Schema.GetColumnType(label.ColumnIndex).IsKey)
                {
                    if (columns.Any(col => col.Purpose == ColumnPurpose.Group))
                    {
                        return(typeof(SignatureRankerTrainer));
                    }
                    else
                    {
                        return(typeof(SignatureMultiClassClassifierTrainer));
                    }
                }
            }

            return(null);
        }
 public IRowCursor CreateCursor(IChannelProvider provider, IRowCursor[] inputs)
 {
     return(DataViewUtils.ConsolidateGeneric(provider, inputs, _batchSize));
 }
        private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, TPredictor predictor)
        {
            // Verifications.
            _host.AssertValue(ch);
            ch.CheckValue(data, nameof(data));

            ValidateTrainInput(ch, data);

            var featureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);

            ch.Check(featureColumns.Count == 1, "Only one vector of features is allowed.");

            // Data dimension.
            int fi      = data.Schema.Feature.Index;
            var colType = data.Schema.Schema.GetColumnType(fi);

            ch.Assert(colType.IsVector, "Feature must be a vector.");
            ch.Assert(colType.VectorSize > 0, "Feature dimension must be known.");
            int       nbDim  = colType.VectorSize;
            IDataView view   = data.Data;
            long      nbRows = DataViewUtils.ComputeRowCount(view);

            Float[] labels;
            uint[]  groupCount;
            DMatrix dtrain;
            // REVIEW xadupre: this can be avoided by using method XGDMatrixCreateFromDataIter from the XGBoost API.
            // XGBoost removes NaN values from a dense matrix and stores it in sparse format anyway.
            bool isDense = DetectDensity(data);
            var  dt      = DateTime.Now;

            if (isDense)
            {
                dtrain = FillDenseMatrix(ch, nbDim, nbRows, data, out labels, out groupCount);
                ch.Info("Dense matrix created with nbFeatures={0} and nbRows={1} in {2}.", nbDim, nbRows, DateTime.Now - dt);
            }
            else
            {
                dtrain = FillSparseMatrix(ch, nbDim, nbRows, data, out labels, out groupCount);
                ch.Info("Sparse matrix created with nbFeatures={0} and nbRows={1} in {2}.", nbDim, nbRows, DateTime.Now - dt);
            }

            // Some options are filled based on the data.
            var options = _args.ToDict(_host);

            UpdateXGBoostOptions(ch, options, labels, groupCount);

            // For multi class, the number of labels is required.
            ch.Assert(PredictionKind != PredictionKind.MultiClassClassification || options.ContainsKey("num_class"),
                      "XGBoost requires the number of classes to be specified in the parameters.");

            ch.Info("XGBoost objective={0}", options["objective"]);

            int     numTrees;
            Booster res = WrappedXGBoostTraining.Train(ch, pch, out numTrees, options, dtrain,
                                                       numBoostRound: _args.numBoostRound,
                                                       obj: null, verboseEval: _args.verboseEval,
                                                       xgbModel: predictor == null ? null : predictor.GetBooster(),
                                                       saveBinaryDMatrix: _args.saveXGBoostDMatrixAsBinary);

            int nbTrees = res.GetNumTrees();

            ch.Info("Training is complete. Number of added trees={0}, total={1}.", numTrees, nbTrees);

            _model             = res.SaveRaw();
            _nbFeaturesXGboost = (int)dtrain.GetNumCols();
            _nbFeaturesML      = nbDim;
        }