public static T Train <T, LabelType>(FunctionStack <T> functionStack, NdArray <T> input, NdArray <LabelType> teach, LossFunction <T, LabelType> lossFunction, Optimizer <T> optimizer = null) where T : unmanaged, IComparable <T> where LabelType : unmanaged, IComparable <LabelType> { optimizer?.SetUp(functionStack); //結果の誤差保存用 NdArray <T> result = functionStack.Forward(input)[0]; T loss = lossFunction.Evaluate(result, teach); //Backwardのバッチを実行 functionStack.Backward(result); //更新 optimizer?.Update(); return(loss); }
public static Real Train(FunctionStack functionStack, NdArray input, NdArray teach, LossFunction lossFunction, bool isUpdate = true) { //結果の誤差保存用 NdArray result = functionStack.Forward(input)[0]; Real sumLoss = lossFunction.Evaluate(result, teach); //Backwardのバッチを実行 functionStack.Backward(result); //更新 if (isUpdate) { functionStack.Update(); } return(sumLoss); }