Ejemplo n.º 1
0
        static void Main(string[] args)
        {
            Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            Options   options   = new Options();
            ArgParser argParser = new ArgParser(args, options);

            ShowOptions(args, options);

            AttentionSeq2Seq ss       = null;
            ArchTypeEnums    archType = (ArchTypeEnums)options.ArchType;

            //Parse device ids from options
            string[] deviceIdsStr = options.DeviceIds.Split(',');
            int[]    deviceIds    = new int[deviceIdsStr.Length];
            for (int i = 0; i < deviceIdsStr.Length; i++)
            {
                deviceIds[i] = int.Parse(deviceIdsStr[i]);
            }

            if (String.Equals(options.TaskName, "train", StringComparison.InvariantCultureIgnoreCase))
            {
                Corpus trainCorpus = new Corpus(options.TrainCorpusPath, options.SrcLang, options.TgtLang, options.BatchSize * deviceIds.Length, options.ShuffleBlockSize);
                if (File.Exists(options.ModelFilePath) == false)
                {
                    //New training
                    ss = new AttentionSeq2Seq(options.WordVectorSize, options.HiddenSize, options.Depth, trainCorpus, options.SrcVocab, options.TgtVocab, options.SrcEmbeddingModelFilePath, options.TgtEmbeddingModelFilePath,
                                              true, options.ModelFilePath, options.BatchSize, options.DropoutRatio, archType, deviceIds);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{options.ModelFilePath}'...");
                    ss             = new AttentionSeq2Seq(options.ModelFilePath, options.BatchSize, archType, deviceIds);
                    ss.TrainCorpus = trainCorpus;
                }

                ss.IterationDone += ss_IterationDone;
                ss.Train(100, options.LearningRate, options.GradClip);
            }
            else if (String.Equals(options.TaskName, "test", StringComparison.InvariantCultureIgnoreCase))
            {
                //Test trained model
                ss = new AttentionSeq2Seq(options.ModelFilePath, 1, archType, deviceIds);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(options.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    List <string> outputWords = ss.Predict(line.ToLower().Trim().Split(' ').ToList());
                    outputLines.Add(String.Join(" ", outputWords));
                }

                File.WriteAllLines(options.OutputTestFile, outputLines);
            }
            else
            {
                argParser.Usage();
            }
        }
Ejemplo n.º 2
0
        static void Main(string[] args)
        {
            Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            Options   options   = new Options();
            ArgParser argParser = new ArgParser(args, options);

            AttentionSeq2Seq ss = null;

            if (String.Equals(options.TaskName, "train", StringComparison.InvariantCultureIgnoreCase))
            {
                Corpus trainCorpus = new Corpus(options.TrainCorpusPath, options.SrcLang, options.TgtLang, options.ShuffleBlockSize);
                if (File.Exists(options.ModelFilePath) == false)
                {
                    ss = new AttentionSeq2Seq(options.WordVectorSize, options.HiddenSize, options.Depth, trainCorpus, options.SrcVocab, options.TgtVocab, options.SrcEmbeddingModelFilePath, options.TgtEmbeddingModelFilePath,
                                              options.SparseFeature, true, options.ModelFilePath);
                }
                else
                {
                    Logger.WriteLine($"Loading model from '{options.ModelFilePath}'...");
                    ss = new AttentionSeq2Seq();
                    ss.Load(options.ModelFilePath);
                    ss.TrainCorpus = trainCorpus;
                }

                Logger.WriteLine($"Source Language = '{options.SrcLang}'");
                Logger.WriteLine($"Target Language = '{options.TgtLang}'");
                Logger.WriteLine($"SSE Enable = '{System.Numerics.Vector.IsHardwareAccelerated}'");
                Logger.WriteLine($"SSE Size = '{System.Numerics.Vector<float>.Count * 32}'");
                Logger.WriteLine($"Processor counter = '{Environment.ProcessorCount}'");
                Logger.WriteLine($"Hidden Size = '{ss.HiddenSize}'");
                Logger.WriteLine($"Word Vector Size = '{ss.WordVectorSize}'");
                Logger.WriteLine($"Learning Rate = '{options.LearningRate}'");
                Logger.WriteLine($"Network Layer = '{ss.Depth}'");
                Logger.WriteLine($"Use Sparse Feature = '{options.SparseFeature}'");

                ss.IterationDone += ss_IterationDone;
                ss.Train(300, options.LearningRate);
            }
            else if (String.Equals(options.TaskName, "test", StringComparison.InvariantCultureIgnoreCase))
            {
                ss = new AttentionSeq2Seq();
                ss.Load(options.ModelFilePath);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(options.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    List <string> outputWords = ss.Predict(line.ToLower().Trim().Split(' ').ToList());
                    outputLines.Add(String.Join(" ", outputWords));
                }

                File.WriteAllLines(options.OutputTestFile, outputLines);
            }
            else
            {
                argParser.Usage();
            }
        }
Ejemplo n.º 3
0
        private void Predict_Click(object sender, EventArgs e)
        {
            var pred = ss.Predict(SrcTxtBox.Text.ToLower().Trim().Split(' ').ToList());

            ResultTxtBox.Clear();

            int i = 0;

            foreach (var item in pred)
            {
                ResultTxtBox.Text += item + " ";

                i++;
            }
        }
Ejemplo n.º 4
0
        static void ReadingConsole()
        {
            while (true)
            {
                var Line = RemoveAccentMark(Console.ReadLine().ToLower());
                if (!MainThread.IsAlive)
                {
                    if (Line == "!resume")
                    {
                        MainThread = new Thread(new ThreadStart(Train));
                        MainThread.Start();
                    }
                    else
                    {
                        try
                        {
                            var pred = S2S.Predict(Line.Trim().Split(' ').ToList());
                            Console.WriteLine($"<< {Line}");

                            Console.Write(">>");
                            for (int i = 0; i < pred.Count; i++)
                            {
                                Console.Write(pred[i] + " ");
                            }
                            Console.WriteLine();
                        }
                        catch (Exception)
                        {
                            Console.WriteLine("Deja de usar palabras raras crj");
                        }
                    }
                }
                else
                {
                    if (Line == "!stop")
                    {
                        MainThread.Abort();
                        Console.WriteLine("Process stoped");
                        S2S.Save();
                    }
                }
            }
        }
Ejemplo n.º 5
0
        private static void Main(string[] args)
        {
            try
            {
                Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";
                ShowOptions(args);

                //Parse command line
                Options   opts      = new Options();
                ArgParser argParser = new ArgParser(args, opts);

                if (string.IsNullOrEmpty(opts.ConfigFilePath) == false)
                {
                    Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                    opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
                }

                AttentionSeq2Seq   ss            = null;
                ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
                EncoderTypeEnums   encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
                DecoderTypeEnums   decoderType   = (DecoderTypeEnums)Enum.Parse(typeof(DecoderTypeEnums), opts.DecoderType);
                ModeEnums          mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);
                ShuffleEnums       shuffleType   = (ShuffleEnums)Enum.Parse(typeof(ShuffleEnums), opts.ShuffleType);

                string[] cudaCompilerOptions = String.IsNullOrEmpty(opts.CompilerOptions) ? null : opts.CompilerOptions.Split(' ', StringSplitOptions.RemoveEmptyEntries);

                //Parse device ids from options
                int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
                if (mode == ModeEnums.Train)
                {
                    // Load train corpus
                    ParallelCorpus trainCorpus = new ParallelCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, batchSize: opts.BatchSize, shuffleBlockSize: opts.ShuffleBlockSize,
                                                                    maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: shuffleType);
                    // Load valid corpus
                    ParallelCorpus validCorpus = string.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxSrcSentLength, opts.MaxTgtSentLength);

                    // Create learning rate
                    ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                    // Create optimizer
                    AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                    // Create metrics
                    List <IMetric> metrics = new List <IMetric>
                    {
                        new BleuMetric(),
                        new LengthRatioMetric()
                    };


                    if (!String.IsNullOrEmpty(opts.ModelFilePath) && File.Exists(opts.ModelFilePath))
                    {
                        //Incremental training
                        Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                        ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds,
                                                  isSrcEmbTrainable: opts.IsSrcEmbeddingTrainable, isTgtEmbTrainable: opts.IsTgtEmbeddingTrainable, isEncoderTrainable: opts.IsEncoderTrainable, isDecoderTrainable: opts.IsDecoderTrainable,
                                                  maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    }
                    else
                    {
                        // Load or build vocabulary
                        Vocab vocab = null;
                        if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab))
                        {
                            // Vocabulary files are specified, so we load them
                            vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                        }
                        else
                        {
                            // We don't specify vocabulary, so we build it from train corpus
                            vocab = new Vocab(trainCorpus);
                        }

                        //New training
                        ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                                  srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath, vocab: vocab, modelFilePath: opts.ModelFilePath,
                                                  dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, decoderType: decoderType,
                                                  maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, enableCoverageModel: opts.EnableCoverageModel, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    }

                    // Add event handler for monitoring
                    ss.IterationDone += ss_IterationDone;

                    // Kick off training
                    ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics);
                }
                else if (mode == ModeEnums.Valid)
                {
                    Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                    // Create metrics
                    List <IMetric> metrics = new List <IMetric>
                    {
                        new BleuMetric(),
                        new LengthRatioMetric()
                    };

                    // Load valid corpus
                    ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxSrcSentLength, opts.MaxTgtSentLength);

                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    ss.Valid(validCorpus: validCorpus, metrics: metrics);
                }
                else if (mode == ModeEnums.Test)
                {
                    Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                    //Test trained model
                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, memoryUsageRatio: opts.MemoryUsageRatio,
                                              shuffleType: shuffleType, maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, compilerOptions: cudaCompilerOptions);

                    List <string> outputLines     = new List <string>();
                    string[]      data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                    foreach (string line in data_sents_raw1)
                    {
                        if (opts.BeamSearch > 1)
                        {
                            // Below support beam search
                            List <List <string> > outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch);
                            outputLines.AddRange(outputWordsList.Select(x => string.Join(" ", x)));
                        }
                        else
                        {
                            var outputTokensBatch = ss.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList()));
                            outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x)));
                        }
                    }

                    File.WriteAllLines(opts.OutputTestFile, outputLines);
                }
                else if (mode == ModeEnums.DumpVocab)
                {
                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, compilerOptions: cudaCompilerOptions);
                    ss.DumpVocabToFiles(opts.SrcVocab, opts.TgtVocab);
                }
                else
                {
                    argParser.Usage();
                }
            }
            catch (Exception err)
            {
                Logger.WriteLine($"Exception: '{err.Message}'");
                Logger.WriteLine($"Call stack: '{err.StackTrace}'");
            }
        }
Ejemplo n.º 6
0
        static void Main(string[] args)
        {
            Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            //Parse command line
            Options   opts      = new Options();
            ArgParser argParser = new ArgParser(args, opts);

            AttentionSeq2Seq ss          = null;
            ArchTypeEnums    archType    = (ArchTypeEnums)Enum.Parse(typeof(ArchTypeEnums), opts.ArchType);
            EncoderTypeEnums encoderType = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
            ModeEnums        mode        = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);


            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();

            if (mode == ModeEnums.Train)
            {
                ShowOptions(args, opts);

                Corpus trainCorpus = new Corpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize * deviceIds.Length,
                                                opts.ShuffleBlockSize, opts.MaxSentLength);
                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                              trainCorpus: trainCorpus, srcVocabFilePath: opts.SrcVocab, tgtVocabFilePath: opts.TgtVocab,
                                              srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath,
                                              modelFilePath: opts.ModelFilePath, batchSize: opts.BatchSize, dropoutRatio: opts.DropoutRatio,
                                              archType: archType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, warmupSteps: opts.WarmUpSteps, encoderType: encoderType);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    ss             = new AttentionSeq2Seq(opts.ModelFilePath, opts.BatchSize, archType, deviceIds);
                    ss.TrainCorpus = trainCorpus;
                }

                ss.IterationDone += ss_IterationDone;
                ss.Train(opts.MaxEpochNum, opts.LearningRate, opts.GradClip);
            }
            else if (mode == ModeEnums.Test)
            {
                //Test trained model
                ss = new AttentionSeq2Seq(opts.ModelFilePath, 1, archType, deviceIds);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    List <List <string> > outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch);
                    outputLines.AddRange(outputWordsList.Select(x => String.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputTestFile, outputLines);
            }
            else if (mode == ModeEnums.VisualizeNetwork)
            {
                ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth,
                                          decoderLayerDepth: opts.DecoderLayerDepth, trainCorpus: null, srcVocabFilePath: null, tgtVocabFilePath: null,
                                          srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null,
                                          modelFilePath: opts.ModelFilePath, batchSize: 1, dropoutRatio: opts.DropoutRatio,
                                          archType: archType, deviceIds: new int[1] {
                    0
                }, multiHeadNum: opts.MultiHeadNum,
                                          warmupSteps: opts.WarmUpSteps, encoderType: encoderType);

                ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath);
            }
            else
            {
                argParser.Usage();
            }
        }