Exemplo n.º 1
0
        //バッチで学習処理を行う
        public static Real Train(FunctionStack functionStack, NdArray input, NdArray teach, LossFunction lossFunction, bool isUpdate = true)
        {
            //結果の誤差保存用
            NdArray[] result  = functionStack.Forward(input);
            Real      sumLoss = lossFunction.Evaluate(result, teach);

            //Backwardのバッチを実行
            functionStack.Backward(result);

            //更新
            if (isUpdate)
            {
                functionStack.Update();
            }

            return(sumLoss);
        }
Exemplo n.º 2
0
        //Perform learning process in batch
        public static Real Train(FunctionStack functionStack, NdArray input, NdArray teach, LossFunction lossFunction, bool isUpdate = true)
        {
            //For preserving error of result
            NdArray[] result  = functionStack.Forward(input);
            Real      sumLoss = lossFunction.Evaluate(result, teach);

            //Run Backward's batch
            functionStack.Backward(result);

            //update
            if (isUpdate)
            {
                functionStack.Update();
            }

            return(sumLoss);
        }
Exemplo n.º 3
0
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        /// <summary>   Do a learning process with a batch. </summary>
        ///
        /// <param name="functionStack">    Stack of functions. This cannot be null. </param>
        /// <param name="input">            The input. This may be null. </param>
        /// <param name="teach">            The teach. This may be null. </param>
        /// <param name="lossFunction">     The loss function. This cannot be null. </param>
        /// <param name="isUpdate">         (Optional) True if this object is update. </param>
        ///
        /// <returns>   A Real. </returns>
        ////////////////////////////////////////////////////////////////////////////////////////////////////

        public static Real Train([NotNull] SortedFunctionStack functionStack, [CanBeNull] NdArray input,
                                 [CanBeNull] NdArray teach, [NotNull] LossFunction lossFunction, bool isUpdate = true,
                                 bool verbose = true)
        {
            if (verbose)
            {
                RILogManager.Default?.EnterMethod("Training " + functionStack.Name);
            }
            // for preserving the error of the result
            if (verbose)
            {
                RILogManager.Default?.SendDebug("Forward propagation");
            }
            NdArray[] result = functionStack.Forward(verbose, input);
            if (verbose)
            {
                RILogManager.Default?.SendDebug("Evaluating loss");
            }
            Real sumLoss = lossFunction.Evaluate(result, teach);

            // Run Backward batch
            if (verbose)
            {
                RILogManager.Default?.SendDebug("Backward propagation");
            }
            functionStack.Backward(verbose, result);

            if (isUpdate)
            {
                if (verbose)
                {
                    RILogManager.Default?.SendDebug("Updating stack");
                }
                functionStack.Update();
            }

            if (verbose)
            {
                RILogManager.Default?.ExitMethod("Training " + functionStack.Name);
                RILogManager.Default?.ViewerSendWatch("Local Loss", sumLoss.ToString(), sumLoss);
            }

            return(sumLoss);
        }