コード例 #1
0
        public AttentionUnit(string name, int hiddenDim, int contextDim, int deviceId, bool enableCoverageModel, bool isTrainable)
        {
            m_name                = name;
            m_hiddenDim           = hiddenDim;
            m_contextDim          = contextDim;
            m_deviceId            = deviceId;
            m_enableCoverageModel = enableCoverageModel;
            m_isTrainable         = isTrainable;

            Logger.WriteLine($"Creating attention unit '{name}' HiddenDim = '{hiddenDim}', ContextDim = '{contextDim}', DeviceId = '{deviceId}', EnableCoverageModel = '{enableCoverageModel}'");

            m_Ua = new WeightTensor(new long[2] {
                contextDim, hiddenDim
            }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Ua)}", isTrainable: isTrainable);
            m_Wa = new WeightTensor(new long[2] {
                hiddenDim, hiddenDim
            }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Wa)}", isTrainable: isTrainable);
            m_bUa = new WeightTensor(new long[2] {
                1, hiddenDim
            }, 0, deviceId, name: $"{name}.{nameof(m_bUa)}", isTrainable: isTrainable);
            m_bWa = new WeightTensor(new long[2] {
                1, hiddenDim
            }, 0, deviceId, name: $"{name}.{nameof(m_bWa)}", isTrainable: isTrainable);
            m_V = new WeightTensor(new long[2] {
                hiddenDim, 1
            }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_V)}", isTrainable: isTrainable);

            if (m_enableCoverageModel)
            {
                m_Wc = new WeightTensor(new long[2] {
                    k_coverageModelDim, hiddenDim
                }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Wc)}", isTrainable: isTrainable);
                m_bWc = new WeightTensor(new long[2] {
                    1, hiddenDim
                }, 0, deviceId, name: $"{name}.{nameof(m_bWc)}", isTrainable: isTrainable);
                m_coverage = new LSTMCell(name: $"{name}.{nameof(m_coverage)}", hdim: k_coverageModelDim, dim: 1 + contextDim + hiddenDim, deviceId: deviceId, isTrainable: isTrainable);
            }
        }