Beispiel #1
0
        static void Main(string[] args)
        {
            System.Random r = new Random(5);
            Preprocess();
            S2S = new AttentionSeq2Seq(32, 16, 1, input, output, true);
            try { S2S.Load(); } catch (Exception) { }

            int c = 0;

            S2S.IterationDone += (a1, a2) =>
            {
                CostEventArg ep = a2 as CostEventArg;

                if (c % 100 == 0)
                {
                    Console.WriteLine($"Cost {ep.Cost} Iteration {ep.Iteration} k {c}");
                    S2S.Save();
                }
                c++;
            };

            MainThread = new Thread(new ThreadStart(Train));
            MainThread.Start();

            ReadThread = new Thread(new ThreadStart(ReadingConsole));
            ReadThread.Start();
        }
        static void Main(string[] args)
        {
            Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            Options   options   = new Options();
            ArgParser argParser = new ArgParser(args, options);

            AttentionSeq2Seq ss = null;

            if (String.Equals(options.TaskName, "train", StringComparison.InvariantCultureIgnoreCase))
            {
                Corpus trainCorpus = new Corpus(options.TrainCorpusPath, options.SrcLang, options.TgtLang, options.ShuffleBlockSize);
                if (File.Exists(options.ModelFilePath) == false)
                {
                    ss = new AttentionSeq2Seq(options.WordVectorSize, options.HiddenSize, options.Depth, trainCorpus, options.SrcVocab, options.TgtVocab, options.SrcEmbeddingModelFilePath, options.TgtEmbeddingModelFilePath,
                                              options.SparseFeature, true, options.ModelFilePath);
                }
                else
                {
                    Logger.WriteLine($"Loading model from '{options.ModelFilePath}'...");
                    ss = new AttentionSeq2Seq();
                    ss.Load(options.ModelFilePath);
                    ss.TrainCorpus = trainCorpus;
                }

                Logger.WriteLine($"Source Language = '{options.SrcLang}'");
                Logger.WriteLine($"Target Language = '{options.TgtLang}'");
                Logger.WriteLine($"SSE Enable = '{System.Numerics.Vector.IsHardwareAccelerated}'");
                Logger.WriteLine($"SSE Size = '{System.Numerics.Vector<float>.Count * 32}'");
                Logger.WriteLine($"Processor counter = '{Environment.ProcessorCount}'");
                Logger.WriteLine($"Hidden Size = '{ss.HiddenSize}'");
                Logger.WriteLine($"Word Vector Size = '{ss.WordVectorSize}'");
                Logger.WriteLine($"Learning Rate = '{options.LearningRate}'");
                Logger.WriteLine($"Network Layer = '{ss.Depth}'");
                Logger.WriteLine($"Use Sparse Feature = '{options.SparseFeature}'");

                ss.IterationDone += ss_IterationDone;
                ss.Train(300, options.LearningRate);
            }
            else if (String.Equals(options.TaskName, "test", StringComparison.InvariantCultureIgnoreCase))
            {
                ss = new AttentionSeq2Seq();
                ss.Load(options.ModelFilePath);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(options.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    List <string> outputWords = ss.Predict(line.ToLower().Trim().Split(' ').ToList());
                    outputLines.Add(String.Join(" ", outputWords));
                }

                File.WriteAllLines(options.OutputTestFile, outputLines);
            }
            else
            {
                argParser.Usage();
            }
        }
 private void button5_Click(object sender, EventArgs e)
 {
     this.TrainButton.Enabled = true;
     ss.Load();
     this.PredictButton.Enabled = true;
     ResultTxtBox.Enabled       = true;
     SrcTxtBox.Enabled          = true;
 }