private bool CreateTrainableParameters(IModel model) { Logger.WriteLine($"Creating encoders..."); var raDeviceIds = new RoundArray <int>(DeviceIds); int contextDim; (m_encoder, contextDim) = Encoder.CreateEncoders(model, m_options, raDeviceIds); m_encoderFFLayer = new MultiProcessorNetworkWrapper <IFeedForwardLayer>(new FeedForwardLayer($"FeedForward_Encoder", contextDim, model.ClsVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds); (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, contextDim, Math.Max(m_options.MaxTrainSentLength, m_options.MaxTestSentLength), model); Logger.WriteLine($"Creating embeddings. Shape = '({model.SrcVocab.Count} ,{model.EncoderEmbeddingDim})'"); m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: m_options.IsEmbeddingTrainable), DeviceIds); return(true); }
private bool CreateTrainableParameters(IModel model) { Logger.WriteLine($"Creating encoders and decoders..."); var raDeviceIds = new RoundArray <int>(DeviceIds); int contextDim; (m_encoder, contextDim) = Encoder.CreateEncoders(model, m_options, raDeviceIds); m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, contextDim); m_encoderFFLayer = new MultiProcessorNetworkWrapper <IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Encoder_0", model.HiddenDim, model.ClsVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds); m_decoderFFLayer = new MultiProcessorNetworkWrapper <IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds); (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, contextDim, Math.Max(Math.Max(m_options.MaxTrainSrcSentLength, m_options.MaxTestSrcSentLength), Math.Max(m_options.MaxTrainTgtSentLength, m_options.MaxTestTgtSentLength)), model); (m_srcEmbedding, m_tgtEmbedding) = CreateSrcTgtEmbeddings(model, raDeviceIds, m_options.IsSrcEmbeddingTrainable, m_options.IsTgtEmbeddingTrainable, m_options.EncoderStartLearningRateFactor, m_options.DecoderStartLearningRateFactor); return(true); }