Пример #1
0
        public void MatchReferenceImplementation()
        {
            var stream = File.OpenRead("Ebisu/test.json");

            using var testData = JsonDocument.Parse(stream);

            // Format of the test.json:
            // Array of test cases, where each test case is one of the following
            //  ["update", [a, b, t0], [k, n, t], {"post": [a, b, t]}]
            //  ["predict", [a, b, t0], [t], {"post": [mean]}]
            // test.json can be obtained from the ebisu python repository.
            foreach (var test in testData.RootElement.EnumerateArray())
            {
                var data      = test.EnumerateArray().ToArray();
                var operation = data[0].GetString();
                var modelData = data[1].EnumerateArray().ToArray();
                var model     = new EbisuModel(
                    modelData[2].GetDouble(),
                    modelData[0].GetDouble(),
                    modelData[1].GetDouble());

                switch (operation)
                {
                case "update":
                    var paramData    = data[2].EnumerateArray().ToArray();
                    var successes    = paramData[0].GetInt32();
                    var total        = paramData[1].GetInt32();
                    var time         = paramData[2].GetDouble();
                    var expectedData = data[3].EnumerateObject().First()
                                       .Value.EnumerateArray().ToArray();
                    var expected = new EbisuModel(
                        expectedData[2].GetDouble(),
                        expectedData[0].GetDouble(),
                        expectedData[1].GetDouble());

                    var updatedModel = model.UpdateRecall(successes, total, time);

                    Assert.AreEqual(expected.Time, updatedModel.Time, Tolerance, $"Test: {test}");
                    Assert.AreEqual(expected.Alpha, updatedModel.Alpha, Tolerance, $"Test: {test}");
                    Assert.AreEqual(expected.Beta, updatedModel.Beta, Tolerance, $"Test: {test}");
                    break;

                case "predict":
                    time = data[2].EnumerateArray().First().GetDouble();
                    var expectedRecall = data[3].EnumerateObject().First().Value
                                         .GetDouble();

                    var predictRecall = model.PredictRecall(time, true);

                    Assert.AreEqual(expectedRecall, predictRecall, Tolerance);
                    break;

                default:
                    Assert.Fail("Reference data has invalid operation.");
                    break;
                }
            }
        }
Пример #2
0
        public void UpdateRecall()
        {
            var m       = new EbisuModel(2, 2, 2);
            var success = m.UpdateRecall(1, 1, 2.0);
            var failure = m.UpdateRecall(0, 1, 2.0);

            Assert.AreEqual(3.0, success.Alpha, Epsilon, "success/alpha");
            Assert.AreEqual(2.0, success.Beta, Epsilon, "success/beta");
            Assert.AreEqual(2.0, failure.Alpha, Epsilon, "failure/alpha");
            Assert.AreEqual(3.0, failure.Beta, Epsilon, "failure/beta");
        }
Пример #3
0
        public void ShouldReturnExactProbabilityOfRecallAfterDuration(
            double alpha,
            double beta,
            double time,
            double duration,
            double expectedRecall)
        {
            var prior = new EbisuModel(time, alpha, beta);

            var recall = prior.PredictRecall(duration, exact: true);

            Assert.AreEqual(expectedRecall, recall, Tolerance);
        }
Пример #4
0
 /// <summary>
 /// Update the parameters of a prior model with new observations and return
 /// an updated model with posterior distribution of recall probability at
 /// <paramref name="timeNow"/> time units after review.
 ///
 /// <paramref name="prior"/> is the given belief about remembrance of the fact. We
 /// attempt to calculate the posterior given additional data i.e. <paramref name="successes"/>
 /// indicating the successful recalls in <paramref name="total"/> review attempts in
 /// <paramref name="timeNow"/> duration since last review.
 /// </summary>
 /// <param name="prior">Existing model representing the beta distribution for a fact.</param>
 /// <param name="successes">Number of successful reviews for the fact.</param>
 /// <param name="total">Number of total reviews for the fact.</param>
 /// <param name="timeNow">Elapsed time units since last review was recorded.</param>
 /// <returns>Updated model for the fact.</returns>
 /// <remarks>By default, this method will rebalance the returned model to represent
 /// recall probability distribution after half life time units since last review.
 /// See <c>UpdateRecall</c> overload to modify this behavior.</remarks>
 public static EbisuModel UpdateRecall(
     this EbisuModel prior,
     int successes,
     int total,
     double timeNow)
 {
     return(prior.UpdateRecall(
                successes,
                total,
                timeNow,
                true,
                prior.Time));
 }
Пример #5
0
        [DataRow(new[] { 2.0, 2.0, 2.0 }, 0, 1, 2.0, new[] { 2.0, 3.0, 0.0 })] // fail recalls
        public void ShouldReturnModelWithUpdatedParameters(
            double[] priorParams,
            int successes,
            int total,
            double duration,
            double[] expectedParams)
        {
            var prior    = new EbisuModel(priorParams[2], priorParams[0], priorParams[1]);
            var expected = new EbisuModel(expectedParams[2], expectedParams[0], expectedParams[1]);

            var actual = prior.UpdateRecall(successes, total, duration);

            Assert.AreEqual(expected.Alpha, actual.Alpha, Tolerance);
        }
Пример #6
0
        /// <summary>
        /// Estimate the recall probability of an existing model given the time units
        /// elapsed since last review.
        /// </summary>
        /// <param name="prior">Existing ebisu model.</param>
        /// <param name="timeNow">Time elapsed since last review.</param>
        /// <param name="exact">Return log probabilities if false (default).</param>
        /// <returns>Probability of recall. 0 represents fail, and 1 for pass.</returns>
        public static double PredictRecall(
            this EbisuModel prior,
            double timeNow,
            bool exact = false)
        {
            double alpha = prior.Alpha;
            double beta  = prior.Beta;
            double dt    = timeNow / prior.Time;

            // Ebisu represents the events as a GB1 distribution. Expected recall
            // probability is `B(alpha + dt, beta)/B(alpha, beta)`, where `B()` is
            // the beta function. We are calculating it over the log domain.
            // So `log(a/b) = log(a) - log(b)` applies.
            // See https://en.wikipedia.org/wiki/Generalized_beta_distribution#Generalized_beta_of_first_kind_(GB1)
            // and the notes at https://fasiha.github.io/ebisu/ (Recall probability right now).
            double ret = BetaLn(alpha + dt, beta) -
                         BetaLn(alpha, beta);

            return(exact ? Math.Exp(ret) : ret);
        }
Пример #7
0
        /// <summary>
        /// Rebalance a proposed posterior model to ensure its <c>Alpha</c> and
        /// <c>Beta</c> parameters are close.
        /// Since <c>Alpha = Beta</c> implies half life, this operation keeps
        /// tries to update the shape parameters for numerical stability.
        /// </summary>
        /// <param name="prior">Existing memory model.</param>
        /// <param name="successes">Count of successful reviews.</param>
        /// <param name="total">Count of total number of reviews.</param>
        /// <param name="timeNow">Duration since last review.</param>
        /// <param name="proposed">Proposed memory model.</param>
        /// <returns>Updated model with duration nearer to the half life.</returns>
        private static EbisuModel Rebalance(
            this EbisuModel prior,
            int successes,
            int total,
            double timeNow,
            EbisuModel proposed)
        {
            double newAlpha = proposed.Alpha;
            double newBeta  = proposed.Beta;

            if (newAlpha > 2 * newBeta || newBeta > 2 * newAlpha)
            {
                // Compute the elapsed time for this model to reach half its recall
                // probability i.e. half life
                double roughHalflife = ModelToPercentileDecay(proposed, 0.5, true, 1e-4);
                return(prior.UpdateRecall(successes, total, timeNow, false, roughHalflife));
            }

            return(proposed);
        }
Пример #8
0
        public void BetaLnFunctionForZeroSuccesses()
        {
            var delta = 0.05;
            var prior = new EbisuModel(1.0, 34.4, 3.4);

            var updatedModel = prior.UpdateRecall(0, 5, 0.1);

            Assert.AreEqual(3.0652051705190964, updatedModel.Time, delta);
            Assert.AreEqual(8.706432410647471, updatedModel.Beta, delta);
            Assert.AreEqual(8.760308130181903, updatedModel.Alpha, delta);

#if NONE
            double timeNow   = 0.1;
            double timeBack  = prior.Time;
            int    successes = 0;
            int    total     = 5;
            double alpha     = prior.Alpha;
            double beta      = prior.Beta;
            double t         = prior.Time;
            double dt        = timeNow / t;
            double et        = timeBack / timeNow;
            var    failures  = total - successes;

            var binomlns = Enumerable.Range(0, failures + 1)
                           .Select(i => BinomialLn(failures, i)).ToArray();
            var logs1 =
                Enumerable.Range(0, 3)
                .Select(m =>
            {
                var a =
                    Enumerable.Range(0, failures + 1)
                    .Select(i => binomlns[i] + SpecialFunctions.BetaLn(
                                beta,
                                alpha + (dt * (successes + i)) + (m * dt * et)))
                    .ToList();
                var b = Enumerable.Range(0, failures + 1)
                        .Select(i => Math.Pow(-1.0, i))
                        .ToList();
                return(LogSumExp(a, b)[0]);
            })
                .ToArray();
            var logs2 =
                Enumerable.Range(0, 3)
                .Select(m =>
            {
                var a =
                    Enumerable.Range(0, failures + 1)
                    .Select(i => binomlns[i] + Beta.BetaLn(
                                beta,
                                alpha + (dt * (successes + i)) + (m * dt * et)))
                    .ToList();
                var b = Enumerable.Range(0, failures + 1)
                        .Select(i => Math.Pow(-1.0, i))
                        .ToList();
                return(LogSumExp(a, b)[0]);
            })
                .ToArray();

            for (int i = 0; i < logs1.Length; i++)
            {
                Assert.AreEqual(logs1[i], logs2[i], 1e-3);
            }
#endif
        }
Пример #9
0
        /// <summary>
        /// Compute the time duration for a <see cref="EbisuModel"/> to decay to
        /// a given percentile.
        /// </summary>
        /// <param name="model">Given model for the fact.</param>
        /// <param name="percentile">Target percentile for the decay.</param>
        /// <param name="coarse">If true, use an approximation for the duration returned.</param>
        /// <param name="tolerance">Allowed tolerance for the duration.</param>
        /// <returns>Duration in time units (of provided model) for the decay to given percentile.</returns>
        private static double ModelToPercentileDecay(
            this EbisuModel model,
            double percentile,
            bool coarse,
            double tolerance)
        {
            if (percentile < 0 || percentile > 1)
            {
                throw new ArgumentException(
                          "Percentiles must be between (0, 1) exclusive",
                          nameof(percentile));
            }

            double alpha = model.Alpha;
            double beta  = model.Beta;
            double t0    = model.Time;

            double logBab           = BetaLn(alpha, beta);
            double logPercentile    = Math.Log(percentile);
            Func <double, double> f = lndelta =>
            {
                return((BetaLn(alpha + Math.Exp(lndelta), beta) - logBab) -
                       logPercentile);
            };

            double bracketWidth = coarse ? 1.0 : 6.0;
            double blow         = -bracketWidth / 2.0;
            double bhigh        = bracketWidth / 2.0;
            double flow         = f(blow);
            double fhigh        = f(bhigh);

            while (flow > 0 && fhigh > 0)
            {
                // Move the bracket up.
                blow   = bhigh;
                flow   = fhigh;
                bhigh += bracketWidth;
                fhigh  = f(bhigh);
            }

            while (flow < 0 && fhigh < 0)
            {
                // Move the bracket down.
                bhigh = blow;
                fhigh = flow;
                blow -= bracketWidth;
                flow  = f(blow);
            }

            if (!(flow > 0 && fhigh < 0))
            {
                throw new EbisuConstraintViolationException($"Failed to bracket: flow={flow}, fhigh={fhigh}");
            }

            if (coarse)
            {
                return((Math.Exp(blow) + Math.Exp(bhigh)) / 2 * t0);
            }

            // Similar to the `root_scalar` api with bracketing
            // See https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root_scalar.html#scipy.optimize.root_scalar
            var sol = Brent.FindRoot(f, blow, bhigh);

            return(Math.Exp(sol) * t0);
        }
Пример #10
0
        /// <summary>
        /// Update the parameters of a prior model with new observations and return
        /// an updated model with posterior distribution of recall probability at
        /// <paramref name="timeBack"/> time units after review.
        ///
        /// <paramref name="prior"/> is the given belief about remembrance of the fact. We
        /// attempt to calculate the posterior given additional data i.e. <paramref name="successes"/>
        /// indicating the successful recalls in <paramref name="total"/> review attempts in
        /// <paramref name="timeNow"/> duration since last review.
        /// </summary>
        /// <param name="prior">Existing model representing the beta distribution for a fact.</param>
        /// <param name="successes">Number of successful reviews for the fact.</param>
        /// <param name="total">Number of total reviews for the fact.</param>
        /// <param name="timeNow">Elapsed time units since last review was recorded.</param>
        /// <param name="rebalance">If true, the updated model is computed with <paramref name="timeBack"/> set to half life.</param>
        /// <param name="timeBack">Time stamp for calculating recall in the updated model.</param>
        /// <returns>Updated model for the fact.</returns>
        /// <remarks>
        /// Each review of the fact can be modelled as a binomial experiment with
        /// `k` successes in `n` trials. These are represented as the `successes` and
        /// `total` variables here. If `total` is 1, this is a bernoulli experiment.
        /// Second, we're assuming the experiments (reviews) to be independent. They're not
        /// independent if the app shows a hint to the user, obviously the next review is biased.
        ///
        /// Given the `prior` recall probability and the results of new experiments, what is the
        /// `posterior` recall probability? Note we're being bayesian, and asking hard questions.
        /// </remarks>
        public static EbisuModel UpdateRecall(
            this EbisuModel prior,
            int successes,
            int total,
            double timeNow,
            bool rebalance,
            double timeBack)
        {
            if (successes < 0 || successes > total)
            {
                throw new ArgumentException(
                          "Successes must not be negative and less than Total.",
                          nameof(successes));
            }

            if (total < 1)
            {
                throw new ArgumentException(
                          "Total experiments must be one or more.",
                          nameof(total));
            }

            // See https://fasiha.github.io/ebisu/ (Updating the posterior with quiz results)
            // section for detailed derivation.
            double alpha    = prior.Alpha;
            double beta     = prior.Beta;
            double t        = prior.Time;
            double dt       = timeNow / t;
            double et       = timeBack / timeNow;
            var    failures = total - successes;

            // Most of the calculations are summations over the range [0, failures]
            var binomlns = Enumerable.Range(0, failures + 1)
                           .Select(i => BinomialLn(failures, i)).ToArray();
            var logs = Enumerable.Range(0, 3)
                       .Select(m =>
            {
                var a =
                    Enumerable.Range(0, failures + 1)
                    .Select(i => binomlns[i] + BetaLn(
                                beta,
                                alpha + (dt * (successes + i)) + (m * dt * et)))
                    .ToList();
                var b = Enumerable.Range(0, failures + 1)
                        .Select(i => Math.Pow(-1.0, i))
                        .ToList();
                return(LogSumExp(a, b).Value);
            })
                       .ToArray();

            double logDenominator = logs[0];
            double logMeanNum     = logs[1];
            double logM2Num       = logs[2];

            double mean   = Math.Exp(logMeanNum - logDenominator);
            double m2     = Math.Exp(logM2Num - logDenominator);
            double meanSq = Math.Exp(2 * (logMeanNum - logDenominator));
            double sig2   = m2 - meanSq;

            if (mean <= 0)
            {
                throw new EbisuConstraintViolationException($"Invalid mean found: a={alpha}, b={beta}, t={t}, k={successes}, n={total}, tnow={timeNow}, mean={mean}, m2={m2}, sig2={sig2}");
            }

            if (m2 <= 0)
            {
                throw new EbisuConstraintViolationException($"Invalid second moment found: a={alpha}, b={beta}, t={t}, k={successes}, n={total}, tnow={timeNow}, mean={mean}, m2={m2}, sig2={sig2}");
            }

            if (sig2 <= 0)
            {
                throw new EbisuConstraintViolationException(
                          $"Invalid variance found: a={alpha}, b={beta}, t={t}, k={successes}, n={total}, tnow={timeNow}, mean={mean}, m2={m2}, sig2={sig2}");
            }

            // Compute the Beta function from mean and variance
            // See https://en.wikipedia.org/wiki/Beta_distribution#Mean_and_variance
            var(newAlpha, newBeta) = MeanVarToBeta(mean, sig2);
            var proposed = new EbisuModel(timeBack, newAlpha, newBeta);

            return(rebalance ? prior.Rebalance(successes, total, timeNow, proposed) : proposed);
        }