Пример #1
0
 public SequenceLabel(string modelFilePath, ProcessorTypeEnums processorType, int[] deviceIds, float dropoutRatio = 0.0f, int maxSntSize = 128)
     : base(deviceIds, processorType, modelFilePath)
 {
     this.m_dropoutRatio  = dropoutRatio;
     this.m_modelMetaData = this.LoadModel(this.CreateTrainableParameters) as SeqLabelModelMetaData;
     this.m_maxSntSize    = maxSntSize;
 }
Пример #2
0
        public SequenceLabel(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType,
                             float dropoutRatio, Vocab vocab, int[] deviceIds, ProcessorTypeEnums processorType, string modelFilePath) :
            base(deviceIds, processorType, modelFilePath)
        {
            m_modelMetaData = new Seq2SeqModelMetaData(hiddenDim, embeddingDim, encoderLayerDepth, 0, multiHeadNum, encoderType, vocab);
            m_dropoutRatio  = dropoutRatio;

            //Initializng weights in encoders and decoders
            CreateTrainableParameters(m_modelMetaData);
        }
Пример #3
0
        public AttentionSeq2Seq(string modelFilePath, ProcessorTypeEnums processorType, int[] deviceIds, float dropoutRatio = 0.0f,
                                bool isSrcEmbTrainable = true, bool isTgtEmbTrainable = true, bool isEncoderTrainable = true, bool isDecoderTrainable = true)
            : base(deviceIds, processorType, modelFilePath)
        {
            m_dropoutRatio       = dropoutRatio;
            m_isSrcEmbTrainable  = isSrcEmbTrainable;
            m_isTgtEmbTrainable  = isTgtEmbTrainable;
            m_isEncoderTrainable = isEncoderTrainable;
            m_isDecoderTrainable = isDecoderTrainable;

            m_modelMetaData = LoadModel(CreateTrainableParameters) as Seq2SeqModelMetaData;
        }
Пример #4
0
        public AttentionSeq2Seq(string modelFilePath, ProcessorTypeEnums processorType, int[] deviceIds, float dropoutRatio = 0.0f,
                                bool isSrcEmbTrainable = true, bool isTgtEmbTrainable = true, bool isEncoderTrainable = true, bool isDecoderTrainable = true,
                                int maxSrcSntSize      = 128, int maxTgtSntSize       = 128, float memoryUsageRatio = 0.9f, ShuffleEnums shuffleType  = ShuffleEnums.Random, string[] compilerOptions = null)
            : base(deviceIds, processorType, modelFilePath, memoryUsageRatio, compilerOptions)
        {
            m_dropoutRatio       = dropoutRatio;
            m_isSrcEmbTrainable  = isSrcEmbTrainable;
            m_isTgtEmbTrainable  = isTgtEmbTrainable;
            m_isEncoderTrainable = isEncoderTrainable;
            m_isDecoderTrainable = isDecoderTrainable;
            m_maxSrcSntSize      = maxSrcSntSize;
            m_maxTgtSntSize      = maxTgtSntSize;
            m_shuffleType        = shuffleType;

            m_modelMetaData = LoadModel(CreateTrainableParameters) as Seq2SeqModelMetaData;
        }
Пример #5
0
        public static void InitDevices(ProcessorTypeEnums archType, int[] ids, float memoryUsageRatio = 0.9f, string[] compilerOptions = null)
        {
            m_archType  = archType;
            m_deviceIds = ids;
            m_allocator = new IAllocator[m_deviceIds.Length];

            if (m_archType == ProcessorTypeEnums.GPU)
            {
                foreach (int id in m_deviceIds)
                {
                    Logger.WriteLine($"Initialize device '{id}'");
                }

                m_cudaContext = new TSCudaContext(m_deviceIds, memoryUsageRatio, compilerOptions);
                m_cudaContext.Precompile(Console.Write);
                m_cudaContext.CleanUnusedPTX();
            }
        }
Пример #6
0
        public static void InitDevices(ProcessorTypeEnums archType, int[] ids)
        {
            m_archType = archType;
            if (m_archType == ProcessorTypeEnums.GPU)
            {
                m_deviceIds = ids;

                foreach (var id in m_deviceIds)
                {
                    Logger.WriteLine($"Initialize device '{id}'");
                }

                m_cudaContext = new TSCudaContext(m_deviceIds);
                m_cudaContext.Precompile(Console.Write);
                m_cudaContext.CleanUnusedPTX();

                m_allocator = new IAllocator[m_deviceIds.Length];
            }
            else
            {
                m_allocator = new IAllocator[1];
            }
        }
Пример #7
0
        public static void InitDevices(ProcessorTypeEnums archType, int[] ids, float memoryUsageRatio = 0.9f, string[] compilerOptions = null)
        {
            architectureType = archType;

            if (architectureType == ProcessorTypeEnums.GPU)
            {
                deviceIds = ids;

                foreach (var id in deviceIds)
                {
                    Logger.WriteLine($"Initialize device '{id}'");
                }

                context = new TSCudaContext(deviceIds, memoryUsageRatio, compilerOptions);
                context.Precompile(Console.Write);
                context.CleanUnusedPTX();

                allocator = new IAllocator[deviceIds.Length];
            }
            else
            {
                allocator = new IAllocator[1];
            }
        }
        static public void Initialization(string modelFilePath, int maxTestSentLength, ProcessorTypeEnums processorType, string deviceIds)
        {
            opts = new SeqClassificationOptions();
            opts.ModelFilePath     = modelFilePath;
            opts.MaxTestSentLength = maxTestSentLength;
            opts.ProcessorType     = processorType;
            opts.DeviceIds         = deviceIds;

            m_seqClassification = new SeqClassification(opts);
        }
Пример #9
0
 public SequenceLabel(string modelFilePath, ProcessorTypeEnums processorType, int[] deviceIds, float dropoutRatio = 0.0f)
     : base(deviceIds, processorType, modelFilePath)
 {
     m_dropoutRatio  = dropoutRatio;
     m_modelMetaData = LoadModel(CreateTrainableParameters) as Seq2SeqModelMetaData;
 }
Пример #10
0
        public AttentionSeq2Seq(int embeddingDim, int hiddenDim, int encoderLayerDepth, int decoderLayerDepth, Vocab vocab, string srcEmbeddingFilePath, string tgtEmbeddingFilePath,
                                string modelFilePath, float dropoutRatio, int multiHeadNum, ProcessorTypeEnums processorType, EncoderTypeEnums encoderType, DecoderTypeEnums decoderType, bool enableCoverageModel, int[] deviceIds,
                                bool isSrcEmbTrainable = true, bool isTgtEmbTrainable = true, bool isEncoderTrainable = true, bool isDecoderTrainable = true, int maxTgtSntSize = 128)
            : base(deviceIds, processorType, modelFilePath)
        {
            m_modelMetaData = new Seq2SeqModelMetaData(hiddenDim, embeddingDim, encoderLayerDepth, decoderLayerDepth, multiHeadNum, encoderType, decoderType, vocab, enableCoverageModel);
            m_dropoutRatio  = dropoutRatio;

            m_isSrcEmbTrainable  = isSrcEmbTrainable;
            m_isTgtEmbTrainable  = isTgtEmbTrainable;
            m_isEncoderTrainable = isEncoderTrainable;
            m_isDecoderTrainable = isDecoderTrainable;
            m_maxTgtSntSize      = maxTgtSntSize;

            //Initializng weights in encoders and decoders
            CreateTrainableParameters(m_modelMetaData);

            // Load external embedding from files
            for (int i = 0; i < DeviceIds.Length; i++)
            {
                //If pre-trained embedding weights are speicifed, loading them from files
                if (!string.IsNullOrEmpty(srcEmbeddingFilePath))
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{srcEmbeddingFilePath}' for source side.");
                    LoadWordEmbedding(srcEmbeddingFilePath, m_srcEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.SrcWordToIndex);
                }

                if (!string.IsNullOrEmpty(tgtEmbeddingFilePath))
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{tgtEmbeddingFilePath}' for target side.");
                    LoadWordEmbedding(tgtEmbeddingFilePath, m_tgtEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.TgtWordToIndex);
                }
            }
        }
Пример #11
0
        private static void Main(string[] args)
        {
            try
            {
                Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";
                ShowOptions(args);

                //Parse command line
                Options   opts      = new Options();
                ArgParser argParser = new ArgParser(args, opts);

                if (string.IsNullOrEmpty(opts.ConfigFilePath) == false)
                {
                    Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                    opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
                }

                AttentionSeq2Seq   ss            = null;
                ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
                EncoderTypeEnums   encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
                DecoderTypeEnums   decoderType   = (DecoderTypeEnums)Enum.Parse(typeof(DecoderTypeEnums), opts.DecoderType);
                ModeEnums          mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);
                ShuffleEnums       shuffleType   = (ShuffleEnums)Enum.Parse(typeof(ShuffleEnums), opts.ShuffleType);

                string[] cudaCompilerOptions = String.IsNullOrEmpty(opts.CompilerOptions) ? null : opts.CompilerOptions.Split(' ', StringSplitOptions.RemoveEmptyEntries);

                //Parse device ids from options
                int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
                if (mode == ModeEnums.Train)
                {
                    // Load train corpus
                    ParallelCorpus trainCorpus = new ParallelCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, batchSize: opts.BatchSize, shuffleBlockSize: opts.ShuffleBlockSize,
                                                                    maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: shuffleType);
                    // Load valid corpus
                    ParallelCorpus validCorpus = string.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxSrcSentLength, opts.MaxTgtSentLength);

                    // Create learning rate
                    ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                    // Create optimizer
                    AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                    // Create metrics
                    List <IMetric> metrics = new List <IMetric>
                    {
                        new BleuMetric(),
                        new LengthRatioMetric()
                    };


                    if (!String.IsNullOrEmpty(opts.ModelFilePath) && File.Exists(opts.ModelFilePath))
                    {
                        //Incremental training
                        Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                        ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds,
                                                  isSrcEmbTrainable: opts.IsSrcEmbeddingTrainable, isTgtEmbTrainable: opts.IsTgtEmbeddingTrainable, isEncoderTrainable: opts.IsEncoderTrainable, isDecoderTrainable: opts.IsDecoderTrainable,
                                                  maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    }
                    else
                    {
                        // Load or build vocabulary
                        Vocab vocab = null;
                        if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab))
                        {
                            // Vocabulary files are specified, so we load them
                            vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                        }
                        else
                        {
                            // We don't specify vocabulary, so we build it from train corpus
                            vocab = new Vocab(trainCorpus);
                        }

                        //New training
                        ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                                  srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath, vocab: vocab, modelFilePath: opts.ModelFilePath,
                                                  dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, decoderType: decoderType,
                                                  maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, enableCoverageModel: opts.EnableCoverageModel, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    }

                    // Add event handler for monitoring
                    ss.IterationDone += ss_IterationDone;

                    // Kick off training
                    ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics);
                }
                else if (mode == ModeEnums.Valid)
                {
                    Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                    // Create metrics
                    List <IMetric> metrics = new List <IMetric>
                    {
                        new BleuMetric(),
                        new LengthRatioMetric()
                    };

                    // Load valid corpus
                    ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxSrcSentLength, opts.MaxTgtSentLength);

                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions);
                    ss.Valid(validCorpus: validCorpus, metrics: metrics);
                }
                else if (mode == ModeEnums.Test)
                {
                    Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                    //Test trained model
                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, memoryUsageRatio: opts.MemoryUsageRatio,
                                              shuffleType: shuffleType, maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, compilerOptions: cudaCompilerOptions);

                    List <string> outputLines     = new List <string>();
                    string[]      data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                    foreach (string line in data_sents_raw1)
                    {
                        if (opts.BeamSearch > 1)
                        {
                            // Below support beam search
                            List <List <string> > outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch);
                            outputLines.AddRange(outputWordsList.Select(x => string.Join(" ", x)));
                        }
                        else
                        {
                            var outputTokensBatch = ss.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList()));
                            outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x)));
                        }
                    }

                    File.WriteAllLines(opts.OutputTestFile, outputLines);
                }
                else if (mode == ModeEnums.DumpVocab)
                {
                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, compilerOptions: cudaCompilerOptions);
                    ss.DumpVocabToFiles(opts.SrcVocab, opts.TgtVocab);
                }
                else
                {
                    argParser.Usage();
                }
            }
            catch (Exception err)
            {
                Logger.WriteLine($"Exception: '{err.Message}'");
                Logger.WriteLine($"Call stack: '{err.StackTrace}'");
            }
        }
Пример #12
0
 static public void Initialization(string modelFilePath, int maxTestSrcSentLength, int maxTestTgtSentLength,
                                   ProcessorTypeEnums processorType, string deviceIds, in (SentencePiece src, SentencePiece tgt) sentPieces)
        static public void Initialization(Dictionary <string, string> key2ModelFilePath, int maxTestSrcSentLength, int maxTestTgtSentLength, ProcessorTypeEnums processorType, string deviceIds)
        {
            foreach (var pair in key2ModelFilePath)
            {
                Logger.WriteLine($"Loading '{pair.Key}' model from '{pair.Value}'");

                opts = new Seq2SeqClassificationOptions();
                opts.ModelFilePath        = pair.Value;
                opts.MaxTestSrcSentLength = maxTestSrcSentLength;
                opts.MaxTestTgtSentLength = maxTestTgtSentLength;
                opts.ProcessorType        = processorType;
                opts.DeviceIds            = deviceIds;

                var inst = new Seq2SeqClassification(opts);

                m_key2Instance.Add(pair.Key, inst);
            }
        }
Пример #14
0
 public BaseSeq2SeqFramework(int[] deviceIds, ProcessorTypeEnums processorType, string modelFilePath, float memoryUsageRatio = 0.9f, string[] compilerOptions = null)
 {
     m_deviceIds = deviceIds;
     m_modelFilePath = modelFilePath;
     TensorAllocator.InitDevices(processorType, m_deviceIds, memoryUsageRatio, compilerOptions);
 }
Пример #15
0
        static void Main(string[] args)
        {
            ShowOptions(args);

            Logger.LogFile = $"{nameof(SeqLabelConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            //Parse command line
            Options   opts      = new Options();
            ArgParser argParser = new ArgParser(args, opts);

            if (String.IsNullOrEmpty(opts.ConfigFilePath) == false)
            {
                Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
            }


            SequenceLabel      sl            = null;
            ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
            EncoderTypeEnums   encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
            ModeEnums          mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);

            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
            if (mode == ModeEnums.Train)
            {
                // Load train corpus
                ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false);

                // Load valid corpus
                ParallelCorpus validCorpus = String.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false);

                // Load or build vocabulary
                Vocab vocab = null;
                if (!String.IsNullOrEmpty(opts.SrcVocab) && !String.IsNullOrEmpty(opts.TgtVocab))
                {
                    // Vocabulary files are specified, so we load them
                    vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                }
                else
                {
                    // We don't specify vocabulary, so we build it from train corpus
                    vocab = new Vocab(trainCorpus);
                }

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                // Create metrics
                List <IMetric> metrics = new List <IMetric>();
                foreach (var word in vocab.TgtVocab)
                {
                    metrics.Add(new SequenceLabelFscoreMetric(word));
                }

                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    sl = new SequenceLabel(hiddenDim: opts.HiddenSize, embeddingDim: opts.WordVectorSize, encoderLayerDepth: opts.EncoderLayerDepth, multiHeadNum: opts.MultiHeadNum,
                                           encoderType: encoderType,
                                           dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds, processorType: processorType, modelFilePath: opts.ModelFilePath, vocab: vocab);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, dropoutRatio: opts.DropoutRatio);
                }

                // Add event handler for monitoring
                sl.IterationDone += ss_IterationDone;

                // Kick off training
                sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics);
            }
            else if (mode == ModeEnums.Valid)
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                // Load valid corpus
                ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, false);

                Vocab vocab = new Vocab(validCorpus);
                // Create metrics
                List <IMetric> metrics = new List <IMetric>();
                foreach (var word in vocab.TgtVocab)
                {
                    metrics.Add(new SequenceLabelFscoreMetric(word));
                }

                sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);
                sl.Valid(validCorpus: validCorpus, metrics: metrics);
            }
            else if (mode == ModeEnums.Test)
            {
                Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                //Test trained model
                sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    var outputTokensBatch = sl.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList(), false));
                    outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputTestFile, outputLines);
            }
            //else if (mode == ModeEnums.VisualizeNetwork)
            //{
            //    ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
            //        vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio,
            //        processorType: processorType, deviceIds: new int[1] { 0 }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType);

            //    ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath);
            //}
            else
            {
                argParser.Usage();
            }
        }
Пример #16
0
 public BaseSeq2SeqFramework(int[] deviceIds, ProcessorTypeEnums processorType, string modelFilePath)
 {
     m_deviceIds     = deviceIds;
     m_modelFilePath = modelFilePath;
     TensorAllocator.InitDevices(processorType, m_deviceIds);
 }