public string Classify(string input)
        {
            List <string> inputGroups = input.Split('\t').ToList();
            var           output      = SeqClassificationInstance.Call(inputGroups);

            Logger.WriteLine($"'{input}' -> '{output}'");
            return(output);
        }
Ejemplo n.º 2
0
        public Startup(IConfiguration configuration)
        {
            int maxTestSrcSentLength;
            int maxTestTgtSentLength;
            ProcessorTypeEnums processorType;
            string             deviceIds;

            Configuration = configuration;

            if (!Configuration["Seq2Seq:ModelFilePath"].IsNullOrEmpty())
            {
                Logger.WriteLine($"Loading Seq2Seq model '{Configuration[ "Seq2Seq:ModelFilePath" ]}'");

                var modelFilePath = Configuration["Seq2Seq:ModelFilePath"];
                maxTestSrcSentLength = Configuration["Seq2Seq:MaxSrcTokenSize"].ToInt();
                maxTestTgtSentLength = Configuration["Seq2Seq:MaxTgtTokenSize"].ToInt();
                processorType        = Configuration["Seq2Seq:ProcessorType"].ToEnum <ProcessorTypeEnums>();
                deviceIds            = Configuration["Seq2Seq:DeviceIds"];

                var srcSentPiece = new SentencePiece(Configuration["Seq2Seq:SrcSentencePieceModelPath"]);
                var tgtSentPiece = new SentencePiece(Configuration["Seq2Seq:TgtSentencePieceModelPath"]);

                Seq2SeqInstance.Initialization(modelFilePath, maxTestSrcSentLength, maxTestTgtSentLength,
                                               processorType, deviceIds, (srcSentPiece, tgtSentPiece));
            }

            if (!Configuration["SeqClassification:ModelFilePath"].IsNullOrEmpty())
            {
                Logger.WriteLine($"Loading SeqClassification model '{Configuration[ "SeqClassification:ModelFilePath" ]}'");

                var modelFilePath     = Configuration["SeqClassification:ModelFilePath"];
                int maxTestSentLength = Configuration["SeqClassification:MaxTokenSize"].ToInt();
                processorType = Configuration["SeqClassification:ProcessorType"].ToEnum <ProcessorTypeEnums>();
                deviceIds     = Configuration["SeqClassification:DeviceIds"];

                SeqClassificationInstance.Initialization(modelFilePath, maxTestSentLength, processorType, deviceIds);
            }

            if (!Configuration["SeqSimilarity:ModelFilePath"].IsNullOrEmpty())
            {
                Logger.WriteLine($"Loading SeqSimilarity model '{Configuration[ "SeqSimilarity:ModelFilePath" ]}'");

                var modelFilePath     = Configuration["SeqSimilarity:ModelFilePath"];
                int maxTestSentLength = Configuration["SeqSimilarity:MaxTokenSize"].ToInt();
                processorType = Configuration["SeqSimilarity:ProcessorType"].ToEnum <ProcessorTypeEnums>();
                deviceIds     = Configuration["SeqSimilarity:DeviceIds"];

                SeqSimilarityInstance.Initialization(modelFilePath, maxTestSentLength, processorType, deviceIds);
            }

            //Loading Seq2SeqClassification models
            if (!Configuration["Seq2SeqClassification:ProcessorType"].IsNullOrEmpty())
            {
                int i = 0;
                Dictionary <string, string> key2ModelFilePath = new Dictionary <string, string>();
                while (true)
                {
                    string key      = $"Seq2SeqClassification:Models:{i}:Key";
                    string filePath = $"Seq2SeqClassification:Models:{i}:FilePath";
                    if (Configuration[key].IsNullOrEmpty())
                    {
                        break;
                    }
                    key2ModelFilePath.Add(Configuration[key], Configuration[filePath]);

                    i++;
                }

                maxTestSrcSentLength = Configuration["Seq2SeqClassification:MaxSrcTokenSize"].ToInt();
                maxTestTgtSentLength = Configuration["Seq2SeqClassification:MaxTgtTokenSize"].ToInt();
                processorType        = Configuration["Seq2SeqClassification:ProcessorType"].ToEnum <ProcessorTypeEnums>();
                deviceIds            = Configuration["Seq2SeqClassification:DeviceIds"];

                if (key2ModelFilePath.Count > 0)
                {
                    Seq2SeqClassificationInstances.Initialization(key2ModelFilePath, maxTestSrcSentLength, maxTestTgtSentLength, processorType, deviceIds);
                }
            }
        }