public void MatchReferenceImplementation() { var stream = File.OpenRead("Ebisu/test.json"); using var testData = JsonDocument.Parse(stream); // Format of the test.json: // Array of test cases, where each test case is one of the following // ["update", [a, b, t0], [k, n, t], {"post": [a, b, t]}] // ["predict", [a, b, t0], [t], {"post": [mean]}] // test.json can be obtained from the ebisu python repository. foreach (var test in testData.RootElement.EnumerateArray()) { var data = test.EnumerateArray().ToArray(); var operation = data[0].GetString(); var modelData = data[1].EnumerateArray().ToArray(); var model = new EbisuModel( modelData[2].GetDouble(), modelData[0].GetDouble(), modelData[1].GetDouble()); switch (operation) { case "update": var paramData = data[2].EnumerateArray().ToArray(); var successes = paramData[0].GetInt32(); var total = paramData[1].GetInt32(); var time = paramData[2].GetDouble(); var expectedData = data[3].EnumerateObject().First() .Value.EnumerateArray().ToArray(); var expected = new EbisuModel( expectedData[2].GetDouble(), expectedData[0].GetDouble(), expectedData[1].GetDouble()); var updatedModel = model.UpdateRecall(successes, total, time); Assert.AreEqual(expected.Time, updatedModel.Time, Tolerance, $"Test: {test}"); Assert.AreEqual(expected.Alpha, updatedModel.Alpha, Tolerance, $"Test: {test}"); Assert.AreEqual(expected.Beta, updatedModel.Beta, Tolerance, $"Test: {test}"); break; case "predict": time = data[2].EnumerateArray().First().GetDouble(); var expectedRecall = data[3].EnumerateObject().First().Value .GetDouble(); var predictRecall = model.PredictRecall(time, true); Assert.AreEqual(expectedRecall, predictRecall, Tolerance); break; default: Assert.Fail("Reference data has invalid operation."); break; } } }
public void UpdateRecall() { var m = new EbisuModel(2, 2, 2); var success = m.UpdateRecall(1, 1, 2.0); var failure = m.UpdateRecall(0, 1, 2.0); Assert.AreEqual(3.0, success.Alpha, Epsilon, "success/alpha"); Assert.AreEqual(2.0, success.Beta, Epsilon, "success/beta"); Assert.AreEqual(2.0, failure.Alpha, Epsilon, "failure/alpha"); Assert.AreEqual(3.0, failure.Beta, Epsilon, "failure/beta"); }
public void ShouldReturnExactProbabilityOfRecallAfterDuration( double alpha, double beta, double time, double duration, double expectedRecall) { var prior = new EbisuModel(time, alpha, beta); var recall = prior.PredictRecall(duration, exact: true); Assert.AreEqual(expectedRecall, recall, Tolerance); }
/// <summary> /// Update the parameters of a prior model with new observations and return /// an updated model with posterior distribution of recall probability at /// <paramref name="timeNow"/> time units after review. /// /// <paramref name="prior"/> is the given belief about remembrance of the fact. We /// attempt to calculate the posterior given additional data i.e. <paramref name="successes"/> /// indicating the successful recalls in <paramref name="total"/> review attempts in /// <paramref name="timeNow"/> duration since last review. /// </summary> /// <param name="prior">Existing model representing the beta distribution for a fact.</param> /// <param name="successes">Number of successful reviews for the fact.</param> /// <param name="total">Number of total reviews for the fact.</param> /// <param name="timeNow">Elapsed time units since last review was recorded.</param> /// <returns>Updated model for the fact.</returns> /// <remarks>By default, this method will rebalance the returned model to represent /// recall probability distribution after half life time units since last review. /// See <c>UpdateRecall</c> overload to modify this behavior.</remarks> public static EbisuModel UpdateRecall( this EbisuModel prior, int successes, int total, double timeNow) { return(prior.UpdateRecall( successes, total, timeNow, true, prior.Time)); }
[DataRow(new[] { 2.0, 2.0, 2.0 }, 0, 1, 2.0, new[] { 2.0, 3.0, 0.0 })] // fail recalls public void ShouldReturnModelWithUpdatedParameters( double[] priorParams, int successes, int total, double duration, double[] expectedParams) { var prior = new EbisuModel(priorParams[2], priorParams[0], priorParams[1]); var expected = new EbisuModel(expectedParams[2], expectedParams[0], expectedParams[1]); var actual = prior.UpdateRecall(successes, total, duration); Assert.AreEqual(expected.Alpha, actual.Alpha, Tolerance); }
/// <summary> /// Estimate the recall probability of an existing model given the time units /// elapsed since last review. /// </summary> /// <param name="prior">Existing ebisu model.</param> /// <param name="timeNow">Time elapsed since last review.</param> /// <param name="exact">Return log probabilities if false (default).</param> /// <returns>Probability of recall. 0 represents fail, and 1 for pass.</returns> public static double PredictRecall( this EbisuModel prior, double timeNow, bool exact = false) { double alpha = prior.Alpha; double beta = prior.Beta; double dt = timeNow / prior.Time; // Ebisu represents the events as a GB1 distribution. Expected recall // probability is `B(alpha + dt, beta)/B(alpha, beta)`, where `B()` is // the beta function. We are calculating it over the log domain. // So `log(a/b) = log(a) - log(b)` applies. // See https://en.wikipedia.org/wiki/Generalized_beta_distribution#Generalized_beta_of_first_kind_(GB1) // and the notes at https://fasiha.github.io/ebisu/ (Recall probability right now). double ret = BetaLn(alpha + dt, beta) - BetaLn(alpha, beta); return(exact ? Math.Exp(ret) : ret); }
/// <summary> /// Rebalance a proposed posterior model to ensure its <c>Alpha</c> and /// <c>Beta</c> parameters are close. /// Since <c>Alpha = Beta</c> implies half life, this operation keeps /// tries to update the shape parameters for numerical stability. /// </summary> /// <param name="prior">Existing memory model.</param> /// <param name="successes">Count of successful reviews.</param> /// <param name="total">Count of total number of reviews.</param> /// <param name="timeNow">Duration since last review.</param> /// <param name="proposed">Proposed memory model.</param> /// <returns>Updated model with duration nearer to the half life.</returns> private static EbisuModel Rebalance( this EbisuModel prior, int successes, int total, double timeNow, EbisuModel proposed) { double newAlpha = proposed.Alpha; double newBeta = proposed.Beta; if (newAlpha > 2 * newBeta || newBeta > 2 * newAlpha) { // Compute the elapsed time for this model to reach half its recall // probability i.e. half life double roughHalflife = ModelToPercentileDecay(proposed, 0.5, true, 1e-4); return(prior.UpdateRecall(successes, total, timeNow, false, roughHalflife)); } return(proposed); }
public void BetaLnFunctionForZeroSuccesses() { var delta = 0.05; var prior = new EbisuModel(1.0, 34.4, 3.4); var updatedModel = prior.UpdateRecall(0, 5, 0.1); Assert.AreEqual(3.0652051705190964, updatedModel.Time, delta); Assert.AreEqual(8.706432410647471, updatedModel.Beta, delta); Assert.AreEqual(8.760308130181903, updatedModel.Alpha, delta); #if NONE double timeNow = 0.1; double timeBack = prior.Time; int successes = 0; int total = 5; double alpha = prior.Alpha; double beta = prior.Beta; double t = prior.Time; double dt = timeNow / t; double et = timeBack / timeNow; var failures = total - successes; var binomlns = Enumerable.Range(0, failures + 1) .Select(i => BinomialLn(failures, i)).ToArray(); var logs1 = Enumerable.Range(0, 3) .Select(m => { var a = Enumerable.Range(0, failures + 1) .Select(i => binomlns[i] + SpecialFunctions.BetaLn( beta, alpha + (dt * (successes + i)) + (m * dt * et))) .ToList(); var b = Enumerable.Range(0, failures + 1) .Select(i => Math.Pow(-1.0, i)) .ToList(); return(LogSumExp(a, b)[0]); }) .ToArray(); var logs2 = Enumerable.Range(0, 3) .Select(m => { var a = Enumerable.Range(0, failures + 1) .Select(i => binomlns[i] + Beta.BetaLn( beta, alpha + (dt * (successes + i)) + (m * dt * et))) .ToList(); var b = Enumerable.Range(0, failures + 1) .Select(i => Math.Pow(-1.0, i)) .ToList(); return(LogSumExp(a, b)[0]); }) .ToArray(); for (int i = 0; i < logs1.Length; i++) { Assert.AreEqual(logs1[i], logs2[i], 1e-3); } #endif }
/// <summary> /// Compute the time duration for a <see cref="EbisuModel"/> to decay to /// a given percentile. /// </summary> /// <param name="model">Given model for the fact.</param> /// <param name="percentile">Target percentile for the decay.</param> /// <param name="coarse">If true, use an approximation for the duration returned.</param> /// <param name="tolerance">Allowed tolerance for the duration.</param> /// <returns>Duration in time units (of provided model) for the decay to given percentile.</returns> private static double ModelToPercentileDecay( this EbisuModel model, double percentile, bool coarse, double tolerance) { if (percentile < 0 || percentile > 1) { throw new ArgumentException( "Percentiles must be between (0, 1) exclusive", nameof(percentile)); } double alpha = model.Alpha; double beta = model.Beta; double t0 = model.Time; double logBab = BetaLn(alpha, beta); double logPercentile = Math.Log(percentile); Func <double, double> f = lndelta => { return((BetaLn(alpha + Math.Exp(lndelta), beta) - logBab) - logPercentile); }; double bracketWidth = coarse ? 1.0 : 6.0; double blow = -bracketWidth / 2.0; double bhigh = bracketWidth / 2.0; double flow = f(blow); double fhigh = f(bhigh); while (flow > 0 && fhigh > 0) { // Move the bracket up. blow = bhigh; flow = fhigh; bhigh += bracketWidth; fhigh = f(bhigh); } while (flow < 0 && fhigh < 0) { // Move the bracket down. bhigh = blow; fhigh = flow; blow -= bracketWidth; flow = f(blow); } if (!(flow > 0 && fhigh < 0)) { throw new EbisuConstraintViolationException($"Failed to bracket: flow={flow}, fhigh={fhigh}"); } if (coarse) { return((Math.Exp(blow) + Math.Exp(bhigh)) / 2 * t0); } // Similar to the `root_scalar` api with bracketing // See https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root_scalar.html#scipy.optimize.root_scalar var sol = Brent.FindRoot(f, blow, bhigh); return(Math.Exp(sol) * t0); }
/// <summary> /// Update the parameters of a prior model with new observations and return /// an updated model with posterior distribution of recall probability at /// <paramref name="timeBack"/> time units after review. /// /// <paramref name="prior"/> is the given belief about remembrance of the fact. We /// attempt to calculate the posterior given additional data i.e. <paramref name="successes"/> /// indicating the successful recalls in <paramref name="total"/> review attempts in /// <paramref name="timeNow"/> duration since last review. /// </summary> /// <param name="prior">Existing model representing the beta distribution for a fact.</param> /// <param name="successes">Number of successful reviews for the fact.</param> /// <param name="total">Number of total reviews for the fact.</param> /// <param name="timeNow">Elapsed time units since last review was recorded.</param> /// <param name="rebalance">If true, the updated model is computed with <paramref name="timeBack"/> set to half life.</param> /// <param name="timeBack">Time stamp for calculating recall in the updated model.</param> /// <returns>Updated model for the fact.</returns> /// <remarks> /// Each review of the fact can be modelled as a binomial experiment with /// `k` successes in `n` trials. These are represented as the `successes` and /// `total` variables here. If `total` is 1, this is a bernoulli experiment. /// Second, we're assuming the experiments (reviews) to be independent. They're not /// independent if the app shows a hint to the user, obviously the next review is biased. /// /// Given the `prior` recall probability and the results of new experiments, what is the /// `posterior` recall probability? Note we're being bayesian, and asking hard questions. /// </remarks> public static EbisuModel UpdateRecall( this EbisuModel prior, int successes, int total, double timeNow, bool rebalance, double timeBack) { if (successes < 0 || successes > total) { throw new ArgumentException( "Successes must not be negative and less than Total.", nameof(successes)); } if (total < 1) { throw new ArgumentException( "Total experiments must be one or more.", nameof(total)); } // See https://fasiha.github.io/ebisu/ (Updating the posterior with quiz results) // section for detailed derivation. double alpha = prior.Alpha; double beta = prior.Beta; double t = prior.Time; double dt = timeNow / t; double et = timeBack / timeNow; var failures = total - successes; // Most of the calculations are summations over the range [0, failures] var binomlns = Enumerable.Range(0, failures + 1) .Select(i => BinomialLn(failures, i)).ToArray(); var logs = Enumerable.Range(0, 3) .Select(m => { var a = Enumerable.Range(0, failures + 1) .Select(i => binomlns[i] + BetaLn( beta, alpha + (dt * (successes + i)) + (m * dt * et))) .ToList(); var b = Enumerable.Range(0, failures + 1) .Select(i => Math.Pow(-1.0, i)) .ToList(); return(LogSumExp(a, b).Value); }) .ToArray(); double logDenominator = logs[0]; double logMeanNum = logs[1]; double logM2Num = logs[2]; double mean = Math.Exp(logMeanNum - logDenominator); double m2 = Math.Exp(logM2Num - logDenominator); double meanSq = Math.Exp(2 * (logMeanNum - logDenominator)); double sig2 = m2 - meanSq; if (mean <= 0) { throw new EbisuConstraintViolationException($"Invalid mean found: a={alpha}, b={beta}, t={t}, k={successes}, n={total}, tnow={timeNow}, mean={mean}, m2={m2}, sig2={sig2}"); } if (m2 <= 0) { throw new EbisuConstraintViolationException($"Invalid second moment found: a={alpha}, b={beta}, t={t}, k={successes}, n={total}, tnow={timeNow}, mean={mean}, m2={m2}, sig2={sig2}"); } if (sig2 <= 0) { throw new EbisuConstraintViolationException( $"Invalid variance found: a={alpha}, b={beta}, t={t}, k={successes}, n={total}, tnow={timeNow}, mean={mean}, m2={m2}, sig2={sig2}"); } // Compute the Beta function from mean and variance // See https://en.wikipedia.org/wiki/Beta_distribution#Mean_and_variance var(newAlpha, newBeta) = MeanVarToBeta(mean, sig2); var proposed = new EbisuModel(timeBack, newAlpha, newBeta); return(rebalance ? prior.Rebalance(successes, total, timeNow, proposed) : proposed); }