Exemple #1
0
        private bool CreateTrainableParameters(IModel model)
        {
            Logger.WriteLine($"Creating encoders...");
            var raDeviceIds = new RoundArray <int>(DeviceIds);

            int contextDim;

            (m_encoder, contextDim) = Encoder.CreateEncoders(model, m_options, raDeviceIds);
            m_encoderFFLayer        = new MultiProcessorNetworkWrapper <IFeedForwardLayer>(new FeedForwardLayer($"FeedForward_Encoder", contextDim, model.ClsVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds);
            (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, contextDim, Math.Max(m_options.MaxTrainSentLength, m_options.MaxTestSentLength), model);

            Logger.WriteLine($"Creating embeddings. Shape = '({model.SrcVocab.Count} ,{model.EncoderEmbeddingDim})'");
            m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] {
                model.SrcVocab.Count, model.EncoderEmbeddingDim
            }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: m_options.IsEmbeddingTrainable), DeviceIds);

            return(true);
        }
Exemple #2
0
        public static MultiProcessorNetworkWrapper <IDecoder> CreateDecoders(IModel modelMetaData, Seq2SeqOptions options, RoundArray <int> raDeviceIds, int contextDim)
        {
            MultiProcessorNetworkWrapper <IDecoder> decoder;

            if (modelMetaData.DecoderType == DecoderTypeEnums.AttentionLSTM)
            {
                decoder = new MultiProcessorNetworkWrapper <IDecoder>(
                    new AttentionDecoder("AttnLSTMDecoder", modelMetaData.HiddenDim, modelMetaData.DecoderEmbeddingDim, contextDim,
                                         options.DropoutRatio, modelMetaData.DecoderLayerDepth, raDeviceIds.GetNextItem(), modelMetaData.EnableCoverageModel, isTrainable: options.IsDecoderTrainable), raDeviceIds.ToArray());
            }
            else
            {
                decoder = new MultiProcessorNetworkWrapper <IDecoder>(
                    new TransformerDecoder("TransformerDecoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.DecoderEmbeddingDim, modelMetaData.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(),
                                           isTrainable: options.IsDecoderTrainable, learningRateFactor: options.DecoderStartLearningRateFactor), raDeviceIds.ToArray());
            }

            return(decoder);
        }
Exemple #3
0
        public static (MultiProcessorNetworkWrapper <IEncoder>, int) CreateEncoders(IModel modelMetaData, Options options, RoundArray <int> raDeviceIds)
        {
            int contextDim;
            MultiProcessorNetworkWrapper <IEncoder> encoder = null;

            if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM)
            {
                encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.EncoderEmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem(), isTrainable: options.IsEncoderTrainable), raDeviceIds.ToArray());

                contextDim = modelMetaData.HiddenDim * 2;
            }
            else
            {
                encoder = new MultiProcessorNetworkWrapper <IEncoder>(
                    new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.EncoderEmbeddingDim, modelMetaData.EncoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(),
                                           isTrainable: options.IsEncoderTrainable, learningRateFactor: options.EncoderStartLearningRateFactor), raDeviceIds.ToArray());

                contextDim = modelMetaData.HiddenDim;
            }

            return(encoder, contextDim);
        }
        private bool CreateTrainableParameters(IModel model)
        {
            Logger.WriteLine($"Creating encoders and decoders...");
            var raDeviceIds = new RoundArray <int>(DeviceIds);

            int contextDim;

            (m_encoder, contextDim) = Encoder.CreateEncoders(model, m_options, raDeviceIds);
            m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, contextDim);

            m_encoderFFLayer = new MultiProcessorNetworkWrapper <IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Encoder_0", model.HiddenDim, model.ClsVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
                                                                                                         isTrainable: true), DeviceIds);

            m_decoderFFLayer = new MultiProcessorNetworkWrapper <IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
                                                                                                         isTrainable: true), DeviceIds);

            (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, contextDim, Math.Max(Math.Max(m_options.MaxTrainSrcSentLength, m_options.MaxTestSrcSentLength), Math.Max(m_options.MaxTrainTgtSentLength, m_options.MaxTestTgtSentLength)), model);
            (m_srcEmbedding, m_tgtEmbedding)     = CreateSrcTgtEmbeddings(model, raDeviceIds, m_options.IsSrcEmbeddingTrainable, m_options.IsTgtEmbeddingTrainable, m_options.EncoderStartLearningRateFactor, m_options.DecoderStartLearningRateFactor);
            return(true);
        }