예제 #1
0
        public SeqLabel(SeqLabelOptions options, Vocab srcVocab = null, Vocab clsVocab = null)
            : base(options.DeviceIds, options.ProcessorType, options.ModelFilePath, options.MemoryUsageRatio, options.CompilerOptions, options.ValidIntervalHours, updateFreq: options.UpdateFreq)
        {
            m_shuffleType = options.ShuffleType;
            m_options     = options;

            // Model must exist if current task is not for training
            if ((m_options.Task != ModeEnums.Train) && !File.Exists(m_options.ModelFilePath))
            {
                throw new FileNotFoundException($"Model '{m_options.ModelFilePath}' doesn't exist.");
            }

            if (File.Exists(m_options.ModelFilePath))
            {
                if (srcVocab != null || clsVocab != null)
                {
                    throw new ArgumentException($"Model '{m_options.ModelFilePath}' exists and it includes vocabulary, so input vocabulary must be null.");
                }

                // Model file exists, so we load it from file.
                m_modelMetaData = LoadModelImpl_WITH_CONVERT(CreateTrainableParameters);
                //m_modelMetaData = LoadModelImpl();
                //---LoadModel_As_BinaryFormatter( CreateTrainableParameters );
            }
            else
            {
                // Model doesn't exist, we create it and initlaize parameters
                m_modelMetaData = new SeqLabelModel(options.HiddenSize, options.EmbeddingDim, options.EncoderLayerDepth, options.MultiHeadNum, options.EncoderType, srcVocab, clsVocab, options.MaxSegmentNum);

                //Initializng weights in encoders and decoders
                CreateTrainableParameters(m_modelMetaData);
            }

            m_modelMetaData.ShowModelInfo();
        }
예제 #2
0
 public SeqLabelModelMetaData(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType, Vocab vocab)
 {
     HiddenDim         = hiddenDim;
     EmbeddingDim      = embeddingDim;
     EncoderLayerDepth = encoderLayerDepth;
     MultiHeadNum      = multiHeadNum;
     EncoderType       = encoderType;
     Vocab             = vocab;
 }
예제 #3
0
        public SequenceLabel(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType,
                             float dropoutRatio, Vocab vocab, int[] deviceIds, ProcessorTypeEnums processorType, string modelFilePath) :
            base(deviceIds, processorType, modelFilePath)
        {
            m_modelMetaData = new Seq2SeqModelMetaData(hiddenDim, embeddingDim, encoderLayerDepth, 0, multiHeadNum, encoderType, vocab);
            m_dropoutRatio  = dropoutRatio;

            //Initializng weights in encoders and decoders
            CreateTrainableParameters(m_modelMetaData);
        }
예제 #4
0
 public Seq2SeqModelMetaData(int hiddenDim, int embeddingDim, int encoderLayerDepth, int decoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType, Vocab vocab, bool enableCoverageModel)
 {
     HiddenDim           = hiddenDim;
     EmbeddingDim        = embeddingDim;
     EncoderLayerDepth   = encoderLayerDepth;
     DecoderLayerDepth   = decoderLayerDepth;
     MultiHeadNum        = multiHeadNum;
     EncoderType         = encoderType;
     Vocab               = vocab;
     EnableCoverageModel = enableCoverageModel;
 }
예제 #5
0
 public Seq2SeqModel(int hiddenDim, int encoderEmbeddingDim, int decoderEmbeddingDim, int encoderLayerDepth, int decoderLayerDepth, int multiHeadNum,
                     EncoderTypeEnums encoderType, DecoderTypeEnums decoderType, Vocab srcVocab, Vocab tgtVocab, bool enableCoverageModel,
                     bool sharedEmbeddings, bool enableSegmentEmbeddings, bool enableTagEmbeddings, int maxSegmentNum, bool pointerGenerator)
     : base(hiddenDim, encoderLayerDepth, encoderType, encoderEmbeddingDim, multiHeadNum, srcVocab, enableSegmentEmbeddings, enableTagEmbeddings, maxSegmentNum, pointerGenerator)
 {
     DecoderEmbeddingDim = decoderEmbeddingDim;
     DecoderLayerDepth   = decoderLayerDepth;
     MultiHeadNum        = multiHeadNum;
     DecoderType         = decoderType;
     EnableCoverageModel = enableCoverageModel;
     SharedEmbeddings    = sharedEmbeddings;
     TgtVocab            = tgtVocab;
 }
예제 #6
0
        public AttentionSeq2Seq(int embeddingDim, int hiddenDim, int encoderLayerDepth, int decoderLayerDepth, Vocab vocab, string srcEmbeddingFilePath, string tgtEmbeddingFilePath,
                                string modelFilePath, float dropoutRatio, int multiHeadNum, ProcessorTypeEnums processorType, EncoderTypeEnums encoderType, DecoderTypeEnums decoderType, bool enableCoverageModel, int[] deviceIds,
                                bool isSrcEmbTrainable = true, bool isTgtEmbTrainable = true, bool isEncoderTrainable = true, bool isDecoderTrainable = true, int maxTgtSntSize = 128)
            : base(deviceIds, processorType, modelFilePath)
        {
            m_modelMetaData = new Seq2SeqModelMetaData(hiddenDim, embeddingDim, encoderLayerDepth, decoderLayerDepth, multiHeadNum, encoderType, decoderType, vocab, enableCoverageModel);
            m_dropoutRatio  = dropoutRatio;

            m_isSrcEmbTrainable  = isSrcEmbTrainable;
            m_isTgtEmbTrainable  = isTgtEmbTrainable;
            m_isEncoderTrainable = isEncoderTrainable;
            m_isDecoderTrainable = isDecoderTrainable;
            m_maxTgtSntSize      = maxTgtSntSize;

            //Initializng weights in encoders and decoders
            CreateTrainableParameters(m_modelMetaData);

            // Load external embedding from files
            for (int i = 0; i < DeviceIds.Length; i++)
            {
                //If pre-trained embedding weights are speicifed, loading them from files
                if (!string.IsNullOrEmpty(srcEmbeddingFilePath))
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{srcEmbeddingFilePath}' for source side.");
                    LoadWordEmbedding(srcEmbeddingFilePath, m_srcEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.SrcWordToIndex);
                }

                if (!string.IsNullOrEmpty(tgtEmbeddingFilePath))
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{tgtEmbeddingFilePath}' for target side.");
                    LoadWordEmbedding(tgtEmbeddingFilePath, m_tgtEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.TgtWordToIndex);
                }
            }
        }
예제 #7
0
 public SeqLabelModel(int hiddenDim, int embeddingDim, int encoderLayerDepth, int multiHeadNum, EncoderTypeEnums encoderType, Vocab srcVocab, Vocab clsVocab, int maxSegmentNum)
     : base(hiddenDim, encoderLayerDepth, encoderType, embeddingDim, multiHeadNum, srcVocab, false, false, maxSegmentNum, false)
 {
     ClsVocab = clsVocab;
 }