Example #1
0
 public Seq2SeqClassificationModel(int hiddenDim, int srcEmbeddingDim, int tgtEmbeddingDim, int encoderLayerDepth, int decoderLayerDepth, int multiHeadNum,
                                   EncoderTypeEnums encoderType, DecoderTypeEnums decoderType, Vocab srcVocab, Vocab tgtVocab, Vocab clsVocab, bool enableCoverageModel, bool sharedEmbeddings, bool enableSegmentEmbeddings, bool enableTagEmbeddings, int maxSegmentNum)
     : base(hiddenDim, srcEmbeddingDim, tgtEmbeddingDim, encoderLayerDepth, decoderLayerDepth, multiHeadNum, encoderType, decoderType, srcVocab, tgtVocab, enableCoverageModel, sharedEmbeddings, enableSegmentEmbeddings,
            enableTagEmbeddings, maxSegmentNum, false)
 {
     ClsVocab = clsVocab;
 }
Example #2
0
 public SeqLabelModelMetaData(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType, Vocab vocab)
 {
     HiddenDim         = hiddenDim;
     EmbeddingDim      = embeddingDim;
     EncoderLayerDepth = encoderLayerDepth;
     MultiHeadNum      = multiHeadNum;
     EncoderType       = encoderType;
     Vocab             = vocab;
 }
Example #3
0
 public Seq2SeqModelMetaData(int hiddenDim, int embeddingDim, int encoderLayerDepth, int decoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType, Vocab vocab, bool enableCoverageModel)
 {
     HiddenDim           = hiddenDim;
     EmbeddingDim        = embeddingDim;
     EncoderLayerDepth   = encoderLayerDepth;
     DecoderLayerDepth   = decoderLayerDepth;
     MultiHeadNum        = multiHeadNum;
     EncoderType         = encoderType;
     Vocab               = vocab;
     EnableCoverageModel = enableCoverageModel;
 }
Example #4
0
        public AttentionSeq2Seq(int embeddingDim, int hiddenDim, int encoderLayerDepth, int decoderLayerDepth, Corpus trainCorpus, string srcVocabFilePath, string tgtVocabFilePath,
                                string srcEmbeddingFilePath, string tgtEmbeddingFilePath, string modelFilePath, int batchSize, float dropoutRatio, int multiHeadNum, int warmupSteps,
                                ArchTypeEnums archType, EncoderTypeEnums encoderType, int[] deviceIds)
        {
            TensorAllocator.InitDevices(archType, deviceIds);
            SetDefaultDeviceIds(deviceIds.Length);

            m_dropoutRatio  = dropoutRatio;
            m_batchSize     = batchSize;
            m_modelFilePath = modelFilePath;
            m_deviceIds     = deviceIds;
            m_multiHeadNum  = multiHeadNum;
            m_encoderType   = encoderType;
            m_warmupSteps   = warmupSteps + 1;

            TrainCorpus         = trainCorpus;
            m_encoderLayerDepth = encoderLayerDepth;
            m_decoderLayerDepth = decoderLayerDepth;
            m_embeddingDim      = embeddingDim;
            m_hiddenDim         = hiddenDim;

            //If vocabulary files are specified, we load them from file, otherwise, we build them from training corpus
            if (String.IsNullOrEmpty(srcVocabFilePath) == false && String.IsNullOrEmpty(tgtVocabFilePath) == false)
            {
                Logger.WriteLine($"Loading vocabulary files from '{srcVocabFilePath}' and '{tgtVocabFilePath}'...");
                LoadVocab(srcVocabFilePath, tgtVocabFilePath);
            }
            else
            {
                Logger.WriteLine("Building vocabulary from training corpus...");
                BuildVocab(trainCorpus);
            }

            //Initializng weights in encoders and decoders
            CreateEncoderDecoderEmbeddings();

            for (int i = 0; i < m_deviceIds.Length; i++)
            {
                //If pre-trained embedding weights are speicifed, loading them from files
                if (String.IsNullOrEmpty(srcEmbeddingFilePath) == false)
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{srcEmbeddingFilePath}' for source side.");
                    LoadWordEmbedding(srcEmbeddingFilePath, m_srcEmbedding[i], m_srcWordToIndex);
                }

                if (String.IsNullOrEmpty(tgtEmbeddingFilePath) == false)
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{tgtEmbeddingFilePath}' for target side.");
                    LoadWordEmbedding(tgtEmbeddingFilePath, m_tgtEmbedding[i], m_tgtWordToIndex);
                }
            }
        }
Example #5
0
 public Seq2SeqModel(int hiddenDim, int encoderEmbeddingDim, int decoderEmbeddingDim, int encoderLayerDepth, int decoderLayerDepth, int multiHeadNum,
                     EncoderTypeEnums encoderType, DecoderTypeEnums decoderType, Vocab srcVocab, Vocab tgtVocab, bool enableCoverageModel,
                     bool sharedEmbeddings, bool enableSegmentEmbeddings, bool enableTagEmbeddings, int maxSegmentNum, bool pointerGenerator)
     : base(hiddenDim, encoderLayerDepth, encoderType, encoderEmbeddingDim, multiHeadNum, srcVocab, enableSegmentEmbeddings, enableTagEmbeddings, maxSegmentNum, pointerGenerator)
 {
     DecoderEmbeddingDim = decoderEmbeddingDim;
     DecoderLayerDepth   = decoderLayerDepth;
     MultiHeadNum        = multiHeadNum;
     DecoderType         = decoderType;
     EnableCoverageModel = enableCoverageModel;
     SharedEmbeddings    = sharedEmbeddings;
     TgtVocab            = tgtVocab;
 }
Example #6
0
        public Model(int hiddenDim, int encoderLayerDepth, EncoderTypeEnums encoderType, int encoderEmbeddingDim, int multiHeadNum, Vocab srcVocab,
                     bool enableSegmentEmbeddings, bool enableTagEmbeddings, int maxSegmentNum, bool pointerGenerator)
        {
            HiddenDim               = hiddenDim;
            EncoderLayerDepth       = encoderLayerDepth;
            EncoderType             = encoderType;
            MultiHeadNum            = multiHeadNum;
            SrcVocab                = srcVocab;
            EncoderEmbeddingDim     = encoderEmbeddingDim;
            EnableSegmentEmbeddings = enableSegmentEmbeddings;
            EnableTagEmbeddings     = enableTagEmbeddings;
            MaxSegmentNum           = maxSegmentNum;
            PointerGenerator        = pointerGenerator;

            Name2Weights = new Dictionary <string, float[]>();
        }
Example #7
0
        public AttentionSeq2Seq(string modelFilePath, int batchSize, ArchTypeEnums archType, int[] deviceIds)
        {
            m_batchSize     = batchSize;
            m_deviceIds     = deviceIds;
            m_modelFilePath = modelFilePath;

            TensorAllocator.InitDevices(archType, deviceIds);
            SetDefaultDeviceIds(deviceIds.Length);

            Logger.WriteLine($"Loading model from '{modelFilePath}'...");

            ModelAttentionMetaData modelMetaData = new ModelAttentionMetaData();
            BinaryFormatter        bf            = new BinaryFormatter();
            FileStream             fs            = new FileStream(m_modelFilePath, FileMode.Open, FileAccess.Read);

            modelMetaData = bf.Deserialize(fs) as ModelAttentionMetaData;

            m_clipvalue         = modelMetaData.Clipval;
            m_encoderLayerDepth = modelMetaData.EncoderLayerDepth;
            m_decoderLayerDepth = modelMetaData.DecoderLayerDepth;
            m_hiddenDim         = modelMetaData.HiddenDim;
            m_startLearningRate = modelMetaData.LearningRate;
            m_embeddingDim      = modelMetaData.EmbeddingDim;
            m_multiHeadNum      = modelMetaData.MultiHeadNum;
            m_encoderType       = modelMetaData.EncoderType;
            m_regc           = modelMetaData.Regc;
            m_dropoutRatio   = modelMetaData.DropoutRatio;
            m_srcWordToIndex = modelMetaData.SrcWordToIndex;
            m_srcIndexToWord = modelMetaData.SrcIndexToWord;
            m_tgtWordToIndex = modelMetaData.TgtWordToIndex;
            m_tgtIndexToWord = modelMetaData.TgtIndexToWord;

            CreateEncoderDecoderEmbeddings();

            m_encoder[m_encoderDefaultDeviceId].Load(fs);
            m_decoder[m_decoderDefaultDeviceId].Load(fs);

            m_srcEmbedding[m_srcEmbeddingDefaultDeviceId].Load(fs);
            m_tgtEmbedding[m_tgtEmbeddingDefaultDeviceId].Load(fs);

            m_decoderFFLayer[m_DecoderFFLayerDefaultDeviceId].Load(fs);

            fs.Close();
            fs.Dispose();
        }
Example #8
0
        public SequenceLabel(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType,
                             float dropoutRatio, Vocab vocab, int[] deviceIds, ProcessorTypeEnums processorType, string modelFilePath) :
            base(deviceIds, processorType, modelFilePath)
        {
            m_modelMetaData = new Seq2SeqModelMetaData(hiddenDim, embeddingDim, encoderLayerDepth, 0, multiHeadNum, encoderType, vocab);
            m_dropoutRatio  = dropoutRatio;

            //Initializng weights in encoders and decoders
            CreateTrainableParameters(m_modelMetaData);
        }
Example #9
0
        public AttentionSeq2Seq(int embeddingDim, int hiddenDim, int encoderLayerDepth, int decoderLayerDepth, Vocab vocab, string srcEmbeddingFilePath, string tgtEmbeddingFilePath,
                                string modelFilePath, float dropoutRatio, int multiHeadNum, ProcessorTypeEnums processorType, EncoderTypeEnums encoderType, DecoderTypeEnums decoderType, bool enableCoverageModel, int[] deviceIds,
                                bool isSrcEmbTrainable = true, bool isTgtEmbTrainable = true, bool isEncoderTrainable = true, bool isDecoderTrainable = true, int maxTgtSntSize = 128)
            : base(deviceIds, processorType, modelFilePath)
        {
            m_modelMetaData = new Seq2SeqModelMetaData(hiddenDim, embeddingDim, encoderLayerDepth, decoderLayerDepth, multiHeadNum, encoderType, decoderType, vocab, enableCoverageModel);
            m_dropoutRatio  = dropoutRatio;

            m_isSrcEmbTrainable  = isSrcEmbTrainable;
            m_isTgtEmbTrainable  = isTgtEmbTrainable;
            m_isEncoderTrainable = isEncoderTrainable;
            m_isDecoderTrainable = isDecoderTrainable;
            m_maxTgtSntSize      = maxTgtSntSize;

            //Initializng weights in encoders and decoders
            CreateTrainableParameters(m_modelMetaData);

            // Load external embedding from files
            for (int i = 0; i < DeviceIds.Length; i++)
            {
                //If pre-trained embedding weights are speicifed, loading them from files
                if (!string.IsNullOrEmpty(srcEmbeddingFilePath))
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{srcEmbeddingFilePath}' for source side.");
                    LoadWordEmbedding(srcEmbeddingFilePath, m_srcEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.SrcWordToIndex);
                }

                if (!string.IsNullOrEmpty(tgtEmbeddingFilePath))
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{tgtEmbeddingFilePath}' for target side.");
                    LoadWordEmbedding(tgtEmbeddingFilePath, m_tgtEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.TgtWordToIndex);
                }
            }
        }
Example #10
0
        private static void Main(string[] args)
        {
            try
            {
                Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";
                ShowOptions(args);

                //Parse command line
                Options   opts      = new Options();
                ArgParser argParser = new ArgParser(args, opts);

                if (string.IsNullOrEmpty(opts.ConfigFilePath) == false)
                {
                    Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                    opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
                }

                AttentionSeq2Seq   ss            = null;
                ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
                EncoderTypeEnums   encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
                DecoderTypeEnums   decoderType   = (DecoderTypeEnums)Enum.Parse(typeof(DecoderTypeEnums), opts.DecoderType);
                ModeEnums          mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);
                ShuffleEnums       shuffleType   = (ShuffleEnums)Enum.Parse(typeof(ShuffleEnums), opts.ShuffleType);

                string[] cudaCompilerOptions = String.IsNullOrEmpty(opts.CompilerOptions) ? null : opts.CompilerOptions.Split(' ', StringSplitOptions.RemoveEmptyEntries);

                //Parse device ids from options
                int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
                if (mode == ModeEnums.Train)
                {
                    // Load train corpus
                    ParallelCorpus trainCorpus = new ParallelCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, batchSize: opts.BatchSize, shuffleBlockSize: opts.ShuffleBlockSize,
                                                                    maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: shuffleType);
                    // Load valid corpus
                    ParallelCorpus validCorpus = string.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxSrcSentLength, opts.MaxTgtSentLength);

                    // Create learning rate
                    ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                    // Create optimizer
                    AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                    // Create metrics
                    List <IMetric> metrics = new List <IMetric>
                    {
                        new BleuMetric(),
                        new LengthRatioMetric()
                    };


                    if (!String.IsNullOrEmpty(opts.ModelFilePath) && File.Exists(opts.ModelFilePath))
                    {
                        //Incremental training
                        Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                        ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds,
                                                  isSrcEmbTrainable: opts.IsSrcEmbeddingTrainable, isTgtEmbTrainable: opts.IsTgtEmbeddingTrainable, isEncoderTrainable: opts.IsEncoderTrainable, isDecoderTrainable: opts.IsDecoderTrainable,
                                                  maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    }
                    else
                    {
                        // Load or build vocabulary
                        Vocab vocab = null;
                        if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab))
                        {
                            // Vocabulary files are specified, so we load them
                            vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                        }
                        else
                        {
                            // We don't specify vocabulary, so we build it from train corpus
                            vocab = new Vocab(trainCorpus);
                        }

                        //New training
                        ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                                  srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath, vocab: vocab, modelFilePath: opts.ModelFilePath,
                                                  dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, decoderType: decoderType,
                                                  maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, enableCoverageModel: opts.EnableCoverageModel, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    }

                    // Add event handler for monitoring
                    ss.IterationDone += ss_IterationDone;

                    // Kick off training
                    ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics);
                }
                else if (mode == ModeEnums.Valid)
                {
                    Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                    // Create metrics
                    List <IMetric> metrics = new List <IMetric>
                    {
                        new BleuMetric(),
                        new LengthRatioMetric()
                    };

                    // Load valid corpus
                    ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxSrcSentLength, opts.MaxTgtSentLength);

                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    ss.Valid(validCorpus: validCorpus, metrics: metrics);
                }
                else if (mode == ModeEnums.Test)
                {
                    Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                    //Test trained model
                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, memoryUsageRatio: opts.MemoryUsageRatio,
                                              shuffleType: shuffleType, maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, compilerOptions: cudaCompilerOptions);

                    List <string> outputLines     = new List <string>();
                    string[]      data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                    foreach (string line in data_sents_raw1)
                    {
                        if (opts.BeamSearch > 1)
                        {
                            // Below support beam search
                            List <List <string> > outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch);
                            outputLines.AddRange(outputWordsList.Select(x => string.Join(" ", x)));
                        }
                        else
                        {
                            var outputTokensBatch = ss.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList()));
                            outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x)));
                        }
                    }

                    File.WriteAllLines(opts.OutputTestFile, outputLines);
                }
                else if (mode == ModeEnums.DumpVocab)
                {
                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, compilerOptions: cudaCompilerOptions);
                    ss.DumpVocabToFiles(opts.SrcVocab, opts.TgtVocab);
                }
                else
                {
                    argParser.Usage();
                }
            }
            catch (Exception err)
            {
                Logger.WriteLine($"Exception: '{err.Message}'");
                Logger.WriteLine($"Call stack: '{err.StackTrace}'");
            }
        }
Example #11
0
        static void Main(string[] args)
        {
            Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            //Parse command line
            Options   opts      = new Options();
            ArgParser argParser = new ArgParser(args, opts);

            AttentionSeq2Seq ss          = null;
            ArchTypeEnums    archType    = (ArchTypeEnums)Enum.Parse(typeof(ArchTypeEnums), opts.ArchType);
            EncoderTypeEnums encoderType = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
            ModeEnums        mode        = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);


            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();

            if (mode == ModeEnums.Train)
            {
                ShowOptions(args, opts);

                Corpus trainCorpus = new Corpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize * deviceIds.Length,
                                                opts.ShuffleBlockSize, opts.MaxSentLength);
                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                              trainCorpus: trainCorpus, srcVocabFilePath: opts.SrcVocab, tgtVocabFilePath: opts.TgtVocab,
                                              srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath,
                                              modelFilePath: opts.ModelFilePath, batchSize: opts.BatchSize, dropoutRatio: opts.DropoutRatio,
                                              archType: archType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, warmupSteps: opts.WarmUpSteps, encoderType: encoderType);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    ss             = new AttentionSeq2Seq(opts.ModelFilePath, opts.BatchSize, archType, deviceIds);
                    ss.TrainCorpus = trainCorpus;
                }

                ss.IterationDone += ss_IterationDone;
                ss.Train(opts.MaxEpochNum, opts.LearningRate, opts.GradClip);
            }
            else if (mode == ModeEnums.Test)
            {
                //Test trained model
                ss = new AttentionSeq2Seq(opts.ModelFilePath, 1, archType, deviceIds);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    List <List <string> > outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch);
                    outputLines.AddRange(outputWordsList.Select(x => String.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputTestFile, outputLines);
            }
            else if (mode == ModeEnums.VisualizeNetwork)
            {
                ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth,
                                          decoderLayerDepth: opts.DecoderLayerDepth, trainCorpus: null, srcVocabFilePath: null, tgtVocabFilePath: null,
                                          srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null,
                                          modelFilePath: opts.ModelFilePath, batchSize: 1, dropoutRatio: opts.DropoutRatio,
                                          archType: archType, deviceIds: new int[1] {
                    0
                }, multiHeadNum: opts.MultiHeadNum,
                                          warmupSteps: opts.WarmUpSteps, encoderType: encoderType);

                ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath);
            }
            else
            {
                argParser.Usage();
            }
        }
Example #12
0
 public SeqLabelModel(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType, Vocab srcVocab, Vocab clsVocab, int maxSegmentNum)
     : base(hiddenDim, encoderLayerDepth, encoderType, embeddingDim, multiHeadNum, srcVocab, false, false, maxSegmentNum, false)
 {
     ClsVocab = clsVocab;
 }
Example #13
0
        static void Main(string[] args)
        {
            ShowOptions(args);

            Logger.LogFile = $"{nameof(SeqLabelConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            //Parse command line
            Options   opts      = new Options();
            ArgParser argParser = new ArgParser(args, opts);

            if (String.IsNullOrEmpty(opts.ConfigFilePath) == false)
            {
                Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
            }


            SequenceLabel      sl            = null;
            ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
            EncoderTypeEnums   encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
            ModeEnums          mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);

            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
            if (mode == ModeEnums.Train)
            {
                // Load train corpus
                ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false);

                // Load valid corpus
                ParallelCorpus validCorpus = String.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false);

                // Load or build vocabulary
                Vocab vocab = null;
                if (!String.IsNullOrEmpty(opts.SrcVocab) && !String.IsNullOrEmpty(opts.TgtVocab))
                {
                    // Vocabulary files are specified, so we load them
                    vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                }
                else
                {
                    // We don't specify vocabulary, so we build it from train corpus
                    vocab = new Vocab(trainCorpus);
                }

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                // Create metrics
                List <IMetric> metrics = new List <IMetric>();
                foreach (var word in vocab.TgtVocab)
                {
                    metrics.Add(new SequenceLabelFscoreMetric(word));
                }

                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    sl = new SequenceLabel(hiddenDim: opts.HiddenSize, embeddingDim: opts.WordVectorSize, encoderLayerDepth: opts.EncoderLayerDepth, multiHeadNum: opts.MultiHeadNum,
                                           encoderType: encoderType,
                                           dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds, processorType: processorType, modelFilePath: opts.ModelFilePath, vocab: vocab);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, dropoutRatio: opts.DropoutRatio);
                }

                // Add event handler for monitoring
                sl.IterationDone += ss_IterationDone;

                // Kick off training
                sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics);
            }
            else if (mode == ModeEnums.Valid)
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                // Load valid corpus
                ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, false);

                Vocab vocab = new Vocab(validCorpus);
                // Create metrics
                List <IMetric> metrics = new List <IMetric>();
                foreach (var word in vocab.TgtVocab)
                {
                    metrics.Add(new SequenceLabelFscoreMetric(word));
                }

                sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);
                sl.Valid(validCorpus: validCorpus, metrics: metrics);
            }
            else if (mode == ModeEnums.Test)
            {
                Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                //Test trained model
                sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    var outputTokensBatch = sl.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList(), false));
                    outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputTestFile, outputLines);
            }
            //else if (mode == ModeEnums.VisualizeNetwork)
            //{
            //    ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
            //        vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio,
            //        processorType: processorType, deviceIds: new int[1] { 0 }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType);

            //    ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath);
            //}
            else
            {
                argParser.Usage();
            }
        }
 public SeqClassificationModel(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType, Vocab srcVocab, List <Vocab> clsVocabs, bool enableSegmentEmbeddings, bool enableTagEmbeddings, int maxSegmentNum)
     : base(hiddenDim, encoderLayerDepth, encoderType, embeddingDim, multiHeadNum, srcVocab, enableSegmentEmbeddings, enableTagEmbeddings, maxSegmentNum, false)
 {
     ClsVocabs = clsVocabs;
 }
Example #15
0
 public SeqSimilarityModel(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType, Vocab srcVocab, Vocab clsVocab, bool enableSegmentEmbeddings, string similarityType, int maxSegmentNum)
     : base(hiddenDim, encoderLayerDepth, encoderType, embeddingDim, multiHeadNum, srcVocab, enableSegmentEmbeddings, false, maxSegmentNum, false)
 {
     ClsVocab       = clsVocab;
     SimilarityType = similarityType;
 }