public static T Train <T>(FunctionStack <T> functionStack, TestDataSet <T> dataSet, LossFunction <T, int> lossFunction, Optimizer <T> optimizer = null) where T : unmanaged, IComparable <T> { return(Train(functionStack, dataSet.Data, dataSet.Label, lossFunction, optimizer)); }
public static T Train <T, LabelType>(FunctionStack <T> functionStack, NdArray <T> input, NdArray <LabelType> teach, LossFunction <T, LabelType> lossFunction, Optimizer <T> optimizer = null) where T : unmanaged, IComparable <T> where LabelType : unmanaged, IComparable <LabelType> { optimizer?.SetUp(functionStack); //結果の誤差保存用 NdArray <T> result = functionStack.Forward(input)[0]; T loss = lossFunction.Evaluate(result, teach); //Backwardのバッチを実行 functionStack.Backward(result); //更新 optimizer?.Update(); return(loss); }
//バッチで学習処理を行う public static T Train <T, LabelType>(FunctionStack <T> functionStack, T[][] input, LabelType[][] teach, LossFunction <T, LabelType> lossFunction, Optimizer <T> optimizer = null) where T : unmanaged, IComparable <T> where LabelType : unmanaged, IComparable <LabelType> { return(Train(functionStack, NdArray.FromArrays(input), NdArray.FromArrays(teach), lossFunction, optimizer)); }