//バッチで学習処理を行う public static Real Train(FunctionStack functionStack, NdArray input, NdArray teach, LossFunction lossFunction, bool isUpdate = true) { //結果の誤差保存用 NdArray[] result = functionStack.Forward(input); Real sumLoss = lossFunction.Evaluate(result, teach); //Backwardのバッチを実行 functionStack.Backward(result); //更新 if (isUpdate) { functionStack.Update(); } return(sumLoss); }
//Perform learning process in batch public static Real Train(FunctionStack functionStack, NdArray input, NdArray teach, LossFunction lossFunction, bool isUpdate = true) { //For preserving error of result NdArray[] result = functionStack.Forward(input); Real sumLoss = lossFunction.Evaluate(result, teach); //Run Backward's batch functionStack.Backward(result); //update if (isUpdate) { functionStack.Update(); } return(sumLoss); }
//////////////////////////////////////////////////////////////////////////////////////////////////// /// <summary> Do a learning process with a batch. </summary> /// /// <param name="functionStack"> Stack of functions. This cannot be null. </param> /// <param name="input"> The input. This may be null. </param> /// <param name="teach"> The teach. This may be null. </param> /// <param name="lossFunction"> The loss function. This cannot be null. </param> /// <param name="isUpdate"> (Optional) True if this object is update. </param> /// /// <returns> A Real. </returns> //////////////////////////////////////////////////////////////////////////////////////////////////// public static Real Train([NotNull] SortedFunctionStack functionStack, [CanBeNull] NdArray input, [CanBeNull] NdArray teach, [NotNull] LossFunction lossFunction, bool isUpdate = true, bool verbose = true) { if (verbose) { RILogManager.Default?.EnterMethod("Training " + functionStack.Name); } // for preserving the error of the result if (verbose) { RILogManager.Default?.SendDebug("Forward propagation"); } NdArray[] result = functionStack.Forward(verbose, input); if (verbose) { RILogManager.Default?.SendDebug("Evaluating loss"); } Real sumLoss = lossFunction.Evaluate(result, teach); // Run Backward batch if (verbose) { RILogManager.Default?.SendDebug("Backward propagation"); } functionStack.Backward(verbose, result); if (isUpdate) { if (verbose) { RILogManager.Default?.SendDebug("Updating stack"); } functionStack.Update(); } if (verbose) { RILogManager.Default?.ExitMethod("Training " + functionStack.Name); RILogManager.Default?.ViewerSendWatch("Local Loss", sumLoss.ToString(), sumLoss); } return(sumLoss); }