public IWeightTensor Relu(IWeightTensor w) { var m = w as WeightTensor; var res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.Relu"); VisualizeNodes(w, res); Ops.Relu(res.TWeight, m.TWeight); if (this.m_needsBackprop) { Action backward = () => { Ops.AddReluD(m.TGradient, m.TGradient, m.TWeight, res.TGradient); res.Dispose(); }; this.m_backprop.Add(backward); } return(res); }
public IWeightTensor Relu(IWeightTensor w, bool inPlace = false) { WeightTensor m = w as WeightTensor; WeightTensor res = null; if (inPlace) { res = m.CopyWeightsRef($"{GetHashString(w.Name)}.Relu"); } else { res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.Relu", graphToBind: this); } VisualizeNodes(w, res); Ops.Relu(res.TWeight, m.TWeight); if (m_needsBackprop) { Action backward = () => { res.ReleaseWeight(); if (inPlace) { m.TGradient = res.TGradient.CopyRef(); Ops.ReluD(m.TGradient, m.TWeight, m.TGradient); } else { Ops.AddReluD(m.TGradient, m.TGradient, m.TWeight, res.TGradient); } res.Dispose(); }; m_backprop.Add(backward); m.UnbindFromComputeGraph(); } return(res); }