static void TestCacheTransformSimple(int nt, bool async)
        {
            using (var host = EnvHelper.NewTestEnvironment(conc: nt == 1 ? 1 : 0))
            {
                var inputs = new InputOutput[] {
                    new InputOutput()
                    {
                        X = new float[] { 0, 1 }, Y = 1
                    },
                    new InputOutput()
                    {
                        X = new float[] { 0, 1 }, Y = 0
                    }
                };

                var data = host.CreateStreamingDataView(inputs);

                using (var cursor = data.GetRowCursor(i => true))
                {
                    var sortedValues     = new List <int>();
                    var sortColumnGetter = cursor.GetGetter <int>(1);
                    while (cursor.MoveNext())
                    {
                        int got = 0;
                        sortColumnGetter(ref got);
                        sortedValues.Add((int)got);
                    }
                    if (sortedValues.Count != 2)
                    {
                        throw new Exception();
                    }
                    if (sortedValues[0] != 1)
                    {
                        throw new Exception();
                    }
                    if (sortedValues[1] != 0)
                    {
                        throw new Exception();
                    }
                }

                var args = new ExtendedCacheTransform.Arguments {
                    numTheads = nt, async = async
                };
                var transformedData = new ExtendedCacheTransform(host, args, data);
                var lastTransform   = transformedData;
                LambdaTransform.CreateMap <InputOutput, InputOutput, EnvHelper.EmptyState>(host, data,
                                                                                           (input, output, state) =>
                {
                    output.X = input.X;
                    output.Y = input.Y;
                }, (EnvHelper.EmptyState state) => { });

                using (var cursor = lastTransform.GetRowCursor(i => true))
                {
                    var sortedValues     = new List <int>();
                    var sortColumnGetter = cursor.GetGetter <int>(1);
                    while (cursor.MoveNext())
                    {
                        int got = 0;
                        sortColumnGetter(ref got);
                        sortedValues.Add((int)got);
                    }
                    if (sortedValues.Count != 2)
                    {
                        throw new Exception();
                    }
                }
            }
        }
        private List <Tuple <int, TimeSpan, int, float[]> > _MeasureTime(int conc,
                                                                         string strategy, string engine, IDataScorerTransform scorer, ITransformer trscorer, int ncall)
        {
            var testFilename = FileHelper.GetTestFile("wikipedia-detox-250-line-test.tsv");
            var times        = new List <Tuple <int, TimeSpan, int, float[]> >();

            /*using (*/
            var env = EnvHelper.NewTestEnvironment(seed: 1, conc: conc);

            {
                if (engine == "mlnet")
                {
                    var args = new TextLoader.Options()
                    {
                        Separators = new[] { '\t' },
                        HasHeader  = true,
                        Columns    = new[] {
                            new TextLoader.Column("Label", DataKind.Boolean, 0),
                            new TextLoader.Column("SentimentText", DataKind.String, 1)
                        }
                    };

                    // Take a couple examples out of the test data and run predictions on top.
                    var       testLoader = new TextLoader(env, args).Load(new MultiFileSource(testFilename));
                    IDataView cache;
                    if (strategy.Contains("extcache"))
                    {
                        cache = new ExtendedCacheTransform(env, new ExtendedCacheTransform.Arguments(), testLoader);
                    }
                    else
                    {
                        cache = new CacheDataView(env, testLoader, new[] { 0, 1 });
                    }
                    //var testData = cache.AsEnumerable<SentimentDataBoolFloat>(env, false);
                    //var testDataArray = cache.AsEnumerable<SentimentDataBoolFloat>(env, false).ToArray();
                    int N = 1;

                    var model = ComponentCreation.CreatePredictionEngine <SentimentDataBoolFloat, SentimentPrediction>(env, trscorer);
                    var sw    = new Stopwatch();
                    for (int call = 1; call <= ncall; ++call)
                    {
                        sw.Reset();
                        var pred = new List <float>();
                        sw.Start();
                        for (int i = 0; i < N; ++i)
                        {
                            if (strategy.Contains("array"))
                            {
                                /*
                                 * foreach (var input in testDataArray)
                                 *  pred.Add(model.Predict(input).Score);
                                 */
                            }
                            else
                            {
                                /*
                                 * foreach (var input in testData)
                                 *  pred.Add(model.Predict(input).Score);
                                 */
                            }
                        }
                        sw.Stop();
                        times.Add(new Tuple <int, TimeSpan, int, float[]>(N, sw.Elapsed, call, pred.ToArray()));
                    }
                }
                else if (engine == "scikit")
                {
                    var args = new TextLoader.Options()
                    {
                        Separators = new[] { '\t' },
                        HasHeader  = true,
                        Columns    = new[] {
                            new TextLoader.Column("Label", DataKind.Boolean, 0),
                            new TextLoader.Column("SentimentText", DataKind.String, 1)
                        }
                    };

                    // Take a couple examples out of the test data and run predictions on top.
                    var       testLoader = new TextLoader(env, args).Load(new MultiFileSource(testFilename));
                    IDataView cache;
                    if (strategy.Contains("extcache"))
                    {
                        cache = new ExtendedCacheTransform(env, new ExtendedCacheTransform.Arguments(), testLoader);
                    }
                    else
                    {
                        cache = new CacheDataView(env, testLoader, new[] { 0, 1 });
                    }
                    //var testData = cache.AsEnumerable<SentimentDataBool>(env, false);
                    //var testDataArray = cache.AsEnumerable<SentimentDataBool>(env, false).ToArray();
                    int N = 1;

                    string allSchema = SchemaHelper.ToString(scorer.Schema);
                    Assert.IsTrue(allSchema.Contains("PredictedLabel:Bool:4; Score:R4:5; Probability:R4:6"));
                    var model  = new ValueMapperPredictionEngine <SentimentDataBool>(env, scorer);
                    var output = new ValueMapperPredictionEngine <SentimentDataBool> .PredictionTypeForBinaryClassification();

                    var sw = new Stopwatch();
                    for (int call = 1; call <= ncall; ++call)
                    {
                        var pred = new List <float>();
                        sw.Reset();
                        sw.Start();
                        for (int i = 0; i < N; ++i)
                        {
                            if (strategy.Contains("array"))
                            {
                                /*
                                 * foreach (var input in testDataArray)
                                 * {
                                 *  model.Predict(input, ref output);
                                 *  pred.Add(output.Score);
                                 * }
                                 */
                            }
                            else
                            {
                                /*
                                 * foreach (var input in testData)
                                 * {
                                 *  model.Predict(input, ref output);
                                 *  pred.Add(output.Score);
                                 * }
                                 */
                            }
                        }
                        sw.Stop();
                        times.Add(new Tuple <int, TimeSpan, int, float[]>(N, sw.Elapsed, call, pred.ToArray()));
                    }
                }
                else
                {
                    throw new NotImplementedException($"Unknown engine '{engine}'.");
                }
            }
            return(times);
        }
        private static List <Tuple <int, TimeSpan, int> > _MeasureTime(int conc,
                                                                       string engine, IDataScorerTransform scorer, ITransformer transformer,
                                                                       int N, int ncall, bool cacheScikit)
        {
            var args = new TextLoader.Arguments()
            {
                Separator = "tab",
                HasHeader = true,
                Column    = new[]
                {
                    new TextLoader.Column("Label", DataKind.BL, 0),
                    new TextLoader.Column("SentimentText", DataKind.Text, 1)
                }
            };

            var testFilename = FileHelper.GetTestFile("wikipedia-detox-250-line-test.tsv");
            var times        = new List <Tuple <int, TimeSpan, int> >();

            using (var env = EnvHelper.NewTestEnvironment(seed: 1, conc: conc))
            {
                // Take a couple examples out of the test data and run predictions on top.
                var       testLoader = TextLoader.ReadFile(env, args, new MultiFileSource(testFilename));
                IDataView cache;
                if (cacheScikit)
                {
                    cache = new ExtendedCacheTransform(env, new ExtendedCacheTransform.Arguments(), testLoader);
                }
                else
                {
                    cache = new CacheDataView(env, testLoader, new[] { 0, 1 });
                }
                var testData = cache.AsEnumerable <SentimentData>(env, false);

                if (engine == "mlnet")
                {
                    Console.WriteLine("engine={0} N={1} ncall={2} cacheScikit={3}", engine, N, ncall, cacheScikit);
                    var fct = transformer.MakePredictionFunction <SentimentData, SentimentPrediction>(env);
                    var sw  = new Stopwatch();
                    for (int call = 1; call <= ncall; ++call)
                    {
                        sw.Reset();
                        sw.Start();
                        for (int i = 0; i < N; ++i)
                        {
                            foreach (var input in testData)
                            {
                                fct.Predict(input);
                            }
                        }
                        sw.Stop();
                        times.Add(new Tuple <int, TimeSpan, int>(N, sw.Elapsed, call));
                    }
                }
                else if (engine == "scikit")
                {
                    Console.WriteLine("engine={0} N={1} ncall={2} cacheScikit={3}", engine, N, ncall, cacheScikit);
                    var model  = new ValueMapperPredictionEngine <SentimentData>(env, scorer, conc: conc);
                    var output = new ValueMapperPredictionEngine <SentimentData> .PredictionTypeForBinaryClassification();

                    var sw = new Stopwatch();
                    for (int call = 1; call <= ncall; ++call)
                    {
                        sw.Reset();
                        sw.Start();
                        for (int i = 0; i < N; ++i)
                        {
                            foreach (var input in testData)
                            {
                                model.Predict(input, ref output);
                            }
                        }
                        sw.Stop();
                        times.Add(new Tuple <int, TimeSpan, int>(N, sw.Elapsed, call));
                    }
                }
                else
                {
                    throw new NotImplementedException($"Unknown engine '{engine}'.");
                }
            }
            return(times);
        }
예제 #4
0
        IDataTransform AppendToPipeline(IDataView input)
        {
            IDataView current = input;

            if (_shuffleInput)
            {
                var args1 = new RowShufflingTransformer.Arguments()
                {
                    ForceShuffle     = false,
                    ForceShuffleSeed = _seedShuffle,
                    PoolRows         = _poolRows,
                    PoolOnly         = false,
                };
                current = new RowShufflingTransformer(Host, args1, current);
            }

            // We generate a random number.
            var columnName = current.Schema.GetTempColumnName();
            var args2      = new GenerateNumberTransform.Arguments()
            {
                Column = new GenerateNumberTransform.Column[] { new GenerateNumberTransform.Column()
                                                                {
                                                                    Name = columnName
                                                                } },
                Seed = _seed ?? 42
            };
            IDataTransform currentTr = new GenerateNumberTransform(Host, args2, current);

            // We convert this random number into a part.
            var cRatios = new float[_ratios.Length];

            cRatios[0] = 0;
            for (int i = 1; i < _ratios.Length; ++i)
            {
                cRatios[i] = cRatios[i - 1] + _ratios[i - 1];
            }

            ValueMapper <float, int> mapper = (in float src, ref int dst) =>
            {
                for (int i = cRatios.Length - 1; i > 0; --i)
                {
                    if (src >= cRatios[i])
                    {
                        dst = i;
                        return;
                    }
                }
                dst = 0;
            };

            // Get location of columnName

            int index;

            currentTr.Schema.TryGetColumnIndex(columnName, out index);
            var ct   = currentTr.Schema.GetColumnType(index);
            var view = LambdaColumnMapper.Create(Host, "Key to part mapper", currentTr,
                                                 columnName, _newColumn, ct, NumberType.I4, mapper);

            // We cache the result to avoid the pipeline to change the random number.
            var args3 = new ExtendedCacheTransform.Arguments()
            {
                inDataFrame = string.IsNullOrEmpty(_cacheFile),
                numTheads   = _numThreads,
                cacheFile   = _cacheFile,
                reuse       = _reuse,
            };

            currentTr = new ExtendedCacheTransform(Host, args3, view);

            // Removing the temporary column.
            var finalTr     = ColumnSelectingTransformer.CreateDrop(Host, currentTr, new string[] { columnName });
            var taggedViews = new List <Tuple <string, ITaggedDataView> >();

            // filenames
            if (_filenames != null || _tags != null)
            {
                int nbf = _filenames == null ? 0 : _filenames.Length;
                if (nbf > 0 && nbf != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nbt = _tags == null ? 0 : _tags.Length;
                if (nbt > 0 && nbt != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nb = Math.Max(nbf, nbt);

                using (var ch = Host.Start("Split the datasets and stores each part."))
                {
                    for (int i = 0; i < nb; ++i)
                    {
                        if (_filenames == null || !_filenames.Any())
                        {
                            ch.Info("Create part {0}: {1} (tag: {2})", i + 1, _ratios[i], _tags[i]);
                        }
                        else
                        {
                            ch.Info("Create part {0}: {1} (file: {2})", i + 1, _ratios[i], _filenames[i]);
                        }
                        var ar1 = new RangeFilter.Arguments()
                        {
                            Column = _newColumn, Min = i, Max = i, IncludeMax = true
                        };
                        int pardId   = i;
                        var filtView = LambdaFilter.Create <int>(Host, string.Format("Select part {0}", i), currentTr,
                                                                 _newColumn, NumberType.I4,
                                                                 (in int part) => { return(part.Equals(pardId)); });