示例#1
0
        public AttentionSeq2Seq(int embeddingDim, int hiddenDim, int encoderLayerDepth, int decoderLayerDepth, Vocab vocab, string srcEmbeddingFilePath, string tgtEmbeddingFilePath,
                                string modelFilePath, float dropoutRatio, int multiHeadNum, ProcessorTypeEnums processorType, EncoderTypeEnums encoderType, DecoderTypeEnums decoderType, bool enableCoverageModel, int[] deviceIds,
                                bool isSrcEmbTrainable = true, bool isTgtEmbTrainable = true, bool isEncoderTrainable = true, bool isDecoderTrainable = true, int maxTgtSntSize = 128)
            : base(deviceIds, processorType, modelFilePath)
        {
            m_modelMetaData = new Seq2SeqModelMetaData(hiddenDim, embeddingDim, encoderLayerDepth, decoderLayerDepth, multiHeadNum, encoderType, decoderType, vocab, enableCoverageModel);
            m_dropoutRatio  = dropoutRatio;

            m_isSrcEmbTrainable  = isSrcEmbTrainable;
            m_isTgtEmbTrainable  = isTgtEmbTrainable;
            m_isEncoderTrainable = isEncoderTrainable;
            m_isDecoderTrainable = isDecoderTrainable;
            m_maxTgtSntSize      = maxTgtSntSize;

            //Initializng weights in encoders and decoders
            CreateTrainableParameters(m_modelMetaData);

            // Load external embedding from files
            for (int i = 0; i < DeviceIds.Length; i++)
            {
                //If pre-trained embedding weights are speicifed, loading them from files
                if (!string.IsNullOrEmpty(srcEmbeddingFilePath))
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{srcEmbeddingFilePath}' for source side.");
                    LoadWordEmbedding(srcEmbeddingFilePath, m_srcEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.SrcWordToIndex);
                }

                if (!string.IsNullOrEmpty(tgtEmbeddingFilePath))
                {
                    Logger.WriteLine($"Loading ExtEmbedding model from '{tgtEmbeddingFilePath}' for target side.");
                    LoadWordEmbedding(tgtEmbeddingFilePath, m_tgtEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.TgtWordToIndex);
                }
            }
        }
示例#2
0
 /// <summary>
 /// Get networks on specific devices
 /// </summary>
 private (IEncoder, IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor, IWeightTensor, IFeedForwardLayer) GetNetworksOnDeviceAt(int deviceIdIdx)
 {
     return(m_encoder.GetNetworkOnDevice(deviceIdIdx),
            m_decoder.GetNetworkOnDevice(deviceIdIdx),
            m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx),
            m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx),
            m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
            m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx));
 }
示例#3
0
 /// <summary>
 /// Get networks on specific devices
 /// </summary>
 /// <param name="deviceIdIdx"></param>
 /// <returns></returns>
 private (IEncoder, IWeightTensor, FeedForwardLayer) GetNetworksOnDeviceAt(int deviceIdIdx)
 {
     return(m_encoder.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx));
 }
示例#4
0
 /// <summary>
 /// Get networks on specific devices
 /// </summary>
 /// <param name="deviceIdIdx"></param>
 /// <returns></returns>
 private (IEncoder, IDecoder, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceIdIdx)
 {
     return(m_encoder.GetNetworkOnDevice(deviceIdIdx), m_decoder.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx),
            m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx));
 }