Ejemplo n.º 1
0
        public SeqLabel(SeqLabelOptions options, Vocab srcVocab = null, Vocab clsVocab = null)
            : base(options.DeviceIds, options.ProcessorType, options.ModelFilePath, options.MemoryUsageRatio, options.CompilerOptions, options.ValidIntervalHours, updateFreq: options.UpdateFreq)
        {
            m_shuffleType = options.ShuffleType;
            m_options     = options;

            // Model must exist if current task is not for training
            if ((m_options.Task != ModeEnums.Train) && !File.Exists(m_options.ModelFilePath))
            {
                throw new FileNotFoundException($"Model '{m_options.ModelFilePath}' doesn't exist.");
            }

            if (File.Exists(m_options.ModelFilePath))
            {
                if (srcVocab != null || clsVocab != null)
                {
                    throw new ArgumentException($"Model '{m_options.ModelFilePath}' exists and it includes vocabulary, so input vocabulary must be null.");
                }

                // Model file exists, so we load it from file.
                m_modelMetaData = LoadModelImpl_WITH_CONVERT(CreateTrainableParameters);
                //m_modelMetaData = LoadModelImpl();
                //---LoadModel_As_BinaryFormatter( CreateTrainableParameters );
            }
            else
            {
                // Model doesn't exist, we create it and initlaize parameters
                m_modelMetaData = new SeqLabelModel(options.HiddenSize, options.EmbeddingDim, options.EncoderLayerDepth, options.MultiHeadNum, options.EncoderType, srcVocab, clsVocab, options.MaxSegmentNum);

                //Initializng weights in encoders and decoders
                CreateTrainableParameters(m_modelMetaData);
            }

            m_modelMetaData.ShowModelInfo();
        }
Ejemplo n.º 2
0
        private static void Main(string[] args)
        {
            ShowOptions(args);

            Logger.LogFile = $"{nameof(SeqLabelConsole)}_{Utils.GetTimeStamp(DateTime.Now)}.log";

            //Parse command line
            SeqLabelOptions opts = new SeqLabelOptions();
            ArgParser argParser = new ArgParser(args, opts);

            if (!opts.ConfigFilePath.IsNullOrEmpty())
            {
                Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                opts = JsonConvert.DeserializeObject<SeqLabelOptions>(File.ReadAllText(opts.ConfigFilePath));
            }

            DecodingOptions decodingOptions = opts.CreateDecodingOptions();

            SeqLabel sl = null;

            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
            if ( opts.Task == ModeEnums.Train )
            {
                // Load train corpus
                SeqLabelingCorpus trainCorpus = new SeqLabelingCorpus(opts.TrainCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, maxSentLength: opts.MaxTrainSentLength);

                // Load valid corpus
                List<SeqLabelingCorpus> validCorpusList = new List<SeqLabelingCorpus>();
                if (!opts.ValidCorpusPaths.IsNullOrEmpty())
                {
                    string[] validCorpusPathList = opts.ValidCorpusPaths.Split(';');
                    foreach (var validCorpusPath in validCorpusPathList)
                    {
                        validCorpusList.Add(new SeqLabelingCorpus(opts.ValidCorpusPaths, opts.BatchSize, opts.ShuffleBlockSize, maxSentLength: opts.MaxTestSentLength));
                    }
                }

                // Load or build vocabulary
                Vocab srcVocab = null;
                Vocab tgtVocab = null;
                if (!opts.SrcVocab.IsNullOrEmpty() && !opts.TgtVocab.IsNullOrEmpty() )
                {
                    // Vocabulary files are specified, so we load them
                    srcVocab = new Vocab(opts.SrcVocab);
                    tgtVocab = new Vocab(opts.TgtVocab);
                }
                else
                {
                    // We don't specify vocabulary, so we build it from train corpus
                    (srcVocab, tgtVocab) = trainCorpus.BuildVocabs(opts.SrcVocabSize);
                }

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                IOptimizer optimizer = Misc.CreateOptimizer(opts);

                // Create metrics
                List<IMetric> metrics = new List<IMetric>();
                foreach (string word in tgtVocab.Items)
                {
                    if (BuildInTokens.IsPreDefinedToken(word) == false)
                    {
                        metrics.Add(new SequenceLabelFscoreMetric(word));
                    }
                }

                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    sl = new SeqLabel(opts, srcVocab: srcVocab, clsVocab: tgtVocab);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    sl = new SeqLabel(opts);
                }

                // Add event handler for monitoring
                sl.StatusUpdateWatcher += Misc.Ss_StatusUpdateWatcher;

                // Kick off training
                sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: validCorpusList.ToArray(), learningRate: learningRate, optimizer: optimizer, metrics: metrics, decodingOptions: decodingOptions);

            }
            else if ( opts.Task == ModeEnums.Valid )
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPaths}'");

                // Load valid corpus
                SeqLabelingCorpus validCorpus = new SeqLabelingCorpus(opts.ValidCorpusPaths, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxTestSentLength);
                (Vocab srcVocab, Vocab tgtVocab) = validCorpus.BuildVocabs();

                // Create metrics
                List<IMetric> metrics = new List<IMetric>();
                foreach (string word in tgtVocab.Items)
                {
                    if (BuildInTokens.IsPreDefinedToken(word) == false)
                    {
                        metrics.Add(new SequenceLabelFscoreMetric(word));
                    }
                }

                sl = new SeqLabel(opts);
                sl.Valid(validCorpus: validCorpus, metrics: metrics, decodingOptions: decodingOptions);
            }
            else if ( opts.Task == ModeEnums.Test )
            {
                Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                //Test trained model
                sl = new SeqLabel(opts);

                List<string> outputLines = new List<string>();
                string[] data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    var nrs = sl.Test<SeqLabelingCorpusBatch>(ConstructInputTokens(line.Trim().Split(' ').ToList()), null, decodingOptions: decodingOptions);
                    outputLines.AddRange(nrs[0].Output[0].Select(x => string.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputFile, outputLines);
            }
            else
            {
                argParser.Usage();
            }
        }