Esempio n. 1
    public void TestSeq2SeqInference()
        var opts = new Seq2SeqOptions();

        opts.ModelFilePath        = "seq2seq_mt_enu_chs_tiny_test.model";
        opts.MaxTestSrcSentLength = 110;
        opts.MaxTestTgtSentLength = 110;
        opts.ProcessorType        = ProcessorTypeEnums.CPU;
        opts.DeviceIds            = "0";

        var             seq2seq         = new Seq2Seq(opts);
        DecodingOptions decodingOptions = opts.CreateDecodingOptions();

        List <List <List <string> > > groupBatchTokens = BuildInputGroupBatchTokens("▁yes , ▁solutions ▁do ▁exist .");
        var nrs        = seq2seq.Test <Seq2SeqCorpusBatch>(groupBatchTokens, null, decodingOptions);
        var out_tokens = nrs[0].Output[0][0];
        var output     = string.Join(" ", out_tokens);

        Assert.IsTrue(output == "<s> ▁是的 , 解决方案 存在 。 </s>");

        groupBatchTokens = BuildInputGroupBatchTokens("▁a ▁question ▁of ▁climate .");
        nrs        = seq2seq.Test <Seq2SeqCorpusBatch>(groupBatchTokens, null, decodingOptions);
        out_tokens = nrs[0].Output[0][0];
        output     = string.Join(" ", out_tokens);
        Assert.IsTrue(output == "<s> ▁ 气候 问题 。 </s>");
Esempio n. 2
    public void TestSeq2SeqTraining()
        // Build configs for training
        Seq2SeqOptions opts = CreateOptions(trainFolderPath, validFolderPath);

        DecodingOptions decodingOptions = opts.CreateDecodingOptions();

        // Load training corpus
        var trainCorpus = new Seq2SeqCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, batchSize: opts.BatchSize, shuffleBlockSize: opts.ShuffleBlockSize, maxSrcSentLength: opts.MaxTrainSrcSentLength, maxTgtSentLength: opts.MaxTrainTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence);

        // Load valid corpus
        var validCorpusList = new List <Seq2SeqCorpus>();

        if (!opts.ValidCorpusPaths.IsNullOrEmpty())
            string[] validCorpusPathList = opts.ValidCorpusPaths.Split(';');
            foreach (var validCorpusPath in validCorpusPathList)
                validCorpusList.Add(new Seq2SeqCorpus(validCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxTestSrcSentLength, opts.MaxTestTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence));

        // Create learning rate
        ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

        // Create optimizer
        IOptimizer optimizer = Misc.CreateOptimizer(opts);

        // Build vocabularies for training
        (var srcVocab, var tgtVocab) = trainCorpus.BuildVocabs(opts.SrcVocabSize, opts.TgtVocabSize, opts.SharedEmbeddings);

        // Create metrics
        List <IMetric> metrics = new List <IMetric> {
            new BleuMetric()

        //New training
        var ss = new Seq2Seq(opts, srcVocab, tgtVocab);

        // Add event handler for monitoring
        ss.StatusUpdateWatcher += Ss_StatusUpdateWatcher;
        ss.EpochEndWatcher     += Ss_EpochEndWatcher;

        // Kick off training
        ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: validCorpusList.ToArray(), learningRate: learningRate, optimizer: optimizer, metrics: metrics, decodingOptions: decodingOptions);

        ss.SaveModel(suffix: ".test");

        // Check if model file exists
        Assert.IsTrue(File.Exists(opts.ModelFilePath + ".test"));
Esempio n. 3
        private static void Main(string[] args)
                //Parse command line
                ArgParser argParser = new ArgParser(args, opts);
                if (!opts.ConfigFilePath.IsNullOrEmpty())
                    Console.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                    opts = JsonConvert.DeserializeObject <Seq2SeqOptions>(File.ReadAllText(opts.ConfigFilePath));

                Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{opts.Task}_{Utils.GetTimeStamp(DateTime.Now)}.log";
                ShowOptions(args, opts);

                DecodingOptions decodingOptions = opts.CreateDecodingOptions();
                Seq2Seq         ss = null;
                if (opts.Task == ModeEnums.Train)
                    // Load train corpus
                    var trainCorpus = new Seq2SeqCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, batchSize: opts.BatchSize, shuffleBlockSize: opts.ShuffleBlockSize,
                                                        maxSrcSentLength: opts.MaxTrainSrcSentLength, maxTgtSentLength: opts.MaxTrainTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence);

                    // Load valid corpus
                    var validCorpusList = new List <Seq2SeqCorpus>();
                    if (!opts.ValidCorpusPaths.IsNullOrEmpty())
                        string[] validCorpusPathList = opts.ValidCorpusPaths.Split(';');
                        foreach (var validCorpusPath in validCorpusPathList)
                            validCorpusList.Add(new Seq2SeqCorpus(validCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxTestSrcSentLength, opts.MaxTestTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence));

                    // Create learning rate
                    ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                    // Create optimizer
                    IOptimizer optimizer = Misc.CreateOptimizer(opts);

                    // Create metrics
                    List <IMetric> metrics = CreateMetrics();

                    if (!opts.ModelFilePath.IsNullOrEmpty() && File.Exists(opts.ModelFilePath))
                        //Incremental training
                        Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                        ss = new Seq2Seq(opts);
                        // Load or build vocabulary
                        Vocab srcVocab = null;
                        Vocab tgtVocab = null;
                        if (!opts.SrcVocab.IsNullOrEmpty() && !opts.TgtVocab.IsNullOrEmpty())
                            Logger.WriteLine($"Loading source vocabulary from '{opts.SrcVocab}' and target vocabulary from '{opts.TgtVocab}'. Shared vocabulary is '{opts.SharedEmbeddings}'");
                            if (opts.SharedEmbeddings == true && (opts.SrcVocab != opts.TgtVocab))
                                throw new ArgumentException("The source and target vocabularies must be identical if their embeddings are shared.");

                            // Vocabulary files are specified, so we load them
                            srcVocab = new Vocab(opts.SrcVocab);
                            tgtVocab = new Vocab(opts.TgtVocab);
                            Logger.WriteLine($"Building vocabulary from training corpus. Shared vocabulary is '{opts.SharedEmbeddings}'");
                            // We don't specify vocabulary, so we build it from train corpus

                            (srcVocab, tgtVocab) = trainCorpus.BuildVocabs(opts.SrcVocabSize, opts.TgtVocabSize, opts.SharedEmbeddings);

                        //New training
                        ss = new Seq2Seq(opts, srcVocab, tgtVocab);

                    // Add event handler for monitoring
                    ss.StatusUpdateWatcher += Misc.Ss_StatusUpdateWatcher;
                    ss.EvaluationWatcher   += Ss_EvaluationWatcher;

                    // Kick off training
                    ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: validCorpusList.ToArray(), learningRate: learningRate, optimizer: optimizer, metrics: metrics, decodingOptions: decodingOptions);
                else if (opts.Task == ModeEnums.Valid)
                    Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPaths}'");

                    // Create metrics
                    List <IMetric> metrics = CreateMetrics();

                    // Load valid corpus
                    Seq2SeqCorpus validCorpus = new Seq2SeqCorpus(opts.ValidCorpusPaths, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxTestSrcSentLength, opts.MaxTestTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence);

                    ss = new Seq2Seq(opts);
                    ss.EvaluationWatcher += Ss_EvaluationWatcher;
                    ss.Valid(validCorpus: validCorpus, metrics: metrics, decodingOptions: decodingOptions);
                else if (opts.Task == ModeEnums.Test)
                    if (File.Exists(opts.OutputFile))
                        Logger.WriteLine(Logger.Level.err, ConsoleColor.Yellow, $"Output file '{opts.OutputFile}' exist. Delete it.");

                    //Test trained model
                    ss = new Seq2Seq(opts);
                    Stopwatch stopwatch = Stopwatch.StartNew();

                    if (String.IsNullOrEmpty(opts.OutputPromptFile))
                        ss.Test <Seq2SeqCorpusBatch>(opts.InputTestFile, opts.OutputFile, opts.BatchSize, decodingOptions, opts.SrcSentencePieceModelPath, opts.TgtSentencePieceModelPath);
                        Logger.WriteLine($"Test with prompt file '{opts.OutputPromptFile}'");
                        ss.Test <Seq2SeqCorpusBatch>(opts.InputTestFile, opts.OutputPromptFile, opts.OutputFile, opts.BatchSize, decodingOptions, opts.SrcSentencePieceModelPath, opts.TgtSentencePieceModelPath);


                    Logger.WriteLine($"Test mode execution time elapsed: '{stopwatch.Elapsed}'");
                else if (opts.Task == ModeEnums.DumpVocab)
                    ss = new Seq2Seq(opts);
                    ss.DumpVocabToFiles(opts.SrcVocab, opts.TgtVocab);
                    Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Task '{opts.Task}' is not supported.");
            catch (Exception err)
                Logger.WriteLine($"Exception: '{err.Message}'");
                Logger.WriteLine($"Call stack: '{err.StackTrace}'");
Esempio n. 4
        static void Main(string[] args)
            if (args.Length > 0)
                int.TryParse(args[0], out times);
            if (args.Length > 1)
                int.TryParse(args[1], out dim);
            if (args.Length > 2)
                int.TryParse(args[2], out hdim);

            //Random r = new Random(5);
                S2S = Seq2Seq.Load();
            catch (Exception ex)
                var msg = LogService.Exception(ex);
                S2S = null;

            if (S2S == null)
                S2S = new Seq2Seq(dim, hdim, dep, input, output, true);

            int c        = 0;
            var lastdone = DateTime.Now;

            S2S.IterationDone += (a1, a2) =>
                CostEventArg ep    = a2 as CostEventArg;
                var          dtnow = DateTime.Now;
                if ((dtnow - lastdone).TotalSeconds >= 10)//c % 300 == 0 ||
                    lastdone = dtnow;
                    Console.WriteLine($"训练次数 {ep.Iteration + 1}/{times} 完成行数 {c}");
            S2S.EpochDone += (o, e) => {
                CostEventArg ep = e as CostEventArg;
                Console.WriteLine($"训练批次 {ep.Iteration + 1}/{times} 完成");
            S2S.TrainDone += (o, e) => {
                CostEventArg ep = e as CostEventArg;
                Console.WriteLine($"训练任务 {times}次 完成");
            S2S.TrainStart += (o, e) => {
                CostEventArg ep = e as CostEventArg;
                Console.WriteLine($"训练开始 将训{times}次");

            //创建模式 不是加载模式
            if (S2S.newType == "new")
                MainThread = new Thread(new ThreadStart(Train));

            ReadThread = new Thread(new ThreadStart(ReadingConsole));

            //System.Threading.AutoResetEvent resetEvent = new AutoResetEvent(false);