public void SortFeaturesByCountWhenApplyingCountLimit()
        {
            // arrange
            var settings = new CountVectorizerSettings
            {
                MaxFeatures = 5
            };

            var target       = new CountVectorizer(settings);
            var trainingData = GetTrainingData();

            target.Fit(trainingData);

            var vectorizerWithoutLimit = GetFittedVectorizer(trainingData);

            IEnumerable <uint> getTopCounts(IEnumerable <IDictionary <string, uint> > tokensCounts) =>
            tokensCounts.SelectMany(kv => kv.Values)
            .OrderByDescending(v => v)
            .Take((int)settings.MaxFeatures);

            var expectedTopCounts = getTopCounts(vectorizerWithoutLimit.Transform(trainingData));

            // act
            var result = target.Transform(trainingData);

            // assert
            var resultTopCounts = getTopCounts(result);

            Assert.Equal(expectedTopCounts, resultTopCounts);
        }
        public void CanLimitFeaturesCount()
        {
            // arrange
            var settings = new CountVectorizerSettings
            {
                MaxFeatures = 5
            };

            var target       = new CountVectorizer(settings);
            var trainingData = GetTrainingData();

            // act
            target.Fit(trainingData);

            // assert
            Assert.True(target.Vocabulary.Count() == 5);
        }
        public void CanFit()
        {
            // arrange
            var target       = new CountVectorizer();
            var trainingData = new string[]
            {
                "Some cool text",
                "Another cool text"
            };

            // act
            var res = target.Fit(trainingData);

            // assert
            Assert.Same(target, res);
            Assert.NotNull(res.Vocabulary);
        }
Ejemplo n.º 4
0
        public void TestCountVectorizer()
        {
            DataFrame input = _spark.Sql("SELECT array('hello', 'I', 'AM', 'a', 'string', 'TO', " +
                                         "'TOKENIZE') as input from range(100)");

            const string inputColumn  = "input";
            const string outputColumn = "output";
            const double minDf        = 1;
            const double minTf        = 10;
            const int    vocabSize    = 10000;
            const bool   binary       = false;

            var countVectorizer = new CountVectorizer();

            countVectorizer
            .SetInputCol(inputColumn)
            .SetOutputCol(outputColumn)
            .SetMinDF(minDf)
            .SetMinTF(minTf)
            .SetVocabSize(vocabSize);

            Assert.IsType <CountVectorizerModel>(countVectorizer.Fit(input));
            Assert.Equal(inputColumn, countVectorizer.GetInputCol());
            Assert.Equal(outputColumn, countVectorizer.GetOutputCol());
            Assert.Equal(minDf, countVectorizer.GetMinDF());
            Assert.Equal(minTf, countVectorizer.GetMinTF());
            Assert.Equal(vocabSize, countVectorizer.GetVocabSize());
            Assert.Equal(binary, countVectorizer.GetBinary());

            using (var tempDirectory = new TemporaryDirectory())
            {
                string savePath = Path.Join(tempDirectory.Path, "countVectorizer");
                countVectorizer.Save(savePath);

                CountVectorizer loadedVectorizer = CountVectorizer.Load(savePath);
                Assert.Equal(countVectorizer.Uid(), loadedVectorizer.Uid());
            }

            Assert.NotEmpty(countVectorizer.ExplainParams());
            Assert.NotEmpty(countVectorizer.ToString());
        }