public AttentionUnit(string name, int hiddenDim, int contextDim, int deviceId, bool enableCoverageModel, bool isTrainable) { m_name = name; m_hiddenDim = hiddenDim; m_contextDim = contextDim; m_deviceId = deviceId; m_enableCoverageModel = enableCoverageModel; m_isTrainable = isTrainable; Logger.WriteLine($"Creating attention unit '{name}' HiddenDim = '{hiddenDim}', ContextDim = '{contextDim}', DeviceId = '{deviceId}', EnableCoverageModel = '{enableCoverageModel}'"); m_Ua = new WeightTensor(new long[2] { contextDim, hiddenDim }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Ua)}", isTrainable: isTrainable); m_Wa = new WeightTensor(new long[2] { hiddenDim, hiddenDim }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Wa)}", isTrainable: isTrainable); m_bUa = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(m_bUa)}", isTrainable: isTrainable); m_bWa = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(m_bWa)}", isTrainable: isTrainable); m_V = new WeightTensor(new long[2] { hiddenDim, 1 }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_V)}", isTrainable: isTrainable); if (m_enableCoverageModel) { m_Wc = new WeightTensor(new long[2] { k_coverageModelDim, hiddenDim }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Wc)}", isTrainable: isTrainable); m_bWc = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(m_bWc)}", isTrainable: isTrainable); m_coverage = new LSTMCell(name: $"{name}.{nameof(m_coverage)}", hdim: k_coverageModelDim, dim: 1 + contextDim + hiddenDim, deviceId: deviceId, isTrainable: isTrainable); } }