Beispiel #1
0
        public void CanTrainSupervised()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outPath = Path.Combine(_tempDir, "cooking");

            var args     = new SupervisedArgs();
            var tuneArgs = new AutotuneArgs();

            fastText.Supervised("cooking.train.txt", outPath, args, tuneArgs, true);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);
            fastText.ModelPath.Should().Be(outPath + ".bin");

            AssertLabels(fastText.GetLabels());

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();

            var debugArgs = DebugArgs.Load("_train.txt");

            AssertSupervisedArgs(args, debugArgs.ExternalArgs);
            AssertSupervisedArgs(args, debugArgs.ConvertedArgs);
            AssertAutotuneArgs(tuneArgs, debugArgs.ExternalTune);
            AssertAutotuneArgs(tuneArgs, debugArgs.ConvertedTune);

            debugArgs.ExternalInput.Should().Be("cooking.train.txt");
            debugArgs.ConvertedInput.Should().Be("cooking.train.txt");
            debugArgs.ExternalOutput.Should().Be(outPath);
            debugArgs.ConvertedOutput.Should().Be(outPath);
        }
Beispiel #2
0
        public void CanTrainSupervisedWithProgressCallback()
        {
            using var fastText = new FastTextWrapper();
            string outPath = Path.Combine(_tempDir, "cooking");
            int    callNum = 0;

            var args = new SupervisedArgs
            {
                TrainProgressCallback = (progress, loss, wst, lr, eta) =>
                {
                    callNum++;
                }
            };

            fastText.Supervised("cooking.train.txt", outPath, args);

            callNum.Should().BeGreaterThan(0);
            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);
            fastText.ModelPath.Should().Be(outPath + ".bin");

            AssertLabels(fastText.GetLabels());

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();
        }
Beispiel #3
0
        public void CanHandleUtf8()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outPath = Path.Combine(_tempDir, "rus");

            fastText.Supervised("data.rus.txt", outPath, new SupervisedArgs());

            var labels = fastText.GetLabels();

            labels.Length.Should().Be(2);
            labels.Should().Contain(new[] { "__label__оператор", "__label__выход" });

            var pred = fastText.PredictSingle("Позови оператора");

            pred.Probability.Should().BeGreaterThan(0);
            pred.Label.Should().Be("__label__оператор");

            var sourceWords = File.ReadAllText("data.rus.txt")
                              .Split(new[] { " ", "\r\n", "\n" }, StringSplitOptions.RemoveEmptyEntries)
                              .Where(x => !x.StartsWith("__label__"))
                              .Distinct().ToArray();
            var nn = fastText.GetNearestNeighbours("оператор", 2);

            nn.Length.Should().Be(2);
            sourceWords.Should().Contain(nn.Select(x => x.Label));
            foreach (var prediction in nn)
            {
                prediction.Probability.Should().BeGreaterThan(0);
            }
        }
Beispiel #4
0
 private static void Test(FastTextWrapper fastText)
 {
     var labels      = fastText.GetLabels();
     var prediction  = fastText.PredictSingle("Can I use a larger crockpot than the recipe calls for?");
     var predictions = fastText.PredictMultiple("Can I use a larger crockpot than the recipe calls for?", 4);
     var vector      = fastText.GetSentenceVector("Can I use a larger crockpot than the recipe calls for?");
 }
Beispiel #5
0
        public void CanLoadSupervisedModel()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            fastText.LoadModel(_fixture.FastText.ModelPath);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);

            AssertLabels(fastText.GetLabels());
        }
 private static void LoadModel()
 {
     using (var fastText = new FastTextWrapper())
     {
         fastText.LoadModel(@"D:\__Models\cooking.bin");
         var labels      = fastText.GetLabels();
         var prediction  = fastText.PredictSingle("Can I use a larger crockpot than the recipe calls for?");
         var predictions = fastText.PredictMultiple("Can I use a larger crockpot than the recipe calls for?", 4);
         var vector      = fastText.GetSentenceVector("Can I use a larger crockpot than the recipe calls for?");
     }
 }
        static void Main(string[] args)
        {
            Log.Logger = new LoggerConfiguration()
                         .MinimumLevel.Debug()
                         .WriteTo.Console(theme: ConsoleTheme.None)
                         .CreateLogger();

            var log     = Log.ForContext <Program>();
            var tempDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString("N"));

            Directory.CreateDirectory(tempDir);

            log.Information($"Temp dir: {tempDir}");

            string outPath  = Path.Combine(tempDir, "cooking.bin");
            var    fastText = new FastTextWrapper(loggerFactory: new LoggerFactory(new[] { new SerilogLoggerProvider() }));

            var ftArgs = FastTextArgs.SupervisedDefaults();

            ftArgs.epoch      = 15;
            ftArgs.lr         = 1;
            ftArgs.dim        = 300;
            ftArgs.wordNgrams = 2;
            ftArgs.minn       = 3;
            ftArgs.maxn       = 6;
            fastText.Supervised("cooking.train.txt", outPath, ftArgs);

            try
            {
                File.Delete("_debug.txt");
            }
            catch
            {
            }

            var result = fastText.TestInternal("cooking.valid.txt", 1, 0.0f, true);

            log.Information($"Results:\n\tPrecision: {result.GlobalMetrics.GetPrecision()}" +
                            $"\n\tRecall: {result.GlobalMetrics.GetRecall()}" +
                            $"\n\tF1: {result.GlobalMetrics.GetF1()}");

            var curve = result.GetPrecisionRecallCurve();

            var(_, debugCurve) = TestResult.LoadDebugResult("_debug.txt", fastText.GetLabels());

            string plotPath = PlotCurves(tempDir, new [] { curve, debugCurve });

            log.Information($"Precision-Recall plot: {plotPath}");

            Console.WriteLine("\nPress any key to exit.");
            Console.ReadKey();

            Directory.Delete(tempDir, true);
        }
        public void CanTrainModelWithOldApi()
        {
            var    fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outPath  = Path.Combine(_tempDir, "cooking");

            fastText.Train("cooking.train.txt", outPath, FastTextArgs.SupervisedDefaults());

            CheckLabels(fastText.GetLabels());

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();
        }
        public void CanTrainSupervisedWithNoLogging()
        {
            var    fastText = new FastTextWrapper();
            string outPath  = Path.Combine(_tempDir, "cooking");

            fastText.Supervised("cooking.train.txt", outPath, FastTextArgs.SupervisedDefaults());

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);

            CheckLabels(fastText.GetLabels());

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();
        }
        public void CanUsePretrainedVectorsForSupervisedModel()
        {
            var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);

            string outPath = Path.Combine(_tempDir, "cooking");
            var    args    = FastTextArgs.SupervisedDefaults();

            args.PretrainedVectors = "cooking.unsup.300.vec";
            args.dim = 300;

            fastText.Supervised("cooking.train.txt", outPath, args);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(300);

            CheckLabels(fastText.GetLabels());

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();
        }
Beispiel #11
0
        public void CanQuantizeLoadedSupervisedModel()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            fastText.LoadModel(_fixture.FastText.ModelPath);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);

            AssertLabels(fastText.GetLabels());

            string newPath = Path.Combine(Path.GetDirectoryName(_fixture.FastText.ModelPath), Path.GetFileNameWithoutExtension(_fixture.FastText.ModelPath));

            fastText.Quantize();

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);
            fastText.ModelPath.Should().Be(newPath + ".ftz");

            File.Exists(newPath + ".ftz").Should().BeTrue();
            File.Exists(newPath + ".vec").Should().BeTrue();
        }
Beispiel #12
0
        public void CanTrainSupervisedWithPretrainedVectors()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);

            string outPath = Path.Combine(_tempDir, "cooking");
            var    args    = new SupervisedArgs();

            args.PretrainedVectors = "cooking.unsup.300.vec";
            args.dim = 300;

            fastText.Supervised("cooking.train.txt", outPath, args, new AutotuneArgs(), true);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(300);
            fastText.ModelPath.Should().Be(outPath + ".bin");

            AssertLabels(fastText.GetLabels());

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();
        }
Beispiel #13
0
        public void CanTrainSupervisedWithRelativeOutput()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);

            var args     = new SupervisedArgs();
            var tuneArgs = new AutotuneArgs();

            fastText.Supervised("cooking.train.txt", "cooking", args, tuneArgs, true);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);
            fastText.ModelPath.Should().Be("cooking.bin");

            AssertLabels(fastText.GetLabels());

            File.Exists("cooking.bin").Should().BeTrue();
            File.Exists("cooking.vec").Should().BeTrue();

            File.Delete("cooking.bin");
            File.Delete("cooking.vec");
        }
    static void Main(string[] args)
    {
        Log.Logger = new LoggerConfiguration()
                     .MinimumLevel.Debug()
                     .WriteTo.Console(theme: ConsoleTheme.None)
                     .CreateLogger();

        var log     = Log.ForContext <Program>();
        var tempDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString("N"));

        Directory.CreateDirectory(tempDir);

        log.Information($"Temp dir: {tempDir}");

        string outPath  = Path.Combine(tempDir, "cooking.bin");
        var    fastText = new FastTextWrapper(loggerFactory: new LoggerFactory(new[] { new SerilogLoggerProvider() }));

        AnsiConsole.Progress()
        .Start(ctx =>
        {
            var task   = ctx.AddTask("Training");
            var ftArgs = new SupervisedArgs
            {
                epoch                 = 15,
                lr                    = 1,
                dim                   = 300,
                wordNgrams            = 2,
                minn                  = 3,
                maxn                  = 6,
                verbose               = 0,
                TrainProgressCallback = (progress, loss, wst, lr, eta) =>
                {
                    task.Value       = Math.Ceiling(progress * 100);
                    task.Description = $"Loss: {loss:N3}, words/thread/sec: {wst}, LR: {lr:N5}, ETA: {eta}";
                }
            };

            fastText.Supervised("cooking.train.txt", outPath, ftArgs);
        });

        try
        {
            File.Delete("_debug.txt");
        }
        catch
        {
        }

        log.Information("Validating model on the test set");

        var result = fastText.TestInternal("cooking.valid.txt", 1, 0.0f, true);

        log.Information($"Results:\n\tPrecision: {result.GlobalMetrics.GetPrecision()}" +
                        $"\n\tRecall: {result.GlobalMetrics.GetRecall()}" +
                        $"\n\tF1: {result.GlobalMetrics.GetF1()}");

        var curve = result.GetPrecisionRecallCurve();

        var(_, debugCurve) = TestResult.LoadDebugResult("_debug.txt", fastText.GetLabels());

        string plotPath = PlotCurves(tempDir, new [] { curve, debugCurve });

        log.Information($"Precision-Recall plot: {plotPath}");

        Console.WriteLine("\nPress any key to exit.");
        Console.ReadKey();

        Directory.Delete(tempDir, true);
    }