public void OffsetObject(Object o, int offset)
 {
     if (o is SequenceLabel)
     {
         SequenceLabel ol = o as SequenceLabel;
         ol.LocalAddress += offset;
     }
 }
 public void RecurseOffset(SymbolContainer container, int startAddr, int offset)
 {
     foreach (KeyValuePair <String, Object> kv in container.Symbols)
     {
         Object o = kv.Value;
         if (o is SequenceLabel)
         {
             SequenceLabel ol = o as SequenceLabel;
             if (ol.LocalAddress > startAddr)
             {
                 ol.LocalAddress += offset;
             }
         }
         if (o is SymbolContainer)
         {
             RecurseOffset(o as SymbolContainer, startAddr, offset);
         }
     }
 }
示例#3
0
        static void Main(string[] args)
        {
            ShowOptions(args);

            Logger.LogFile = $"{nameof(SeqLabelConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            //Parse command line
            Options   opts      = new Options();
            ArgParser argParser = new ArgParser(args, opts);

            if (String.IsNullOrEmpty(opts.ConfigFilePath) == false)
            {
                Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
            }


            SequenceLabel      sl            = null;
            ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
            EncoderTypeEnums   encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
            ModeEnums          mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);

            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
            if (mode == ModeEnums.Train)
            {
                // Load train corpus
                ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false);

                // Load valid corpus
                ParallelCorpus validCorpus = String.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false);

                // Load or build vocabulary
                Vocab vocab = null;
                if (!String.IsNullOrEmpty(opts.SrcVocab) && !String.IsNullOrEmpty(opts.TgtVocab))
                {
                    // Vocabulary files are specified, so we load them
                    vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                }
                else
                {
                    // We don't specify vocabulary, so we build it from train corpus
                    vocab = new Vocab(trainCorpus);
                }

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                // Create metrics
                List <IMetric> metrics = new List <IMetric>();
                foreach (var word in vocab.TgtVocab)
                {
                    metrics.Add(new SequenceLabelFscoreMetric(word));
                }

                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    sl = new SequenceLabel(hiddenDim: opts.HiddenSize, embeddingDim: opts.WordVectorSize, encoderLayerDepth: opts.EncoderLayerDepth, multiHeadNum: opts.MultiHeadNum,
                                           encoderType: encoderType,
                                           dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds, processorType: processorType, modelFilePath: opts.ModelFilePath, vocab: vocab);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, dropoutRatio: opts.DropoutRatio);
                }

                // Add event handler for monitoring
                sl.IterationDone += ss_IterationDone;

                // Kick off training
                sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics);
            }
            else if (mode == ModeEnums.Valid)
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                // Load valid corpus
                ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, false);

                Vocab vocab = new Vocab(validCorpus);
                // Create metrics
                List <IMetric> metrics = new List <IMetric>();
                foreach (var word in vocab.TgtVocab)
                {
                    metrics.Add(new SequenceLabelFscoreMetric(word));
                }

                sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);
                sl.Valid(validCorpus: validCorpus, metrics: metrics);
            }
            else if (mode == ModeEnums.Test)
            {
                Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                //Test trained model
                sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    var outputTokensBatch = sl.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList(), false));
                    outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputTestFile, outputLines);
            }
            //else if (mode == ModeEnums.VisualizeNetwork)
            //{
            //    ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
            //        vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio,
            //        processorType: processorType, deviceIds: new int[1] { 0 }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType);

            //    ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath);
            //}
            else
            {
                argParser.Usage();
            }
        }
示例#4
0
        public Sequence ParseSequence(Boolean isMacro) // Parses a sequence and return an object representation (used for both normal sequences and macros)
        {
            Sequence sequence = new Sequence()
            {
                IsMacro = isMacro
            };

            if (Enumerator.HasToken() && Enumerator.Current.TokenType == TokenType.ParenList)
            {
                if (!isMacro)
                {
                    throw new MicroassemblerParseException(Enumerator.Last, "Only macros can define a parameter list");
                }
                List <Object> parameters = (List <Object>)Enumerator.Current.Value;
                Enumerator.Advance();
                if (parameters.Where(p => !(p is String) || (p is String && ((String)p).Any(Char.IsWhiteSpace))).Any())
                {
                    throw new MicroassemblerParseException(Enumerator.Last, "Parameter definitions may only contain non-whitespace-delimited words");
                }
                sequence.Parameters.AddRange(parameters.Cast <String>());
            }
            VerifySyntaxToken(TokenType.OpenBlock, "{");
            String word = "";

            do
            {
                if (Enumerator.HasToken() && Enumerator.Current.TokenType == TokenType.CloseBlock)
                {
                    Enumerator.Advance();
                    break;
                }
                word = GetWordToken();
                if (word.StartsWith("::") && word.EndsWith("::"))
                {
                    String label = word.Substring(2, word.Length - 4);
                    sequence[label] = new SequenceLabel {
                        LocalAddress = sequence.Steps.Count
                    };
                }
                else if (Enumerator.HasToken() && Enumerator.Current.TokenType == TokenType.ParenList)
                {
                    List <Object> arguments = (List <Object>)Enumerator.Current.Value;
                    Enumerator.Advance();
                    SequenceMacroReference step = new SequenceMacroReference()
                    {
                        Arguments = arguments, Symbol = word, Line = Enumerator.Last.Line
                    };
                    sequence.Steps.Add(step);
                }
                else if (word.ToLower().Equals("assert"))
                {
                    SequenceAssertion step = ParseSequenceAssertion();
                    sequence.Steps.Add(step);
                }
                else
                {
                    throw new MicroassemblerParseException(Enumerator.Last);
                }
            }while (true);

            return(sequence);
        }