public void CanTrainCbowWithProgressCallback() { using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory); string outPath = Path.Combine(_tempDir, "cooking"); int callNum = 0; var args = new UnsupervisedArgs { TrainProgressCallback = (progress, loss, wst, lr, eta) => { callNum++; } }; fastText.Unsupervised(UnsupervisedModel.CBow, "cooking.train.nolabels.txt", outPath, args); callNum.Should().BeGreaterThan(0); fastText.IsModelReady().Should().BeTrue(); fastText.GetModelDimension().Should().Be(100); fastText.ModelPath.Should().Be(outPath + ".bin"); File.Exists(outPath + ".bin").Should().BeTrue(); File.Exists(outPath + ".vec").Should().BeTrue(); }
public void CanTrainCbowModel() { using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory); string outPath = Path.Combine(_tempDir, "cooking"); fastText.Unsupervised(UnsupervisedModel.CBow, "cooking.train.nolabels.txt", outPath); fastText.IsModelReady().Should().BeTrue(); fastText.GetModelDimension().Should().Be(100); fastText.ModelPath.Should().Be(outPath + ".bin"); File.Exists(outPath + ".bin").Should().BeTrue(); File.Exists(outPath + ".vec").Should().BeTrue(); }
public void SkipgramAndCBowLearnDifferentRepresentations() { using var sg = new FastTextWrapper(loggerFactory: _loggerFactory); string outSG = Path.Combine(_tempDir, "cooking"); sg.Unsupervised(UnsupervisedModel.SkipGram, "cooking.train.nolabels.txt", outSG); using var cbow = new FastTextWrapper(loggerFactory: _loggerFactory); string outCbow = Path.Combine(_tempDir, "cooking"); cbow.Unsupervised(UnsupervisedModel.CBow, "cooking.train.nolabels.txt", outCbow); var nnSg = sg.GetNearestNeighbours("pot", 10); var nnCbow = cbow.GetNearestNeighbours("pot", 10); var nnSup = _fixture.FastText.GetNearestNeighbours("pot", 10); void CheckPair(Prediction[] first, Prediction[] second) { int samePredictions = 0; foreach (var prediction in first) { if (second.Any(x => x.Label == prediction.Label)) { samePredictions++; } } // We want less than a half of same predictions. samePredictions.Should().BeLessThan(first.Length / 2); } CheckPair(nnSg, nnCbow); CheckPair(nnSg, nnSup); CheckPair(nnCbow, nnSup); }