public void UpdateRecall() { var m = new EbisuModel(2, 2, 2); var success = m.UpdateRecall(1, 1, 2.0); var failure = m.UpdateRecall(0, 1, 2.0); Assert.AreEqual(3.0, success.Alpha, Epsilon, "success/alpha"); Assert.AreEqual(2.0, success.Beta, Epsilon, "success/beta"); Assert.AreEqual(2.0, failure.Alpha, Epsilon, "failure/alpha"); Assert.AreEqual(3.0, failure.Beta, Epsilon, "failure/beta"); }
public void MatchReferenceImplementation() { var stream = File.OpenRead("Ebisu/test.json"); using var testData = JsonDocument.Parse(stream); // Format of the test.json: // Array of test cases, where each test case is one of the following // ["update", [a, b, t0], [k, n, t], {"post": [a, b, t]}] // ["predict", [a, b, t0], [t], {"post": [mean]}] // test.json can be obtained from the ebisu python repository. foreach (var test in testData.RootElement.EnumerateArray()) { var data = test.EnumerateArray().ToArray(); var operation = data[0].GetString(); var modelData = data[1].EnumerateArray().ToArray(); var model = new EbisuModel( modelData[2].GetDouble(), modelData[0].GetDouble(), modelData[1].GetDouble()); switch (operation) { case "update": var paramData = data[2].EnumerateArray().ToArray(); var successes = paramData[0].GetInt32(); var total = paramData[1].GetInt32(); var time = paramData[2].GetDouble(); var expectedData = data[3].EnumerateObject().First() .Value.EnumerateArray().ToArray(); var expected = new EbisuModel( expectedData[2].GetDouble(), expectedData[0].GetDouble(), expectedData[1].GetDouble()); var updatedModel = model.UpdateRecall(successes, total, time); Assert.AreEqual(expected.Time, updatedModel.Time, Tolerance, $"Test: {test}"); Assert.AreEqual(expected.Alpha, updatedModel.Alpha, Tolerance, $"Test: {test}"); Assert.AreEqual(expected.Beta, updatedModel.Beta, Tolerance, $"Test: {test}"); break; case "predict": time = data[2].EnumerateArray().First().GetDouble(); var expectedRecall = data[3].EnumerateObject().First().Value .GetDouble(); var predictRecall = model.PredictRecall(time, true); Assert.AreEqual(expectedRecall, predictRecall, Tolerance); break; default: Assert.Fail("Reference data has invalid operation."); break; } } }
/// <summary> /// Update the parameters of a prior model with new observations and return /// an updated model with posterior distribution of recall probability at /// <paramref name="timeNow"/> time units after review. /// /// <paramref name="prior"/> is the given belief about remembrance of the fact. We /// attempt to calculate the posterior given additional data i.e. <paramref name="successes"/> /// indicating the successful recalls in <paramref name="total"/> review attempts in /// <paramref name="timeNow"/> duration since last review. /// </summary> /// <param name="prior">Existing model representing the beta distribution for a fact.</param> /// <param name="successes">Number of successful reviews for the fact.</param> /// <param name="total">Number of total reviews for the fact.</param> /// <param name="timeNow">Elapsed time units since last review was recorded.</param> /// <returns>Updated model for the fact.</returns> /// <remarks>By default, this method will rebalance the returned model to represent /// recall probability distribution after half life time units since last review. /// See <c>UpdateRecall</c> overload to modify this behavior.</remarks> public static EbisuModel UpdateRecall( this EbisuModel prior, int successes, int total, double timeNow) { return(prior.UpdateRecall( successes, total, timeNow, true, prior.Time)); }
[DataRow(new[] { 2.0, 2.0, 2.0 }, 0, 1, 2.0, new[] { 2.0, 3.0, 0.0 })] // fail recalls public void ShouldReturnModelWithUpdatedParameters( double[] priorParams, int successes, int total, double duration, double[] expectedParams) { var prior = new EbisuModel(priorParams[2], priorParams[0], priorParams[1]); var expected = new EbisuModel(expectedParams[2], expectedParams[0], expectedParams[1]); var actual = prior.UpdateRecall(successes, total, duration); Assert.AreEqual(expected.Alpha, actual.Alpha, Tolerance); }
/// <summary> /// Rebalance a proposed posterior model to ensure its <c>Alpha</c> and /// <c>Beta</c> parameters are close. /// Since <c>Alpha = Beta</c> implies half life, this operation keeps /// tries to update the shape parameters for numerical stability. /// </summary> /// <param name="prior">Existing memory model.</param> /// <param name="successes">Count of successful reviews.</param> /// <param name="total">Count of total number of reviews.</param> /// <param name="timeNow">Duration since last review.</param> /// <param name="proposed">Proposed memory model.</param> /// <returns>Updated model with duration nearer to the half life.</returns> private static EbisuModel Rebalance( this EbisuModel prior, int successes, int total, double timeNow, EbisuModel proposed) { double newAlpha = proposed.Alpha; double newBeta = proposed.Beta; if (newAlpha > 2 * newBeta || newBeta > 2 * newAlpha) { // Compute the elapsed time for this model to reach half its recall // probability i.e. half life double roughHalflife = ModelToPercentileDecay(proposed, 0.5, true, 1e-4); return(prior.UpdateRecall(successes, total, timeNow, false, roughHalflife)); } return(proposed); }
public void BetaLnFunctionForZeroSuccesses() { var delta = 0.05; var prior = new EbisuModel(1.0, 34.4, 3.4); var updatedModel = prior.UpdateRecall(0, 5, 0.1); Assert.AreEqual(3.0652051705190964, updatedModel.Time, delta); Assert.AreEqual(8.706432410647471, updatedModel.Beta, delta); Assert.AreEqual(8.760308130181903, updatedModel.Alpha, delta); #if NONE double timeNow = 0.1; double timeBack = prior.Time; int successes = 0; int total = 5; double alpha = prior.Alpha; double beta = prior.Beta; double t = prior.Time; double dt = timeNow / t; double et = timeBack / timeNow; var failures = total - successes; var binomlns = Enumerable.Range(0, failures + 1) .Select(i => BinomialLn(failures, i)).ToArray(); var logs1 = Enumerable.Range(0, 3) .Select(m => { var a = Enumerable.Range(0, failures + 1) .Select(i => binomlns[i] + SpecialFunctions.BetaLn( beta, alpha + (dt * (successes + i)) + (m * dt * et))) .ToList(); var b = Enumerable.Range(0, failures + 1) .Select(i => Math.Pow(-1.0, i)) .ToList(); return(LogSumExp(a, b)[0]); }) .ToArray(); var logs2 = Enumerable.Range(0, 3) .Select(m => { var a = Enumerable.Range(0, failures + 1) .Select(i => binomlns[i] + Beta.BetaLn( beta, alpha + (dt * (successes + i)) + (m * dt * et))) .ToList(); var b = Enumerable.Range(0, failures + 1) .Select(i => Math.Pow(-1.0, i)) .ToList(); return(LogSumExp(a, b)[0]); }) .ToArray(); for (int i = 0; i < logs1.Length; i++) { Assert.AreEqual(logs1[i], logs2[i], 1e-3); } #endif }