private bool CreateTrainableParameters(IModel model) { Logger.WriteLine($"Creating encoders and decoders..."); var raDeviceIds = new RoundArray <int>(DeviceIds); int contextDim; (m_encoder, contextDim) = Encoder.CreateEncoders(model, m_options, raDeviceIds); m_ffLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", contextDim, model.ClsVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds); m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, name: "SrcEmbeddings", isTrainable: true), DeviceIds); if (model.EncoderType == EncoderTypeEnums.Transformer) { m_posEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(PositionEmbedding.BuildPositionWeightTensor(m_options.MaxTestSentLength + 2, model.EncoderEmbeddingDim, raDeviceIds.GetNextItem(), "PosEmbedding", false), DeviceIds, true); } else { m_posEmbedding = null; } return(true); }
/// <summary> /// Encode source sentences and output encoded weights /// </summary> /// <param name="g"></param> /// <param name="seqs"></param> /// <param name="encoder"></param> /// <param name="reversEncoder"></param> /// <param name="embeddings"></param> /// <returns></returns> static private IWeightTensor RunEncoder(IComputeGraph g, List <List <int> > seqs, IEncoder encoder, IModel modelMetaData, IWeightTensor embeddings, IWeightTensor selfMask, IWeightTensor posEmbeddings, IWeightTensor segmentEmbeddings) { int batchSize = seqs.Count; var inputEmbs = TensorUtils.CreateTokensEmbeddings(seqs, g, embeddings, segmentEmbeddings, modelMetaData.SrcVocab, (float)Math.Sqrt(embeddings.Columns), enableTagEmbedding: modelMetaData.EnableTagEmbeddings); if (modelMetaData.EncoderType == EncoderTypeEnums.Transformer) { inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, 0.0f); } return(encoder.Encode(inputEmbs, batchSize, g, selfMask)); }