Example #1
0
        private bool CreateTrainableParameters(IModelMetaData mmd)
        {
            Logger.WriteLine($"Creating encoders and decoders...");
            Seq2SeqModelMetaData modelMetaData = mmd as Seq2SeqModelMetaData;
            RoundArray <int>     raDeviceIds   = new RoundArray <int>(DeviceIds);

            if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM)
            {
                m_encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem()), DeviceIds);
                m_decoderFFLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", modelMetaData.HiddenDim * 2, modelMetaData.Vocab.TargetWordSize, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem()), DeviceIds);
            }
            else
            {
                m_encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, m_dropoutRatio, raDeviceIds.GetNextItem()), DeviceIds);
                m_decoderFFLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", modelMetaData.HiddenDim, modelMetaData.Vocab.TargetWordSize, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem()), DeviceIds);
            }

            m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] {
                modelMetaData.Vocab.SourceWordSize, modelMetaData.EmbeddingDim
            }, raDeviceIds.GetNextItem(), normal: true, name: "SrcEmbeddings", isTrainable: true), DeviceIds);
            //      m_crfDecoder = new CRFDecoder(modelMetaData.Vocab.TargetWordSize);

            return(true);
        }
Example #2
0
        private bool CreateTrainableParameters(IModelMetaData mmd)
        {
            Logger.WriteLine($"Creating encoders and decoders...");
            var modelMetaData = mmd as SeqLabelModelMetaData;
            var raDeviceIds   = new RoundArray <int>(this.DeviceIds);

            if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM)
            {
                this.m_encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem(), true), this.DeviceIds);
                this.m_decoderFFLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", modelMetaData.HiddenDim * 2, modelMetaData.Vocab.TargetWordSize, 0.0f, raDeviceIds.GetNextItem(), true), this.DeviceIds);
            }
            else
            {
                this.m_encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, this.m_dropoutRatio, raDeviceIds.GetNextItem(), true), this.DeviceIds);
                this.m_decoderFFLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", modelMetaData.HiddenDim, modelMetaData.Vocab.TargetWordSize, 0.0f, raDeviceIds.GetNextItem(), true), this.DeviceIds);
            }

            this.m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] {
                modelMetaData.Vocab.SourceWordSize, modelMetaData.EmbeddingDim
            }, raDeviceIds.GetNextItem(), normal: NormType.Normal, name: "SrcEmbeddings", isTrainable: true), this.DeviceIds);
            //      m_crfDecoder = new CRFDecoder(modelMetaData.Vocab.TargetWordSize);


            this.m_posEmbedding = modelMetaData.EncoderType == EncoderTypeEnums.Transformer ? new MultiProcessorNetworkWrapper <IWeightTensor>(this.BuildPositionWeightTensor(Math.Max(this.m_maxSntSize, this.m_maxSntSize) + 2, modelMetaData.EmbeddingDim, raDeviceIds.GetNextItem(), "PosEmbedding", false), this.DeviceIds, true) : null;

            return(true);
        }
Example #3
0
        private bool CreateTrainableParameters(IModelMetaData mmd)
        {
            Logger.WriteLine($"Creating encoders and decoders...");
            Seq2SeqModelMetaData modelMetaData = mmd as Seq2SeqModelMetaData;
            RoundArray <int>     raDeviceIds   = new RoundArray <int>(DeviceIds);

            if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM)
            {
                m_encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem(), isTrainable: m_isEncoderTrainable), DeviceIds);
                m_decoder = new MultiProcessorNetworkWrapper <AttentionDecoder>(
                    new AttentionDecoder("AttnLSTMDecoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.HiddenDim * 2,
                                         modelMetaData.Vocab.TargetWordSize, m_dropoutRatio, modelMetaData.DecoderLayerDepth, raDeviceIds.GetNextItem(), modelMetaData.EnableCoverageModel, isTrainable: m_isDecoderTrainable), DeviceIds);
            }
            else
            {
                m_encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, m_dropoutRatio, raDeviceIds.GetNextItem(),
                                           isTrainable: m_isEncoderTrainable), DeviceIds);
                m_decoder = new MultiProcessorNetworkWrapper <AttentionDecoder>(
                    new AttentionDecoder("AttnLSTMDecoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.HiddenDim,
                                         modelMetaData.Vocab.TargetWordSize, m_dropoutRatio, modelMetaData.DecoderLayerDepth, raDeviceIds.GetNextItem(), modelMetaData.EnableCoverageModel, isTrainable: m_isDecoderTrainable), DeviceIds);
            }
            m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] {
                modelMetaData.Vocab.SourceWordSize, modelMetaData.EmbeddingDim
            }, raDeviceIds.GetNextItem(), normal: true, name: "SrcEmbeddings", isTrainable: m_isSrcEmbTrainable), DeviceIds);
            m_tgtEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] {
                modelMetaData.Vocab.TargetWordSize, modelMetaData.EmbeddingDim
            }, raDeviceIds.GetNextItem(), normal: true, name: "TgtEmbeddings", isTrainable: m_isTgtEmbTrainable), DeviceIds);

            return(true);
        }
Example #4
0
        private bool CreateTrainableParameters(IModel model)
        {
            Logger.WriteLine($"Creating encoders and decoders...");
            var raDeviceIds = new RoundArray <int>(DeviceIds);

            int contextDim;

            (m_encoder, contextDim) = Encoder.CreateEncoders(model, m_options, raDeviceIds);
            m_ffLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", contextDim, model.ClsVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds);

            m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] {
                model.SrcVocab.Count, model.EncoderEmbeddingDim
            }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, name: "SrcEmbeddings", isTrainable: true), DeviceIds);

            if (model.EncoderType == EncoderTypeEnums.Transformer)
            {
                m_posEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(PositionEmbedding.BuildPositionWeightTensor(m_options.MaxTestSentLength + 2, model.EncoderEmbeddingDim, raDeviceIds.GetNextItem(), "PosEmbedding", false), DeviceIds, true);
            }
            else
            {
                m_posEmbedding = null;
            }

            return(true);
        }
Example #5
0
        private bool CreateTrainableParameters(IModel model)
        {
            Logger.WriteLine($"Creating encoders and decoders...");
            var raDeviceIds = new RoundArray <int>(DeviceIds);

            int contextDim;

            (m_encoder, contextDim) = Encoder.CreateEncoders(model, m_options, raDeviceIds);
            m_decoder        = Decoder.CreateDecoders(model, m_options, raDeviceIds, contextDim);
            m_decoderFFLayer = new MultiProcessorNetworkWrapper <IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
                                                                                                         isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor), DeviceIds);
            (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, contextDim, Math.Max(Math.Max(m_options.MaxTrainSrcSentLength, m_options.MaxTestSrcSentLength), Math.Max(m_options.MaxTrainTgtSentLength, m_options.MaxTestTgtSentLength)), model);
            (m_srcEmbedding, m_tgtEmbedding)     = CreateSrcTgtEmbeddings(model, raDeviceIds, m_options.IsSrcEmbeddingTrainable, m_options.IsTgtEmbeddingTrainable, m_options.EncoderStartLearningRateFactor, m_options.DecoderStartLearningRateFactor);


            if (model.PointerGenerator)
            {
                if (model.SharedEmbeddings == false)
                {
                    throw new ArgumentException($"Shared embeddings is required to true for pointer generator.");
                }

                Logger.WriteLine($"Create pointer generator weights...");
                m_pointerGenerator = new MultiProcessorNetworkWrapper <IFeedForwardLayer>(new FeedForwardLayer("PointerGenerator_0", model.HiddenDim, 1, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
                                                                                                               isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor), DeviceIds);
            }
            else
            {
                m_pointerGenerator = null;
            }

            return(true);
        }
Example #6
0
        private bool CreateTrainableParameters(IModelMetaData mmd)
        {
            Logger.WriteLine($"Creating encoders and decoders...");
            Seq2SeqModelMetaData modelMetaData = mmd as Seq2SeqModelMetaData;
            RoundArray <int>     raDeviceIds   = new RoundArray <int>(DeviceIds);

            int contextDim = 0;

            if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM)
            {
                m_encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.SrcEmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem(), isTrainable: m_isEncoderTrainable), DeviceIds);

                contextDim = modelMetaData.HiddenDim * 2;
            }
            else
            {
                m_encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.SrcEmbeddingDim, modelMetaData.EncoderLayerDepth, m_dropoutRatio, raDeviceIds.GetNextItem(),
                                           isTrainable: m_isEncoderTrainable), DeviceIds);

                contextDim = modelMetaData.HiddenDim;
            }

            if (modelMetaData.DecoderType == DecoderTypeEnums.AttentionLSTM)
            {
                m_decoder = new MultiProcessorNetworkWrapper <IDecoder>(
                    new AttentionDecoder("AttnLSTMDecoder", modelMetaData.HiddenDim, modelMetaData.TgtEmbeddingDim, contextDim,
                                         modelMetaData.Vocab.TargetWordSize, m_dropoutRatio, modelMetaData.DecoderLayerDepth, raDeviceIds.GetNextItem(), modelMetaData.EnableCoverageModel, isTrainable: m_isDecoderTrainable), DeviceIds);
            }
            else
            {
                m_decoder = new MultiProcessorNetworkWrapper <IDecoder>(
                    new TransformerDecoder("TransformerDecoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.TgtEmbeddingDim, modelMetaData.Vocab.TargetWordSize, modelMetaData.EncoderLayerDepth, m_dropoutRatio, raDeviceIds.GetNextItem(),
                                           isTrainable: m_isDecoderTrainable), DeviceIds);
            }

            if (modelMetaData.EncoderType == EncoderTypeEnums.Transformer || modelMetaData.DecoderType == DecoderTypeEnums.Transformer)
            {
                m_posEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(BuildPositionWeightTensor(Math.Max(m_maxSrcSntSize, m_maxTgtSntSize) + 2, contextDim, raDeviceIds.GetNextItem(), "PosEmbedding", false), DeviceIds, true);
            }
            else
            {
                m_posEmbedding = null;
            }

            Logger.WriteLine($"Creating embeddings for source side. Shape = '({modelMetaData.Vocab.SourceWordSize} ,{modelMetaData.SrcEmbeddingDim})'");
            m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] {
                modelMetaData.Vocab.SourceWordSize, modelMetaData.SrcEmbeddingDim
            }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: m_isSrcEmbTrainable), DeviceIds);

            Logger.WriteLine($"Creating embeddings for target side. Shape = '({modelMetaData.Vocab.TargetWordSize} ,{modelMetaData.TgtEmbeddingDim})'");
            m_tgtEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] {
                modelMetaData.Vocab.TargetWordSize, modelMetaData.TgtEmbeddingDim
            }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: m_isTgtEmbTrainable), DeviceIds);

            return(true);
        }