Exemplo n.º 1
0
        public void CanTrainCbowWithProgressCallback()
        {
            using var fastText = new FastTextWrapper(loggerFactory: _loggerFactory);
            string outPath = Path.Combine(_tempDir, "cooking");
            int    callNum = 0;

            var args = new UnsupervisedArgs
            {
                TrainProgressCallback = (progress, loss, wst, lr, eta) =>
                {
                    callNum++;
                }
            };

            fastText.Unsupervised(UnsupervisedModel.CBow, "cooking.train.nolabels.txt", outPath, args);

            callNum.Should().BeGreaterThan(0);

            fastText.IsModelReady().Should().BeTrue();
            fastText.GetModelDimension().Should().Be(100);
            fastText.ModelPath.Should().Be(outPath + ".bin");

            File.Exists(outPath + ".bin").Should().BeTrue();
            File.Exists(outPath + ".vec").Should().BeTrue();
        }
Exemplo n.º 2
0
        public void CanGetDefaultSkipgramArgs()
        {
            var args = new UnsupervisedArgs();

            args.lr.Should().BeApproximately(0.05d, 10e-5);
            args.bucket.Should().Be(2000000);
            args.dim.Should().Be(100);
            args.loss.Should().Be(LossName.NegativeSampling);
            args.model.Should().Be(ModelName.SkipGram);
            args.LabelPrefix.Should().Be("__label__");

            // No need to check all of them here.
        }
Exemplo n.º 3
0
    /// <summary>
    /// Trains a new unsupervised model.
    /// </summary>
    /// <param name="model">Type of unsupervised model: Skipgram or Cbow.</param>
    /// <param name="inputPath">Path to a training set.</param>
    /// <param name="outputPath">Path to write the model to (excluding extension).</param>
    /// <param name="args">Low-level training arguments.</param>
    /// <remarks>Trained model will consist of two files: .bin (main model) and .vec (word vectors).</remarks>
    public void Unsupervised(UnsupervisedModel model, string inputPath, string outputPath, UnsupervisedArgs args)
    {
        ValidatePaths(inputPath, outputPath, args.PretrainedVectors);

        args.model = (ModelName)model;

        var argsStruct = _mapper.Map <FastTextArgsStruct>(args);

        CheckForErrors(Train(
                           _fastText,
                           inputPath,
                           outputPath,
                           argsStruct,
                           new AutotuneArgsStruct(),
                           args.TrainProgressCallback,
                           null,
                           args.LabelPrefix,
                           args.PretrainedVectors,
                           false));
        _maxLabelLen = 0;

        ModelPath = AdjustPath(outputPath, false);
    }