private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size, float clipval, float regc, int rowId) { Tensor TWeight = m.TWeight.Narrow(0, rowId, 1); Tensor TGradient = m.TGradient.Narrow(0, rowId, 1); Tensor TCash = m.TCash.Narrow(0, rowId, 1); Tensor TLrW = m.TLrW.Narrow(0, rowId, 1); if (batchSize != 1) { Ops.Mul(TGradient, TGradient, 1.0f / batchSize); } Ops.Clamp(TGradient, TGradient, -clipval, clipval); Ops.UpdateCash(TCash, TCash, TGradient, decay_rate); Ops.UpdateDelta(TGradient, TGradient, TCash, smooth_eps); Ops.UpdateCash(TLrW, TLrW, TGradient, lr_decay_rate); // Ops.AddMul(TLrW, TLrW, TGradient, TGradient); Ops.UpdateWeight2(TWeight, TWeight, TGradient, TLrW, -step_size, -regc); // Ops.UpdateWeight3(TWeight, TWeight, TGradient, -step_size, -regc); TWeight.Dispose(); TGradient.Dispose(); TCash.Dispose(); TLrW.Dispose(); }
private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size, float clipval, float regc) { Ops.Mul(m.TGradient, m.TGradient, 1.0f / batchSize); Ops.Clamp(m.TGradient, m.TGradient, -clipval, clipval); Ops.UpdateCash(m.TCash, m.TCash, m.TGradient, decay_rate); Ops.UpdateDelta(m.TGradient, m.TGradient, m.TCash, smooth_eps); Ops.UpdateCash(m.TLrW, m.TLrW, m.TGradient, lr_decay_rate); // Ops.AddMul(m.TLrW, m.TLrW, m.TGradient, m.TGradient); Ops.UpdateWeight2(m.TWeight, m.TWeight, m.TGradient, m.TLrW, -step_size, -regc); // Ops.UpdateWeight3(m.TWeight, m.TWeight, m.TGradient, -step_size, -regc); // Ops.Fill(m.TGradient, 0.0f); }
/// <summary> /// Clamps the specified minimum. /// </summary> /// <param name="min">The minimum.</param> /// <param name="max">The maximum.</param> /// <returns>TVar.</returns> public Variable Clamp(ScalarVar min, ScalarVar max) { return(new Variable(new UnaryTensorExpression(this.Expression, (res, src) => Ops.Clamp(res, src, min.Evaluate(), max.Evaluate())))); }
public TVar Clamp(SVar min, SVar max) { return(new TVar(new UnaryTensorExpression(this.Expression, (res, src) => Ops.Clamp(res, src, min.Evaluate(), max.Evaluate())))); }