Exemplo n.º 1
0
        public void CanHandleUtf8()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outPath = Path.Combine(_tempDir, "rus");

            fastText.Supervised("data.rus.txt", outPath, new SupervisedArgs());

            var labels = fastText.GetLabels();

            labels.Length.Should().Be(2);
            labels.Should().Contain(new[] { "__label__оператор", "__label__выход" });

            var pred = fastText.PredictSingle("Позови оператора");

            pred.Probability.Should().BeGreaterThan(0);
            pred.Label.Should().Be("__label__оператор");

            var sourceWords = File.ReadAllText("data.rus.txt")
                              .Split(new[] { " ", "\r\n", "\n" }, StringSplitOptions.RemoveEmptyEntries)
                              .Where(x => !x.StartsWith("__label__"))
                              .Distinct().ToArray();
            var nn = fastText.GetNearestNeighbours("оператор", 2);

            nn.Length.Should().Be(2);
            sourceWords.Should().Contain(nn.Select(x => x.Label));
            foreach (var prediction in nn)
            {
                prediction.Probability.Should().BeGreaterThan(0);
            }
        }
Exemplo n.º 2
0
        public void SkipgramAndCBowLearnDifferentRepresentations()
        {
            using var sg = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outSG = Path.Combine(_tempDir, "cooking");

            sg.Unsupervised(UnsupervisedModel.SkipGram, "cooking.train.nolabels.txt", outSG);

            using var cbow = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outCbow = Path.Combine(_tempDir, "cooking");

            cbow.Unsupervised(UnsupervisedModel.CBow, "cooking.train.nolabels.txt", outCbow);

            var nnSg   = sg.GetNearestNeighbours("pot", 10);
            var nnCbow = cbow.GetNearestNeighbours("pot", 10);
            var nnSup  = _fixture.FastText.GetNearestNeighbours("pot", 10);

            void CheckPair(Prediction[] first, Prediction[] second)
            {
                int samePredictions = 0;

                foreach (var prediction in first)
                {
                    if (second.Any(x => x.Label == prediction.Label))
                    {
                        samePredictions++;
                    }
                }

                // We want less than a half of same predictions.
                samePredictions.Should().BeLessThan(first.Length / 2);
            }

            CheckPair(nnSg, nnCbow);
            CheckPair(nnSg, nnSup);
            CheckPair(nnCbow, nnSup);
        }