protected override void EvaluateCrossValidationSet(IEnumerable <BatchInputWrapper> cvData)
        {
            float matchPredictRewardAccumulator = 0, nonMatchPredictRewardAccumulator = 0;
            int   matchPredictRewardCountAccumulator = 0, nonMatchPredictRewardCountAccumulator = 0;

            foreach (var data in cvData)
            {
                float matchPredictReward = 0, nonMatchPredictReward = 0;
                int   matchPredictRewardCount = 0, nonMatchPredictRewardCount = 0;
                var   currentState        = data.StateActionRewardNewStateIsLastEpisodeTuple.Item1;
                var   chosenActionIndices = data.StateActionRewardNewStateIsLastEpisodeTuple.Item2;
                var   currentRewards      = data.StateActionRewardNewStateIsLastEpisodeTuple.Item3;

                this.VerifyInputVectorDimentionality(currentState, data.BatchSize);
                this.CheckAndResizeNetwork(data.BatchSize);

                Matrix predictedQValues = this.Predict(currentState);
                using (Matrix chosenActionIndicesMatrix = this.mathManager.CreateMatrix(chosenActionIndices, chosenActionIndices.Count()),
                       currentRewardsMatrix = this.mathManager.CreateMatrix(currentRewards, currentRewards.Count()))
                {
                    LossFunctions.DqnStanfordEvaluation(this.mathManager, predictedQValues, chosenActionIndicesMatrix, currentRewardsMatrix, ref matchPredictReward, ref matchPredictRewardCount,
                                                        ref nonMatchPredictReward, ref nonMatchPredictRewardCount);
                }

                matchPredictRewardAccumulator         += matchPredictReward;
                matchPredictRewardCountAccumulator    += matchPredictRewardCount;
                nonMatchPredictRewardAccumulator      += nonMatchPredictReward;
                nonMatchPredictRewardCountAccumulator += nonMatchPredictRewardCount;
            }

            Console.WriteLine("Average predict match reward:{0}", matchPredictRewardAccumulator / matchPredictRewardCountAccumulator);
            Console.WriteLine("Average predict non-match reward:{0}", nonMatchPredictRewardAccumulator / nonMatchPredictRewardCountAccumulator);
        }
 private void GetBellmanErrorAndDerivative(Matrix predictedQValues, Matrix QHatValues, float[] chosenActionIndices, float[] currentRewards, ref float errorAvg, float[] isLastEpisode)
 {
     using (Matrix chosenActionIndicesMatrix = this.mathManager.CreateMatrix(chosenActionIndices, chosenActionIndices.Count()),
            currentRewardsMatrix = this.mathManager.CreateMatrix(currentRewards, currentRewards.Count()),
            isLastEpisodeMatrix = this.mathManager.CreateMatrix(isLastEpisode, isLastEpisode.Count()))
     {
         LossFunctions.BellmanLossAndDerivative(this.mathManager, predictedQValues, QHatValues, chosenActionIndicesMatrix, currentRewardsMatrix, ref errorAvg,
                                                this.neuralLayers.Last().ErrorGradient, this.DQNConfiguration.DiscountFactor, isLastEpisodeMatrix);
     }
 }
 private void GetErrorAndDerivative(Matrix preSigmoidScores, float[] trueLabels, ref float errorAvg)
 {
     using (Matrix TrueLabelsMatrix = this.mathManager.CreateMatrix(trueLabels, this.neuralLayers.Last().Data.Row, this.neuralLayers.Last().Data.Column))
     {
         switch (this.configuration.LossFunction)
         {
         case LossFunctionType.CrossEntropyError:
             LossFunctions.CrossEntropyErrorAndDerivative(this.mathManager, preSigmoidScores, TrueLabelsMatrix, this.neuralLayers.Last().ErrorGradient, ref errorAvg);
             break;
         }
     }
 }
Esempio n. 4
0
        public static LossFunction GetLossFunction(LossFunctions loss)
        {
            switch (loss)
            {
            case LossFunctions.MSE:
                return(new MSE());

            case LossFunctions.MLE:
                return(new MLE());

            case LossFunctions.CE:
                return(new CE());

            default:
                return(null);
            }
        }
Esempio n. 5
0
        public static bool IsCanonicalLink(LogisticFunctions logistic, LossFunctions loss)
        {
            switch (logistic)
            {
            case LogisticFunctions.Sigmoid:
                return(loss == LossFunctions.MLE);

            case LogisticFunctions.IdentityFunction:
                return(loss == LossFunctions.MSE);

            case LogisticFunctions.SoftMax:
                return(loss == LossFunctions.CE);

            default:
                return(false);
            }
        }