public void OffsetObject(Object o, int offset) { if (o is SequenceLabel) { SequenceLabel ol = o as SequenceLabel; ol.LocalAddress += offset; } }
public void RecurseOffset(SymbolContainer container, int startAddr, int offset) { foreach (KeyValuePair <String, Object> kv in container.Symbols) { Object o = kv.Value; if (o is SequenceLabel) { SequenceLabel ol = o as SequenceLabel; if (ol.LocalAddress > startAddr) { ol.LocalAddress += offset; } } if (o is SymbolContainer) { RecurseOffset(o as SymbolContainer, startAddr, offset); } } }
static void Main(string[] args) { ShowOptions(args); Logger.LogFile = $"{nameof(SeqLabelConsole)}_{GetTimeStamp(DateTime.Now)}.log"; //Parse command line Options opts = new Options(); ArgParser argParser = new ArgParser(args, opts); if (String.IsNullOrEmpty(opts.ConfigFilePath) == false) { Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath)); } SequenceLabel sl = null; ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType); EncoderTypeEnums encoderType = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType); ModeEnums mode = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName); //Parse device ids from options int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray(); if (mode == ModeEnums.Train) { // Load train corpus ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false); // Load valid corpus ParallelCorpus validCorpus = String.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false); // Load or build vocabulary Vocab vocab = null; if (!String.IsNullOrEmpty(opts.SrcVocab) && !String.IsNullOrEmpty(opts.TgtVocab)) { // Vocabulary files are specified, so we load them vocab = new Vocab(opts.SrcVocab, opts.TgtVocab); } else { // We don't specify vocabulary, so we build it from train corpus vocab = new Vocab(trainCorpus); } // Create learning rate ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount); // Create optimizer AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2); // Create metrics List <IMetric> metrics = new List <IMetric>(); foreach (var word in vocab.TgtVocab) { metrics.Add(new SequenceLabelFscoreMetric(word)); } if (File.Exists(opts.ModelFilePath) == false) { //New training sl = new SequenceLabel(hiddenDim: opts.HiddenSize, embeddingDim: opts.WordVectorSize, encoderLayerDepth: opts.EncoderLayerDepth, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds, processorType: processorType, modelFilePath: opts.ModelFilePath, vocab: vocab); } else { //Incremental training Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'..."); sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, dropoutRatio: opts.DropoutRatio); } // Add event handler for monitoring sl.IterationDone += ss_IterationDone; // Kick off training sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics); } else if (mode == ModeEnums.Valid) { Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'"); // Load valid corpus ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, false); Vocab vocab = new Vocab(validCorpus); // Create metrics List <IMetric> metrics = new List <IMetric>(); foreach (var word in vocab.TgtVocab) { metrics.Add(new SequenceLabelFscoreMetric(word)); } sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); sl.Valid(validCorpus: validCorpus, metrics: metrics); } else if (mode == ModeEnums.Test) { Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'"); //Test trained model sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); List <string> outputLines = new List <string>(); var data_sents_raw1 = File.ReadAllLines(opts.InputTestFile); foreach (string line in data_sents_raw1) { var outputTokensBatch = sl.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList(), false)); outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x))); } File.WriteAllLines(opts.OutputTestFile, outputLines); } //else if (mode == ModeEnums.VisualizeNetwork) //{ // ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth, // vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio, // processorType: processorType, deviceIds: new int[1] { 0 }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType); // ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath); //} else { argParser.Usage(); } }
public Sequence ParseSequence(Boolean isMacro) // Parses a sequence and return an object representation (used for both normal sequences and macros) { Sequence sequence = new Sequence() { IsMacro = isMacro }; if (Enumerator.HasToken() && Enumerator.Current.TokenType == TokenType.ParenList) { if (!isMacro) { throw new MicroassemblerParseException(Enumerator.Last, "Only macros can define a parameter list"); } List <Object> parameters = (List <Object>)Enumerator.Current.Value; Enumerator.Advance(); if (parameters.Where(p => !(p is String) || (p is String && ((String)p).Any(Char.IsWhiteSpace))).Any()) { throw new MicroassemblerParseException(Enumerator.Last, "Parameter definitions may only contain non-whitespace-delimited words"); } sequence.Parameters.AddRange(parameters.Cast <String>()); } VerifySyntaxToken(TokenType.OpenBlock, "{"); String word = ""; do { if (Enumerator.HasToken() && Enumerator.Current.TokenType == TokenType.CloseBlock) { Enumerator.Advance(); break; } word = GetWordToken(); if (word.StartsWith("::") && word.EndsWith("::")) { String label = word.Substring(2, word.Length - 4); sequence[label] = new SequenceLabel { LocalAddress = sequence.Steps.Count }; } else if (Enumerator.HasToken() && Enumerator.Current.TokenType == TokenType.ParenList) { List <Object> arguments = (List <Object>)Enumerator.Current.Value; Enumerator.Advance(); SequenceMacroReference step = new SequenceMacroReference() { Arguments = arguments, Symbol = word, Line = Enumerator.Last.Line }; sequence.Steps.Add(step); } else if (word.ToLower().Equals("assert")) { SequenceAssertion step = ParseSequenceAssertion(); sequence.Steps.Add(step); } else { throw new MicroassemblerParseException(Enumerator.Last); } }while (true); return(sequence); }