コード例 #1
0
        public IWeightMatrix MulBatch(IWeightMatrix m1, IWeightMatrix m2, int batchSize)
        {
            WeightTensor t1  = m1 as WeightTensor;
            WeightTensor t2  = m2 as WeightTensor;
            var          n   = t1.Rows;
            var          d   = t2.Columns;
            WeightTensor res = weightTensorFactory.CreateWeightTensor(n, d, deviceId);

            Tensor t1W = t1.TWeight.View(batchSize, t1.Rows / batchSize, t1.Columns);
            Tensor t2W = t2.TWeight.View(batchSize, t2.Rows / batchSize, t2.Columns);
            Tensor rW  = res.TWeight.View(batchSize, n / batchSize, d);

            Ops.AddmmBatch(rW, 0.0f, rW, 1.0f, t1W, t2W);
            rW.Dispose();

            if (this.needs_backprop)
            {
                Action backward = () =>
                {
                    res.ReleaseWeight();

                    Tensor t1G = t1.TGradient.View(batchSize, t1.Rows / batchSize, t1.Columns);
                    Tensor t2G = t2.TGradient.View(batchSize, t2.Rows / batchSize, t2.Columns);
                    Tensor rG  = res.TGradient.View(batchSize, n / batchSize, d);

                    var tW2 = t2W.Transpose(1, 2);
                    Ops.AddmmBatch(t1G, 1.0f, t1G, 1.0f, rG, tW2);

                    var tW1 = t1W.Transpose(1, 2);
                    Ops.AddmmBatch(t2G, 1.0f, t2G, 1.0f, tW1, rG);

                    tW1.Dispose();
                    tW2.Dispose();

                    t1W.Dispose();
                    t2W.Dispose();
                    t1G.Dispose();
                    t2G.Dispose();

                    rG.Dispose();

                    res.Dispose();
                };
                this.backprop.Add(backward);
            }
            else
            {
                t1W.Dispose();
                t2W.Dispose();
            }

            return(res);
        }
コード例 #2
0
        public IWeightTensor MulBatch(IWeightTensor m1, IWeightTensor m2, int batchSize, float alpha = 1.0f)
        {
            WeightTensor t1  = m1 as WeightTensor;
            WeightTensor t2  = m2 as WeightTensor;
            WeightTensor res = m_weightTensorFactory.CreateWeightTensor((int)(batchSize * t1.TWeight.Sizes[1]), (int)t2.TWeight.Sizes[2], m_deviceId, name: $"{GetHashString(m1.Name, m2.Name)}.MulBatch", graphToBind: this);

            VisualizeNodes(new IWeightTensor[] { m1, m2 }, res);

            Tensor t1W = t1.TWeight;
            Tensor t2W = t2.TWeight;

            using (Tensor rW = res.TWeight.View(batchSize, t1.TWeight.Sizes[1], t2.TWeight.Sizes[2]))
            {
                Ops.AddmmBatch(rW, 0.0f, rW, alpha, t1W, t2W);
            }

            if (m_needsBackprop)
            {
                Action backward = () =>
                {
                    res.ReleaseWeight();
                    using (Tensor rG = res.TGradient.View(batchSize, t1.TWeight.Sizes[1], t2.TWeight.Sizes[2]))
                    {
                        using (Tensor t1G = t1.TGradient.View(t1.TWeight.Sizes[0], t1.TWeight.Sizes[1], t1.TWeight.Sizes[2]))
                        {
                            using (Tensor tW2 = t2W.Transpose(1, 2))
                            {
                                Ops.AddmmBatch(t1G, 1.0f, t1G, 1.0f, rG, tW2);
                            }
                        }
                        using (Tensor t2G = t2.TGradient.View(t2.TWeight.Sizes[0], t2.TWeight.Sizes[1], t2.TWeight.Sizes[2]))
                        {
                            using (Tensor tW1 = t1W.Transpose(1, 2))
                            {
                                Ops.AddmmBatch(t2G, 1.0f, t2G, 1.0f, tW1, rG);
                            }
                        }
                    }

                    res.Dispose();
                };
                m_backprop.Add(backward);

                t1.UnbindFromComputeGraph();
                t2.UnbindFromComputeGraph();
            }

            return(res);
        }