public AttentionSeq2Seq(int embeddingDim, int hiddenDim, int encoderLayerDepth, int decoderLayerDepth, Vocab vocab, string srcEmbeddingFilePath, string tgtEmbeddingFilePath, string modelFilePath, float dropoutRatio, int multiHeadNum, ProcessorTypeEnums processorType, EncoderTypeEnums encoderType, DecoderTypeEnums decoderType, bool enableCoverageModel, int[] deviceIds, bool isSrcEmbTrainable = true, bool isTgtEmbTrainable = true, bool isEncoderTrainable = true, bool isDecoderTrainable = true, int maxTgtSntSize = 128) : base(deviceIds, processorType, modelFilePath) { m_modelMetaData = new Seq2SeqModelMetaData(hiddenDim, embeddingDim, encoderLayerDepth, decoderLayerDepth, multiHeadNum, encoderType, decoderType, vocab, enableCoverageModel); m_dropoutRatio = dropoutRatio; m_isSrcEmbTrainable = isSrcEmbTrainable; m_isTgtEmbTrainable = isTgtEmbTrainable; m_isEncoderTrainable = isEncoderTrainable; m_isDecoderTrainable = isDecoderTrainable; m_maxTgtSntSize = maxTgtSntSize; //Initializng weights in encoders and decoders CreateTrainableParameters(m_modelMetaData); // Load external embedding from files for (int i = 0; i < DeviceIds.Length; i++) { //If pre-trained embedding weights are speicifed, loading them from files if (!string.IsNullOrEmpty(srcEmbeddingFilePath)) { Logger.WriteLine($"Loading ExtEmbedding model from '{srcEmbeddingFilePath}' for source side."); LoadWordEmbedding(srcEmbeddingFilePath, m_srcEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.SrcWordToIndex); } if (!string.IsNullOrEmpty(tgtEmbeddingFilePath)) { Logger.WriteLine($"Loading ExtEmbedding model from '{tgtEmbeddingFilePath}' for target side."); LoadWordEmbedding(tgtEmbeddingFilePath, m_tgtEmbedding.GetNetworkOnDevice(i), m_modelMetaData.Vocab.TgtWordToIndex); } } }
/// <summary> /// Get networks on specific devices /// </summary> private (IEncoder, IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor, IWeightTensor, IFeedForwardLayer) GetNetworksOnDeviceAt(int deviceIdIdx) { return(m_encoder.GetNetworkOnDevice(deviceIdIdx), m_decoder.GetNetworkOnDevice(deviceIdIdx), m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx)); }
/// <summary> /// Get networks on specific devices /// </summary> /// <param name="deviceIdIdx"></param> /// <returns></returns> private (IEncoder, IWeightTensor, FeedForwardLayer) GetNetworksOnDeviceAt(int deviceIdIdx) { return(m_encoder.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx)); }
/// <summary> /// Get networks on specific devices /// </summary> /// <param name="deviceIdIdx"></param> /// <returns></returns> private (IEncoder, IDecoder, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceIdIdx) { return(m_encoder.GetNetworkOnDevice(deviceIdIdx), m_decoder.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx)); }