Exemple #1
0
        public void CanTrainCbowWithProgressCallback()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outPath = Path.Combine(_tempDir, "cooking");
            int    callNum = 0;

            var args = new UnsupervisedArgs
            {
                TrainProgressCallback = (progress, loss, wst, lr, eta) =>
                {
                    callNum++;
                }
            };

            fastText.Unsupervised(UnsupervisedModel.CBow, "cooking.train.nolabels.txt", outPath, args);

            callNum.Should().BeGreaterThan(0);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);
            fastText.ModelPath.Should().Be(outPath + ".bin");

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();
        }
Exemple #2
0
        public void CanTrainCbowModel()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outPath = Path.Combine(_tempDir, "cooking");

            fastText.Unsupervised(UnsupervisedModel.CBow, "cooking.train.nolabels.txt", outPath);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);
            fastText.ModelPath.Should().Be(outPath + ".bin");

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();
        }
Exemple #3
0
        public void SkipgramAndCBowLearnDifferentRepresentations()
        {
            using var sg = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outSG = Path.Combine(_tempDir, "cooking");

            sg.Unsupervised(UnsupervisedModel.SkipGram, "cooking.train.nolabels.txt", outSG);

            using var cbow = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outCbow = Path.Combine(_tempDir, "cooking");

            cbow.Unsupervised(UnsupervisedModel.CBow, "cooking.train.nolabels.txt", outCbow);

            var nnSg   = sg.GetNearestNeighbours("pot", 10);
            var nnCbow = cbow.GetNearestNeighbours("pot", 10);
            var nnSup  = _fixture.FastText.GetNearestNeighbours("pot", 10);

            void CheckPair(Prediction[] first, Prediction[] second)
            {
                int samePredictions = 0;

                foreach (var prediction in first)
                {
                    if (second.Any(x => x.Label == prediction.Label))
                    {
                        samePredictions++;
                    }
                }

                // We want less than a half of same predictions.
                samePredictions.Should().BeLessThan(first.Length / 2);
            }

            CheckPair(nnSg, nnCbow);
            CheckPair(nnSg, nnSup);
            CheckPair(nnCbow, nnSup);
        }