Esempio n. 1
0
        // This method does as many passes over the data as needed by the evaluator, and computes the metrics, outputting the
        // results in a dictionary from the metric kind (overal/per-fold/confusion matrix/PR-curves etc.), to a data view containing
        // the metric. If there are stratified metrics, an additional column is added to the data view containing the
        // stratification value as text in the format "column x = y".
        private Dictionary <string, IDataView> ProcessData(IDataView data, RoleMappedSchema schema,
                                                           Func <int, bool> activeColsIndices, TAgg aggregator, AggregatorDictionaryBase[] dictionaries)
        {
            Func <bool> finishPass =
                () =>
            {
                var need = aggregator.FinishPass();
                foreach (var agg in dictionaries.SelectMany(dict => dict.GetAll()))
                {
                    need |= agg.FinishPass();
                }
                return(need);
            };

            bool needMorePasses = aggregator.Start();

            var activeCols = data.Schema.Where(x => activeColsIndices(x.Index));

            // REVIEW: Add progress reporting.
            while (needMorePasses)
            {
                using (var cursor = data.GetRowCursor(activeCols))
                {
                    if (aggregator.IsActive())
                    {
                        aggregator.InitializeNextPass(cursor, schema);
                    }
                    for (int i = 0; i < Utils.Size(dictionaries); i++)
                    {
                        dictionaries[i].Reset(cursor);

                        foreach (var agg in dictionaries[i].GetAll())
                        {
                            if (agg.IsActive())
                            {
                                agg.InitializeNextPass(cursor, schema);
                            }
                        }
                    }
                    while (cursor.MoveNext())
                    {
                        if (aggregator.IsActive())
                        {
                            aggregator.ProcessRow();
                        }
                        for (int i = 0; i < Utils.Size(dictionaries); i++)
                        {
                            var agg = dictionaries[i].Get();
                            if (agg.IsActive())
                            {
                                agg.ProcessRow();
                            }
                        }
                    }
                }
                needMorePasses = finishPass();
            }

            Action <uint, ReadOnlyMemory <char>, TAgg> addAgg;
            Func <Dictionary <string, IDataView> >     consolidate;

            GetAggregatorConsolidationFuncs(aggregator, dictionaries, out addAgg, out consolidate);

            uint stratColKey = 0;

            addAgg(stratColKey, default, aggregator);
        [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only
        public void TensorFlowTransformCifarSavedModel()
        {
            var model_location = "cifar_saved_model";

            var mlContext       = new MLContext(seed: 1, conc: 1);
            var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(mlContext, model_location);
            var schema          = tensorFlowModel.GetInputSchema();

            Assert.True(schema.TryGetColumnIndex("Input", out int column));
            var type        = (VectorType)schema[column].Type;
            var imageHeight = type.Dimensions[0];
            var imageWidth  = type.Dimensions[1];

            var dataFile    = GetDataPath("images/images.tsv");
            var imageFolder = Path.GetDirectoryName(dataFile);
            var data        = mlContext.Data.ReadFromTextFile(dataFile, columns: new[]
            {
                new TextLoader.Column("ImagePath", DataKind.TX, 0),
                new TextLoader.Column("Name", DataKind.TX, 1),
            }
                                                              );
            var images = ImageLoaderTransform.Create(mlContext, new ImageLoaderTransform.Arguments()
            {
                Column = new ImageLoaderTransform.Column[1]
                {
                    new ImageLoaderTransform.Column()
                    {
                        Source = "ImagePath", Name = "ImageReal"
                    }
                },
                ImageFolder = imageFolder
            }, data);
            var cropped = ImageResizerTransform.Create(mlContext, new ImageResizerTransform.Arguments()
            {
                Column = new ImageResizerTransform.Column[1] {
                    new ImageResizerTransform.Column()
                    {
                        Source = "ImageReal", Name = "ImageCropped", ImageHeight = imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop
                    }
                }
            }, images);

            var pixels = ImagePixelExtractorTransform.Create(mlContext, new ImagePixelExtractorTransform.Arguments()
            {
                Column = new ImagePixelExtractorTransform.Column[1] {
                    new ImagePixelExtractorTransform.Column()
                    {
                        Source = "ImageCropped", Name = "Input", UseAlpha = false, InterleaveArgb = true
                    }
                }
            }, cropped);


            IDataView trans = TensorFlowTransform.Create(mlContext, pixels, tensorFlowModel, new[] { "Output" }, new[] { "Input" });

            trans.Schema.TryGetColumnIndex("Output", out int output);
            using (var cursor = trans.GetRowCursor(col => col == output))
            {
                var buffer  = default(VBuffer <float>);
                var getter  = cursor.GetGetter <VBuffer <float> >(output);
                var numRows = 0;
                while (cursor.MoveNext())
                {
                    getter(ref buffer);
                    Assert.Equal(10, buffer.Length);
                    numRows += 1;
                }
                Assert.Equal(4, numRows);
            }
        }
Esempio n. 3
0
        private SequencePool[] Train(Arguments args, IDataView trainingData, out double[][] invDocFreqs)
        {
            // Contains the maximum number of grams to store in the dictionary, for each level of ngrams,
            // from 1 (in position 0) up to ngramLength (in position ngramLength-1)
            var lims = new int[Infos.Length][];

            for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
            {
                var all         = args.Column[iinfo].AllLengths ?? args.AllLengths;
                var ngramLength = _exes[iinfo].NgramLength;
                var maxNumTerms = Utils.Size(args.Column[iinfo].MaxNumTerms) > 0 ? args.Column[iinfo].MaxNumTerms : args.MaxNumTerms;
                if (!all)
                {
                    Host.CheckUserArg(Utils.Size(maxNumTerms) == 0 ||
                                      Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(args.MaxNumTerms));
                    lims[iinfo] = new int[ngramLength];
                    lims[iinfo][ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? Arguments.DefaultMaxTerms : maxNumTerms[0];
                }
                else
                {
                    Host.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(args.MaxNumTerms));
                    Host.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(args.MaxNumTerms));
                    var extend = Utils.Size(maxNumTerms) == 0 ? Arguments.DefaultMaxTerms : maxNumTerms[maxNumTerms.Length - 1];
                    lims[iinfo] = Utils.BuildArray(ngramLength,
                                                   i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend);
                }
            }

            var helpers = new NgramBufferBuilder[Infos.Length];
            var getters = new ValueGetter <VBuffer <uint> > [Infos.Length];
            var src     = new VBuffer <uint> [Infos.Length];

            // Keep track of how many grams are in the pool for each value of n. Position
            // i in _counts counts how many (i+1)-grams are in the pool for column iinfo.
            var counts    = new int[Infos.Length][];
            var ngramMaps = new SequencePool[Infos.Length];

            bool[] activeInput = new bool[trainingData.Schema.ColumnCount];
            foreach (var info in Infos)
            {
                activeInput[info.Source] = true;
            }
            using (var cursor = trainingData.GetRowCursor(col => activeInput[col]))
                using (var pch = Host.StartProgressChannel("Building n-gram dictionary"))
                {
                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        Host.Assert(Infos[iinfo].TypeSrc.IsVector && Infos[iinfo].TypeSrc.ItemType.IsKey);
                        var ngramLength = _exes[iinfo].NgramLength;
                        var skipLength  = _exes[iinfo].SkipLength;

                        getters[iinfo]   = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, cursor, Infos[iinfo].Source);
                        src[iinfo]       = default(VBuffer <uint>);
                        counts[iinfo]    = new int[ngramLength];
                        ngramMaps[iinfo] = new SequencePool();

                        // Note: GetNgramIdFinderAdd will control how many ngrams of a specific length will
                        // be added (using lims[iinfo]), therefore we set slotLim to the maximum
                        helpers[iinfo] = new NgramBufferBuilder(ngramLength, skipLength, Utils.ArrayMaxSize,
                                                                GetNgramIdFinderAdd(counts[iinfo], lims[iinfo], ngramMaps[iinfo], _exes[iinfo].RequireIdf(), Host));
                    }

                    int    cInfoFull = 0;
                    bool[] infoFull  = new bool[Infos.Length];

                    invDocFreqs = new double[Infos.Length][];

                    long   totalDocs = 0;
                    Double rowCount  = trainingData.GetRowCount(true) ?? Double.NaN;
                    var    buffers   = new VBuffer <float> [Infos.Length];
                    pch.SetHeader(new ProgressHeader(new[] { "Total n-grams" }, new[] { "documents" }),
                                  e => e.SetProgress(0, totalDocs, rowCount));
                    while (cInfoFull < Infos.Length && cursor.MoveNext())
                    {
                        totalDocs++;
                        for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                        {
                            getters[iinfo](ref src[iinfo]);
                            var keyCount = (uint)Infos[iinfo].TypeSrc.ItemType.KeyCount;
                            if (keyCount == 0)
                            {
                                keyCount = uint.MaxValue;
                            }
                            if (!infoFull[iinfo])
                            {
                                if (_exes[iinfo].RequireIdf())
                                {
                                    helpers[iinfo].Reset();
                                }

                                helpers[iinfo].AddNgrams(ref src[iinfo], 0, keyCount);
                                if (_exes[iinfo].RequireIdf())
                                {
                                    int totalNgrams = counts[iinfo].Sum();
                                    Utils.EnsureSize(ref invDocFreqs[iinfo], totalNgrams);
                                    helpers[iinfo].GetResult(ref buffers[iinfo]);
                                    foreach (var pair in buffers[iinfo].Items())
                                    {
                                        if (pair.Value >= 1)
                                        {
                                            invDocFreqs[iinfo][pair.Key] += 1;
                                        }
                                    }
                                }
                            }
                            AssertValid(counts[iinfo], lims[iinfo], ngramMaps[iinfo]);
                        }
                    }

                    pch.Checkpoint(counts.Sum(c => c.Sum()), totalDocs);
                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        for (int i = 0; i < Utils.Size(invDocFreqs[iinfo]); i++)
                        {
                            if (invDocFreqs[iinfo][i] != 0)
                            {
                                invDocFreqs[iinfo][i] = Math.Log(totalDocs / invDocFreqs[iinfo][i]);
                            }
                        }
                    }

                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        AssertValid(counts[iinfo], lims[iinfo], ngramMaps[iinfo]);

                        int ngramLength = _exes[iinfo].NgramLength;
                        for (int i = 0; i < ngramLength; i++)
                        {
                            _exes[iinfo].NonEmptyLevels[i] = counts[iinfo][i] > 0;
                        }
                    }

                    return(ngramMaps);
                }
        }