public SntPairBatchStreamReader(string srcFilePath, string tgtFilePath, int batchSize, int maxSentLength, string srcSPMPath = null, string tgtSPMPath = null)
        {
            currentIdx         = 0;
            this.maxSentLength = maxSentLength;
            this.batchSize     = batchSize;

            Logger.WriteLine($"Loading lines from '{srcFilePath}'");
            srcLines = File.ReadAllLines(srcFilePath);

            Logger.WriteLine($"Loading lines from '{tgtFilePath}'");
            tgtLines = File.ReadAllLines(tgtFilePath);

            if (srcLines.Length != tgtLines.Length)
            {
                throw new DataMisalignedException($"The number of lines between source file '{srcFilePath}' (line# '{srcLines.Length}') and target file '{tgtFilePath}' (line# '{tgtLines.Length}') are different.");
            }

            if (String.IsNullOrEmpty(srcSPMPath) == false)
            {
                Logger.WriteLine($"Loading sentence piece model '{srcSPMPath}' for encoding.");
                srcSP = new SentencePiece(srcSPMPath);
            }

            if (String.IsNullOrEmpty(tgtSPMPath) == false)
            {
                Logger.WriteLine($"Loading sentence piece model '{tgtSPMPath}' for encoding.");
                tgtSP = new SentencePiece(tgtSPMPath);
            }
        }
예제 #2
0
        public SntBatchStreamWriter(string filePath, string sentencePieceModelPath = null)
        {
            this.filePath = filePath;
            sw            = new StreamWriter(filePath, false);

            if (String.IsNullOrEmpty(sentencePieceModelPath) == false)
            {
                Logger.WriteLine($"Loading sentence piece model '{sentencePieceModelPath}' for decoding.");
                sp = new SentencePiece(sentencePieceModelPath);
            }
        }
예제 #3
0
        public SntBatchStreamReader(string filePath, int batchSize, int maxSentLength, string sentencePieceModelPath = null)
        {
            currentIdx         = 0;
            this.maxSentLength = maxSentLength;
            this.batchSize     = batchSize;
            lines = File.ReadAllLines(filePath);

            if (String.IsNullOrEmpty(sentencePieceModelPath) == false)
            {
                Logger.WriteLine($"Loading sentence piece model '{sentencePieceModelPath}' for encoding.");
                sp = new SentencePiece(sentencePieceModelPath);
            }
        }
예제 #4
0
        public Startup(IConfiguration configuration)
        {
            int maxTestSrcSentLength;
            int maxTestTgtSentLength;
            ProcessorTypeEnums processorType;
            string             deviceIds;

            Configuration = configuration;

            if (!Configuration["Seq2Seq:ModelFilePath"].IsNullOrEmpty())
            {
                Logger.WriteLine($"Loading Seq2Seq model '{Configuration[ "Seq2Seq:ModelFilePath" ]}'");

                var modelFilePath = Configuration["Seq2Seq:ModelFilePath"];
                maxTestSrcSentLength = Configuration["Seq2Seq:MaxSrcTokenSize"].ToInt();
                maxTestTgtSentLength = Configuration["Seq2Seq:MaxTgtTokenSize"].ToInt();
                processorType        = Configuration["Seq2Seq:ProcessorType"].ToEnum <ProcessorTypeEnums>();
                deviceIds            = Configuration["Seq2Seq:DeviceIds"];

                var srcSentPiece = new SentencePiece(Configuration["Seq2Seq:SrcSentencePieceModelPath"]);
                var tgtSentPiece = new SentencePiece(Configuration["Seq2Seq:TgtSentencePieceModelPath"]);

                Seq2SeqInstance.Initialization(modelFilePath, maxTestSrcSentLength, maxTestTgtSentLength,
                                               processorType, deviceIds, (srcSentPiece, tgtSentPiece));
            }

            if (!Configuration["SeqClassification:ModelFilePath"].IsNullOrEmpty())
            {
                Logger.WriteLine($"Loading SeqClassification model '{Configuration[ "SeqClassification:ModelFilePath" ]}'");

                var modelFilePath     = Configuration["SeqClassification:ModelFilePath"];
                int maxTestSentLength = Configuration["SeqClassification:MaxTokenSize"].ToInt();
                processorType = Configuration["SeqClassification:ProcessorType"].ToEnum <ProcessorTypeEnums>();
                deviceIds     = Configuration["SeqClassification:DeviceIds"];

                SeqClassificationInstance.Initialization(modelFilePath, maxTestSentLength, processorType, deviceIds);
            }

            if (!Configuration["SeqSimilarity:ModelFilePath"].IsNullOrEmpty())
            {
                Logger.WriteLine($"Loading SeqSimilarity model '{Configuration[ "SeqSimilarity:ModelFilePath" ]}'");

                var modelFilePath     = Configuration["SeqSimilarity:ModelFilePath"];
                int maxTestSentLength = Configuration["SeqSimilarity:MaxTokenSize"].ToInt();
                processorType = Configuration["SeqSimilarity:ProcessorType"].ToEnum <ProcessorTypeEnums>();
                deviceIds     = Configuration["SeqSimilarity:DeviceIds"];

                SeqSimilarityInstance.Initialization(modelFilePath, maxTestSentLength, processorType, deviceIds);
            }

            //Loading Seq2SeqClassification models
            if (!Configuration["Seq2SeqClassification:ProcessorType"].IsNullOrEmpty())
            {
                int i = 0;
                Dictionary <string, string> key2ModelFilePath = new Dictionary <string, string>();
                while (true)
                {
                    string key      = $"Seq2SeqClassification:Models:{i}:Key";
                    string filePath = $"Seq2SeqClassification:Models:{i}:FilePath";
                    if (Configuration[key].IsNullOrEmpty())
                    {
                        break;
                    }
                    key2ModelFilePath.Add(Configuration[key], Configuration[filePath]);

                    i++;
                }

                maxTestSrcSentLength = Configuration["Seq2SeqClassification:MaxSrcTokenSize"].ToInt();
                maxTestTgtSentLength = Configuration["Seq2SeqClassification:MaxTgtTokenSize"].ToInt();
                processorType        = Configuration["Seq2SeqClassification:ProcessorType"].ToEnum <ProcessorTypeEnums>();
                deviceIds            = Configuration["Seq2SeqClassification:DeviceIds"];

                if (key2ModelFilePath.Count > 0)
                {
                    Seq2SeqClassificationInstances.Initialization(key2ModelFilePath, maxTestSrcSentLength, maxTestTgtSentLength, processorType, deviceIds);
                }
            }
        }