/// <summary> /// Operation r = w1 + w2 * v /// </summary> /// <param name="w1"></param> /// <param name="w2"></param> /// <param name="v"></param> /// <returns></returns> public IWeightTensor AddMul(IWeightTensor w1, IWeightTensor w2, float v, bool runGradientW1 = true, bool runGradientW2 = true) { WeightTensor m1 = w1 as WeightTensor; WeightTensor m2 = w2 as WeightTensor; WeightTensor res = m_weightTensorFactory.CreateWeightTensor(m1.Sizes, m_deviceId, name: $"{GetHashString(w1.Name, w2.Name)}.AddMulV", graphToBind: this); VisualizeNodes(new IWeightTensor[] { w1, w2 }, res); Ops.AddMulV(res.TWeight, m1.TWeight, m2.TWeight, v); if (m_needsBackprop) { Action backward = () => { res.ReleaseWeight(); if (runGradientW1) { m1.CopyOrAddGradient(res); } if (runGradientW2) { Ops.AddMulV(m2.TGradient, m2.TGradient, res.TGradient, v); } res.Dispose(); }; m_backprop.Add(backward); } return(res); }
public IWeightMatrix Softmax(IWeightMatrix w) { WeightTensor m = w as WeightTensor; var res = weightTensorFactory.CreateWeightTensor(m.Rows, m.Columns, deviceId); var maxval = Ops.MaxAll(m.TWeight); Ops.ExpSub(res.TWeight, m.TWeight, maxval); float s = Ops.SumAll(res.TWeight); Ops.Mul(res.TWeight, res.TWeight, 1.0f / s); if (this.needs_backprop) { Action backward = () => { Tensor tTmp = Ops.Mul(null, res.TGradient, res.TWeight); Ops.Add(m.TGradient, m.TGradient, tTmp); float ss = Ops.SumAll(tTmp); Ops.AddMulV(m.TGradient, m.TGradient, res.TWeight, -ss); tTmp.Dispose(); }; this.backprop.Add(backward); } return(res); }
public IWeightMatrix Mul(IWeightMatrix w, float v) { var m = w as WeightTensor; var res = weightTensorFactory.CreateWeightTensor(m.Rows, m.Columns, deviceId); Ops.Mul(res.TWeight, m.TWeight, v); if (this.needs_backprop) { Action backward = () => { Ops.AddMulV(m.TGradient, m.TGradient, res.TGradient, v); }; this.backprop.Add(backward); } return(res); }
public IWeightTensor Mul(IWeightTensor w, float v) { var m = w as WeightTensor; var res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.MulV"); VisualizeNodes(w, res); Ops.Mul(res.TWeight, m.TWeight, v); if (m_needsBackprop) { Action backward = () => { Ops.AddMulV(m.TGradient, m.TGradient, res.TGradient, v); res.Dispose(); }; this.m_backprop.Add(backward); } return(res); }