public IWeightMatrix MulBatch(IWeightMatrix m1, IWeightMatrix m2, int batchSize) { WeightTensor t1 = m1 as WeightTensor; WeightTensor t2 = m2 as WeightTensor; var n = t1.Rows; var d = t2.Columns; WeightTensor res = weightTensorFactory.CreateWeightTensor(n, d, deviceId); Tensor t1W = t1.TWeight.View(batchSize, t1.Rows / batchSize, t1.Columns); Tensor t2W = t2.TWeight.View(batchSize, t2.Rows / batchSize, t2.Columns); Tensor rW = res.TWeight.View(batchSize, n / batchSize, d); Ops.AddmmBatch(rW, 0.0f, rW, 1.0f, t1W, t2W); rW.Dispose(); if (this.needs_backprop) { Action backward = () => { res.ReleaseWeight(); Tensor t1G = t1.TGradient.View(batchSize, t1.Rows / batchSize, t1.Columns); Tensor t2G = t2.TGradient.View(batchSize, t2.Rows / batchSize, t2.Columns); Tensor rG = res.TGradient.View(batchSize, n / batchSize, d); var tW2 = t2W.Transpose(1, 2); Ops.AddmmBatch(t1G, 1.0f, t1G, 1.0f, rG, tW2); var tW1 = t1W.Transpose(1, 2); Ops.AddmmBatch(t2G, 1.0f, t2G, 1.0f, tW1, rG); tW1.Dispose(); tW2.Dispose(); t1W.Dispose(); t2W.Dispose(); t1G.Dispose(); t2G.Dispose(); rG.Dispose(); res.Dispose(); }; this.backprop.Add(backward); } else { t1W.Dispose(); t2W.Dispose(); } return(res); }
public IWeightTensor MulBatch(IWeightTensor m1, IWeightTensor m2, int batchSize, float alpha = 1.0f) { WeightTensor t1 = m1 as WeightTensor; WeightTensor t2 = m2 as WeightTensor; WeightTensor res = m_weightTensorFactory.CreateWeightTensor((int)(batchSize * t1.TWeight.Sizes[1]), (int)t2.TWeight.Sizes[2], m_deviceId, name: $"{GetHashString(m1.Name, m2.Name)}.MulBatch", graphToBind: this); VisualizeNodes(new IWeightTensor[] { m1, m2 }, res); Tensor t1W = t1.TWeight; Tensor t2W = t2.TWeight; using (Tensor rW = res.TWeight.View(batchSize, t1.TWeight.Sizes[1], t2.TWeight.Sizes[2])) { Ops.AddmmBatch(rW, 0.0f, rW, alpha, t1W, t2W); } if (m_needsBackprop) { Action backward = () => { res.ReleaseWeight(); using (Tensor rG = res.TGradient.View(batchSize, t1.TWeight.Sizes[1], t2.TWeight.Sizes[2])) { using (Tensor t1G = t1.TGradient.View(t1.TWeight.Sizes[0], t1.TWeight.Sizes[1], t1.TWeight.Sizes[2])) { using (Tensor tW2 = t2W.Transpose(1, 2)) { Ops.AddmmBatch(t1G, 1.0f, t1G, 1.0f, rG, tW2); } } using (Tensor t2G = t2.TGradient.View(t2.TWeight.Sizes[0], t2.TWeight.Sizes[1], t2.TWeight.Sizes[2])) { using (Tensor tW1 = t1W.Transpose(1, 2)) { Ops.AddmmBatch(t2G, 1.0f, t2G, 1.0f, tW1, rG); } } } res.Dispose(); }; m_backprop.Add(backward); t1.UnbindFromComputeGraph(); t2.UnbindFromComputeGraph(); } return(res); }