コード例 #1
0
        /// <summary>
        /// Operation r = w1 + w2 * v
        /// </summary>
        /// <param name="w1"></param>
        /// <param name="w2"></param>
        /// <param name="v"></param>
        /// <returns></returns>
        public IWeightTensor AddMul(IWeightTensor w1, IWeightTensor w2, float v, bool runGradientW1 = true, bool runGradientW2 = true)
        {
            WeightTensor m1 = w1 as WeightTensor;
            WeightTensor m2 = w2 as WeightTensor;

            WeightTensor res = m_weightTensorFactory.CreateWeightTensor(m1.Sizes, m_deviceId, name: $"{GetHashString(w1.Name, w2.Name)}.AddMulV", graphToBind: this);

            VisualizeNodes(new IWeightTensor[] { w1, w2 }, res);

            Ops.AddMulV(res.TWeight, m1.TWeight, m2.TWeight, v);

            if (m_needsBackprop)
            {
                Action backward = () =>
                {
                    res.ReleaseWeight();

                    if (runGradientW1)
                    {
                        m1.CopyOrAddGradient(res);
                    }

                    if (runGradientW2)
                    {
                        Ops.AddMulV(m2.TGradient, m2.TGradient, res.TGradient, v);
                    }

                    res.Dispose();
                };
                m_backprop.Add(backward);
            }

            return(res);
        }
コード例 #2
0
        public IWeightMatrix Softmax(IWeightMatrix w)
        {
            WeightTensor m   = w as WeightTensor;
            var          res = weightTensorFactory.CreateWeightTensor(m.Rows, m.Columns, deviceId);

            var maxval = Ops.MaxAll(m.TWeight);

            Ops.ExpSub(res.TWeight, m.TWeight, maxval);
            float s = Ops.SumAll(res.TWeight);

            Ops.Mul(res.TWeight, res.TWeight, 1.0f / s);

            if (this.needs_backprop)
            {
                Action backward = () =>
                {
                    Tensor tTmp = Ops.Mul(null, res.TGradient, res.TWeight);
                    Ops.Add(m.TGradient, m.TGradient, tTmp);
                    float ss = Ops.SumAll(tTmp);

                    Ops.AddMulV(m.TGradient, m.TGradient, res.TWeight, -ss);

                    tTmp.Dispose();
                };
                this.backprop.Add(backward);
            }

            return(res);
        }
コード例 #3
0
        public IWeightMatrix Mul(IWeightMatrix w, float v)
        {
            var m   = w as WeightTensor;
            var res = weightTensorFactory.CreateWeightTensor(m.Rows, m.Columns, deviceId);

            Ops.Mul(res.TWeight, m.TWeight, v);

            if (this.needs_backprop)
            {
                Action backward = () =>
                {
                    Ops.AddMulV(m.TGradient, m.TGradient, res.TGradient, v);
                };
                this.backprop.Add(backward);
            }

            return(res);
        }
コード例 #4
0
        public IWeightTensor Mul(IWeightTensor w, float v)
        {
            var m   = w as WeightTensor;
            var res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.MulV");

            VisualizeNodes(w, res);

            Ops.Mul(res.TWeight, m.TWeight, v);

            if (m_needsBackprop)
            {
                Action backward = () =>
                {
                    Ops.AddMulV(m.TGradient, m.TGradient, res.TGradient, v);
                    res.Dispose();
                };
                this.m_backprop.Add(backward);
            }

            return(res);
        }