public void ComputeNormalizerWithNonTrivialLoop1() { StringAutomaton automaton = StringAutomaton.Zero(); var state = automaton.Start.AddTransition('a', Weight.FromValue(0.9)); state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One; state = state.AddTransition('a', Weight.FromValue(0.9)); state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One; state = state.AddTransition('a', Weight.FromValue(0.9), automaton.Start); state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One; AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); Assert.Equal(0.0, GetLogNormalizerByGetValue(automaton), 1e-6); Assert.Equal(0.0, GetLogNormalizerByGetValueWithTransducers(automaton), 1e-6); }
public void NonNormalizableLoop2() { StringAutomaton automaton = StringAutomaton.Zero(); var endState = automaton.Start.AddTransition('a', Weight.FromValue(2.0)); endState.SetEndWeight(Weight.FromValue(5.0)); endState.AddTransition('b', Weight.FromValue(0.1), automaton.Start); endState.AddTransition('c', Weight.FromValue(0.05), automaton.Start); endState.AddSelfTransition('!', Weight.FromValue(0.75)); StringAutomaton copyOfAutomaton = automaton.Clone(); Assert.Throws <InvalidOperationException>(() => copyOfAutomaton.NormalizeValues()); Assert.False(copyOfAutomaton.TryNormalizeValues()); ////Assert.Equal(f, copyOfF); // TODO: fix equality first }
/// <summary> /// Tests if the weights of all outgoing transitions sum to one for each state of a given automaton. /// </summary> /// <param name="automaton">The automaton.</param> private static void AssertStochastic(StringAutomaton automaton) { StringAutomaton automatonClone = automaton.Clone(); automatonClone.RemoveDeadStates(); for (int i = 0; i < automatonClone.States.Count; ++i) { Weight weightSum = automatonClone.States[i].EndWeight; for (int j = 0; j < automatonClone.States[i].TransitionCount; ++j) { weightSum = Weight.Sum(weightSum, automatonClone.States[i].GetTransition(j).Weight); } Assert.Equal(0.0, weightSum.LogValue, 1e-6); } }
public void NormalCdfIntegralTest() { Assert.True(0 <= NormalCdfIntegral(190187183095334850882507750944849586799124505055478568794871547478488387682304.0, -190187183095334850882507750944849586799124505055478568794871547478488387682304.0, -1, 0.817880416082724044547388352452631856079457366800004151664125953519049673808376291470533145141236089924006896061006277409614237094627499958581030715374379576478204968748786874650796450332240045653919846557755590765736997127532958984375e-78).Mantissa); Assert.True(0 <= NormalCdfIntegral(213393529.2046706974506378173828125, -213393529.2046706974506378173828125, -1, 0.72893668811495072384656764856902984306419313043079455383121967315673828125e-9).Mantissa); Assert.True(0 < NormalCdfIntegral(-0.421468532207607216033551367218024097383022308349609375, 0.42146843802130329326161017888807691633701324462890625, -0.99999999999999989, 0.62292398855983019004972723654291189010479001808562316000461578369140625e-8).Mantissa); Parallel.ForEach(OperatorTests.Doubles(), x => { foreach (var y in OperatorTests.Doubles()) { foreach (var r in OperatorTests.Doubles().Where(d => d >= -1 && d <= 1)) { MMath.NormalCdfIntegral(x, y, r); } } }); }
public void SparseBernoulliListArithmetic() { double tolerance = 1e-10; var commonValue1 = new Bernoulli(0.1); var commonValue2 = new Bernoulli(0.2); var specialValue1 = new Bernoulli(0.7); var specialValue2 = new Bernoulli(0.8); var specialValue3 = new Bernoulli(0.9); var listSize = 100; var sparseBernoulliList1 = SparseBernoulliList.Constant(listSize, commonValue1); var sparseBernoulliList2 = SparseBernoulliList.Constant(listSize, commonValue2); sparseBernoulliList1[20] = specialValue1; sparseBernoulliList1[55] = specialValue2; sparseBernoulliList2[25] = specialValue2; sparseBernoulliList2[55] = specialValue3; // Product var product = sparseBernoulliList1 * sparseBernoulliList2; Assert.Equal(3, product.SparseValues.Count); Assert.Equal(commonValue1 * commonValue2, product.CommonValue); Assert.Equal(specialValue1 * commonValue2, product[20]); Assert.Equal(commonValue1 * specialValue2, product[25]); Assert.Equal(specialValue2 * specialValue3, product[55]); // Ratio var ratio = sparseBernoulliList1 / sparseBernoulliList2; Assert.Equal(2, ratio.SparseValues.Count); Assert.Equal((commonValue1 / commonValue2).GetProbTrue(), ratio.CommonValue.GetProbTrue(), tolerance); Assert.Equal((specialValue1 / commonValue2).GetProbTrue(), ratio[20].GetProbTrue(), tolerance); Assert.Equal((commonValue1 / specialValue2).GetProbTrue(), ratio[25].GetProbTrue(), tolerance); Assert.Equal((specialValue2 / specialValue3).GetProbTrue(), ratio[55].GetProbTrue(), tolerance); // Power var exponent = 1.2; var power = sparseBernoulliList1 ^ exponent; Assert.Equal(2, power.SparseValues.Count); Assert.Equal(commonValue1 ^ exponent, power.CommonValue); Assert.Equal(specialValue1 ^ exponent, power[20]); Assert.Equal(specialValue2 ^ exponent, power[55]); }
public void AutomatonNormalizationPerformance2() { Assert.Timeout(() => { StringAutomaton automaton = StringAutomaton.Zero(); var nextState = automaton.Start.AddTransitionsForSequence("abc"); nextState.EndWeight = Weight.One; nextState.AddSelfTransition('d', Weight.FromValue(0.1)); nextState = nextState.AddTransitionsForSequence("efg"); nextState.EndWeight = Weight.One; nextState.AddSelfTransition('h', Weight.FromValue(0.2)); nextState = nextState.AddTransitionsForSequence("grlkhgn;lk3rng"); nextState.EndWeight = Weight.One; nextState.AddSelfTransition('h', Weight.FromValue(0.3)); ProfileAction(() => automaton.GetLogNormalizer(), 100000); }, 20000); }
private static void Test(string name, IDistribution <string> dist, params string[] vals) { var sa = (StringDistribution)dist; Console.Write(name + "=" + sa); double sum = 0; foreach (var s in vals) { var logProb = dist.GetLogProb(s); sum += Math.Exp(logProb); } var ok = Math.Abs(sum - 1.0) < 1E-8; var valstr = string.Join("|", vals.Select(s => s + "$").ToArray()); Assert.True(ok, $"Result was {sa} should be ({valstr})"); }
public void ProductWithGroups() { StringDistribution lhsWithoutGroup = StringDistribution.String("ab"); var weightFunction = lhsWithoutGroup.GetWorkspaceOrPoint(); var transitionWithGroup = weightFunction.Start.GetTransitions()[0]; transitionWithGroup.Group = 1; weightFunction.Start.SetTransition(0, transitionWithGroup); StringDistribution lhs = StringDistribution.FromWeightFunction(weightFunction); StringDistribution rhs = StringDistribution.OneOf("ab", "ac"); Assert.True(lhs.GetWorkspaceOrPoint().HasGroup(1)); Assert.False(rhs.GetWorkspaceOrPoint().UsesGroups()); var result = StringDistribution.Zero(); result.SetToProduct(lhs, rhs); Assert.True(result.GetWorkspaceOrPoint().HasGroup(1)); }
public void CopySequence2() { const int CopyCount = 10; const int LetterCount = 20; StringTransducer copy = StringTransducer.Copy(); for (int i = 0; i < CopyCount - 1; ++i) { copy.AppendInPlace(StringTransducer.Copy()); } var sequence = new string(Enumerable.Repeat('a', LetterCount).ToArray()); StringAutomaton result = copy.ProjectSource(sequence); var expectedLogValue = Math.Log(StringInferenceTestUtilities.Partitions(LetterCount, CopyCount)); Assert.Equal(expectedLogValue, result.GetLogValue(sequence), 1e-8); }
public void UniformDetection() { StringDistribution dist1 = StringDistribution.OneOf(0.0, StringDistribution.Zero(), 1.0, StringDistribution.Any()); Assert.True(dist1.IsUniform()); StringInferenceTestUtilities.TestProbability(dist1, 1.0, string.Empty, "a", "bc"); StringDistribution dist2 = StringDistribution.OneOf(0.3, StringDistribution.Any(), 0.7, StringDistribution.Any()); Assert.True(dist2.IsUniform()); StringInferenceTestUtilities.TestProbability(dist1, 1.0, string.Empty, "a", "bc"); StringDistribution dist3 = StringDistribution.OneOf(1.0, StringDistribution.Any(), 2.0, StringDistribution.OneOf(0.1, StringDistribution.Any(), 0.2, StringDistribution.Any())); Assert.True(dist3.IsUniform()); StringInferenceTestUtilities.TestProbability(dist3, 1.0, string.Empty, "a", "bc"); }
public void SemanticWebTest1() { var prop0 = "Tony Blair"; var prop1 = "6 May 1953"; var template = Variable.Random(StringDistribution.Any()); var text = Variable.StringFormat(template, prop0, prop1); var engine = new InferenceEngine(); engine.Compiler.RecommendedQuality = QualityBand.Experimental; engine.NumberOfIterations = 1; var textDist = engine.Infer <StringDistribution>(text); Console.WriteLine("textDist={0}", textDist); Assert.False(double.IsNegativeInfinity(textDist.GetLogProb("6 May 1953 is the date of birth of Tony Blair."))); }
public void StringFormatPerformanceTest2() { Assert.Timeout(() => { Rand.Restart(777); Variable <string> template = Variable.Constant("{0} {1}").Named("template"); Variable <string> arg1 = Variable.Random(StringDistribution.Any(minLength: 1, maxLength: 15)).Named("arg1"); Variable <string> arg2 = Variable.Random(StringDistribution.Any(minLength: 1)).Named("arg2"); Variable <string> text = Variable.StringFormat(template, arg1, arg2).Named("text"); Variable <string> fullTextFormat = Variable.Random(StringDistribution.Any()).Named("fullTextFormat"); Variable <string> fullText = Variable.StringFormat(fullTextFormat, text).Named("fullText"); var engine = new InferenceEngine(); engine.Compiler.RecommendedQuality = QualityBand.Experimental; engine.ShowProgress = false; Action action = () => { // Generate random observed string string observedPattern = string.Empty; for (int j = 0; j < 5; ++j) { for (int k = 0; k < 5; ++k) { observedPattern += (char)Rand.Int('a', 'z' + 1); } observedPattern += ' '; } // Run inference fullText.ObservedValue = observedPattern; engine.Infer <StringDistribution>(arg1); engine.Infer <StringDistribution>(arg2); engine.Infer <StringDistribution>(text); engine.Infer <StringDistribution>(fullTextFormat); }; action(); // To exclude the compilation time from the profile ProfileAction(action, 100); }, 10000); }
public void NonNormalizableLoop4() { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddSelfTransition('a', Weight.FromValue(0.1)); var branch1 = automaton.Start.AddTransition('a', Weight.FromValue(2.0)); branch1.AddSelfTransition('a', Weight.FromValue(2.0)); branch1.EndWeight = Weight.One; var branch2 = automaton.Start.AddTransition('a', Weight.FromValue(2.0)); branch2.EndWeight = Weight.One; StringAutomaton copyOfAutomaton = automaton.Clone(); Assert.Throws <InvalidOperationException>(() => automaton.NormalizeValues()); Assert.False(copyOfAutomaton.TryNormalizeValues()); ////Assert.Equal(f, copyOfF); // TODO: fix equality first }
public void SampleFiniteSupport() { Rand.Restart(69); StringDistribution dist = StringDistribution.OneOf("a", "ab").Append(StringDistribution.OneOf("c", "bc")); const int SampleCount = 10000; int[] sampleCounts = new int[3]; for (int i = 0; i < SampleCount; ++i) { string sample = dist.Sample(); int sampleIndex = sample == "ac" ? 0 : sample == "abc" ? 1 : 2; ++sampleCounts[sampleIndex]; } Assert.Equal(0.25, sampleCounts[0] / (double)SampleCount, 1e-2); Assert.Equal(0.5, sampleCounts[1] / (double)SampleCount, 1e-2); Assert.Equal(0.25, sampleCounts[2] / (double)SampleCount, 1e-2); }
public void AutomatonNormalizationPerformance3() { Assert.Timeout(() => { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddSelfTransition('a', Weight.FromValue(0.5)); automaton.Start.EndWeight = Weight.One; var nextState = automaton.Start.AddTransitionsForSequence("aa"); nextState.AddSelfTransition('a', Weight.FromValue(0.5)); nextState.EndWeight = Weight.One; for (int i = 0; i < 3; ++i) { automaton = automaton.Product(automaton); } ProfileAction(() => automaton.GetLogNormalizer(), 100); }, 120000); }
public void AssertEqualTo(MyClass that) { Assert.Equal(0, this.bernoulli.MaxDiff(that.bernoulli)); Assert.Equal(0, this.beta.MaxDiff(that.beta)); Assert.Equal(0, this.binomial.MaxDiff(that.binomial)); Assert.Equal(0, this.conjugateDirichlet.MaxDiff(that.conjugateDirichlet)); Assert.Equal(0, this.dirichlet.MaxDiff(that.dirichlet)); Assert.Equal(0, this.discrete.MaxDiff(that.discrete)); Assert.Equal(0, this.gamma.MaxDiff(that.gamma)); Assert.Equal(0, this.gammaPower.MaxDiff(that.gammaPower)); Assert.Equal(0, this.gaussian.MaxDiff(that.gaussian)); Assert.Equal(0, this.nonconjugateGaussian.MaxDiff(that.nonconjugateGaussian)); Assert.Equal(0, this.pointMass.MaxDiff(that.pointMass)); Assert.Equal(0, this.sparseBernoulliList.MaxDiff(that.sparseBernoulliList)); Assert.Equal(0, this.sparseBetaList.MaxDiff(that.sparseBetaList)); Assert.Equal(0, this.sparseGammaList.MaxDiff(that.sparseGammaList)); Assert.Equal(0, this.truncatedGamma.MaxDiff(that.truncatedGamma)); Assert.Equal(0, this.truncatedGaussian.MaxDiff(that.truncatedGaussian)); Assert.Equal(0, this.wrappedGaussian.MaxDiff(that.wrappedGaussian)); Assert.Equal(0, this.sparseGaussianList.MaxDiff(that.sparseGaussianList)); Assert.Equal(0, this.unnormalizedDiscrete.MaxDiff(that.unnormalizedDiscrete)); Assert.Equal(0, this.vectorGaussian.MaxDiff(that.vectorGaussian)); Assert.Equal(0, this.wishart.MaxDiff(that.wishart)); Assert.Equal(0, this.pareto.MaxDiff(that.pareto)); Assert.Equal(0, this.poisson.MaxDiff(that.poisson)); Assert.Equal(0, ga.MaxDiff(that.ga)); Assert.Equal(0, vga.MaxDiff(that.vga)); Assert.Equal(0, ga2D.MaxDiff(that.ga2D)); Assert.Equal(0, vga2D.MaxDiff(that.vga2D)); Assert.Equal(0, gaJ.MaxDiff(that.gaJ)); Assert.Equal(0, vgaJ.MaxDiff(that.vgaJ)); Assert.Equal(0, this.sparseGp.MaxDiff(that.sparseGp)); Assert.True(this.quantileEstimator.ValueEquals(that.quantileEstimator)); Assert.True(this.innerQuantiles.Equals(that.innerQuantiles)); Assert.True(this.outerQuantiles.Equals(that.outerQuantiles)); if (this.stringDistribution1 != null) { Assert.Equal(0, this.stringDistribution1.MaxDiff(that.stringDistribution1)); Assert.Equal(0, this.stringDistribution2.MaxDiff(that.stringDistribution2)); } }
public void UniformOf() { var unif1 = StringDistribution.ZeroOrMore(DiscreteChar.Lower()); Assert.False(unif1.IsUniform()); Assert.False(unif1.IsProper()); StringInferenceTestUtilities.TestProbability(unif1, 1.0, "hello", "a", string.Empty); StringInferenceTestUtilities.TestProbability(unif1, 0.0, "123", "!", "Abc"); // Test if non-uniform element distribution does not affect the outcome Vector probs = DiscreteChar.Digit().GetProbs(); probs['1'] = 0; probs['2'] = 0.3; probs['3'] = 0.0001; var unif2 = StringDistribution.ZeroOrMore(DiscreteChar.FromVector(probs)); StringInferenceTestUtilities.TestProbability(unif2, 1.0, "0", "234", string.Empty); StringInferenceTestUtilities.TestProbability(unif2, 0.0, "1", "231", "!", "Abc"); }
public void PointMassToString() { StringDistribution point = StringDistribution.PointMass("ab\""); Assert.Equal("ab\"", point.ToString(SequenceDistributionFormats.Friendly)); Assert.Equal("ab\"", point.ToString(SequenceDistributionFormats.Regexp)); Assert.Equal( @"digraph finite_state_machine {" + Environment.NewLine + @" rankdir=LR;" + Environment.NewLine + @" node [shape = doublecircle; label = ""0\nE=0""]; N0" + Environment.NewLine + @" node [shape = circle; label = ""1\nE=0""]; N1" + Environment.NewLine + @" node [shape = circle; label = ""2\nE=0""]; N2" + Environment.NewLine + @" node [shape = circle; label = ""3\nE=1""]; N3" + Environment.NewLine + @" N0 -> N1 [ label = ""W=1\na"" ];" + Environment.NewLine + @" N1 -> N2 [ label = ""W=1\nb"" ];" + Environment.NewLine + @" N2 -> N3 [ label = ""W=1\n\"""" ];" + Environment.NewLine + @"}" + Environment.NewLine , point.ToString(SequenceDistributionFormats.GraphViz)); }
public void SparseDiscreteNormalise() { double commonValue = 1.0; double nonCommonValue = 10.0; Discrete d = new Discrete( SparseVector.FromSparseValues(100, 1.0, new List <SparseElement> { new SparseElement(20, nonCommonValue), new SparseElement(55, nonCommonValue) })); Vector v = d.GetProbs(); Assert.Equal(Sparsity.Sparse, d.Sparsity); double sum = (v.Count - 2) * commonValue + 2 * nonCommonValue; SparseVector sv = (SparseVector)v; Assert.Equal(sv.CommonValue, commonValue / sum); Assert.Equal(2, sv.SparseValues.Count); }
public void AutomatonNormalizationPerformance2() { Assert.Timeout(() => { var builder = new StringAutomaton.Builder(); var nextState = builder.Start.AddTransitionsForSequence("abc"); nextState.SetEndWeight(Weight.One); nextState.AddSelfTransition('d', Weight.FromValue(0.1)); nextState = nextState.AddTransitionsForSequence("efg"); nextState.SetEndWeight(Weight.One); nextState.AddSelfTransition('h', Weight.FromValue(0.2)); nextState = nextState.AddTransitionsForSequence("grlkhgn;lk3rng"); nextState.SetEndWeight(Weight.One); nextState.AddSelfTransition('h', Weight.FromValue(0.3)); var automaton = builder.GetAutomaton(); ProfileAction(() => automaton.GetLogNormalizer(), 100000); }, 20000); }
public void AutomatonNormalizationPerformance3() { Assert.Timeout(() => { var builder = new StringAutomaton.Builder(); builder.Start.AddSelfTransition('a', Weight.FromValue(0.5)); builder.Start.SetEndWeight(Weight.One); var nextState = builder.Start.AddTransitionsForSequence("aa"); nextState.AddSelfTransition('a', Weight.FromValue(0.5)); nextState.SetEndWeight(Weight.One); var automaton = builder.GetAutomaton(); for (int i = 0; i < 3; ++i) { automaton = automaton.Product(automaton); } ProfileAction(() => automaton.GetLogNormalizer(), 100); }, 120000); }
/// <summary> /// A helper function for testing messages to <c>format</c>. /// </summary> /// <param name="str">The message from <c>str</c>.</param> /// <param name="args">The message from <c>args</c>.</param> /// <param name="expectedFormatRequireEveryPlaceholder"> /// The expected message to <c>format</c> if the format string is required to contain placeholders for all arguments. /// </param> /// <param name="expectedFormatAllowMissingPlaceholders"> /// The expected message to <c>format</c> if the format string may not contain placeholders for some arguments. /// </param> private static void TestMessageToFormat( StringDistribution str, StringDistribution[] args, StringDistribution expectedFormatRequireEveryPlaceholder, StringDistribution expectedFormatAllowMissingPlaceholders) { string[] argNames = GetDefaultArgumentNames(args.Length); Assert.Equal( expectedFormatRequireEveryPlaceholder, StringFormatOp_RequireEveryPlaceholder_NoArgumentNames.FormatAverageConditional(str, args)); Assert.Equal( expectedFormatRequireEveryPlaceholder, StringFormatOp_RequireEveryPlaceholder.FormatAverageConditional(str, args, argNames)); Assert.Equal( expectedFormatAllowMissingPlaceholders, StringFormatOp_AllowMissingPlaceholders_NoArgumentNames.FormatAverageConditional(str, args)); Assert.Equal( expectedFormatAllowMissingPlaceholders, StringFormatOp_AllowMissingPlaceholders.FormatAverageConditional(str, args, argNames)); }
public static void TestAutomatonPropertyPreservation(StringAutomaton automaton, Func <StringAutomaton, StringAutomaton> testedOperation) { var automatonWithClearProperties = automaton .WithLogValueOverride(null) .WithPruneStatesWithLogEndWeightLessThan(null); var outputForClearProperties = testedOperation(automatonWithClearProperties); Assert.Null(outputForClearProperties.LogValueOverride); Assert.Null(outputForClearProperties.PruneStatesWithLogEndWeightLessThan); var automatonWithSetProperties = automaton .WithLogValueOverride(-1) .WithPruneStatesWithLogEndWeightLessThan(-128); var outputForSetProperties = testedOperation(automatonWithSetProperties); Assert.Equal(automatonWithSetProperties.LogValueOverride, outputForSetProperties.LogValueOverride); Assert.Equal(automatonWithSetProperties.PruneStatesWithLogEndWeightLessThan, outputForSetProperties.PruneStatesWithLogEndWeightLessThan); }
public void ComputeNormalizerWithManyNonTrivialLoops1() { StringAutomaton automaton = StringAutomaton.Zero(); AddEpsilonLoop(automaton.Start, 3, 0.2); AddEpsilonLoop(automaton.Start, 5, 0.3); automaton.Start.EndWeight = Weight.FromValue(0.1); var nextState = automaton.Start.AddTransition('a', Weight.FromValue(0.4)); nextState.EndWeight = Weight.FromValue(0.6); AddEpsilonLoop(nextState, 0, 0.3); nextState = nextState.AddTransition('b', Weight.FromValue(0.1)); AddEpsilonLoop(nextState, 1, 0.9); nextState.EndWeight = Weight.FromValue(0.1); AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); Assert.Equal(0.0, GetLogNormalizerByGetValue(automaton), 1e-6); Assert.Equal(0.0, GetLogNormalizerByGetValueWithTransducers(automaton), 1e-6); }
public void ComputeNormalizerSimple3() { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddSelfTransition('a', Weight.FromValue(0.7)); automaton.Start.EndWeight = Weight.FromValue(0.1); var state1 = automaton.Start.AddTransition('b', Weight.FromValue(0.15)); state1.AddSelfTransition('a', Weight.FromValue(0.4)); state1.EndWeight = Weight.FromValue(0.6); var state2 = automaton.Start.AddTransition('c', Weight.FromValue(0.05)); state2.EndWeight = Weight.One; AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); Assert.Equal(0.0, GetLogNormalizerByGetValue(automaton), 1e-6); Assert.Equal(0.0, GetLogNormalizerByGetValueWithTransducers(automaton), 1e-6); }
public void ProductWithGroups() { StringDistribution lhsWithoutGroup = StringDistribution.String("ab"); // add a group to first transition of the start state var weightFunctionBuilder = StringAutomaton.Builder.FromAutomaton(lhsWithoutGroup.GetWorkspaceOrPoint()); var transitionIterator = weightFunctionBuilder.Start.TransitionIterator; var transitionWithGroup = transitionIterator.Value; transitionWithGroup.Group = 1; transitionIterator.Value = transitionWithGroup; StringDistribution lhs = StringDistribution.FromWeightFunction(weightFunctionBuilder.GetAutomaton()); StringDistribution rhs = StringDistribution.OneOf("ab", "ac"); Assert.True(lhs.GetWorkspaceOrPoint().HasGroup(1)); Assert.False(rhs.GetWorkspaceOrPoint().UsesGroups); var result = StringDistribution.Zero(); result.SetToProduct(lhs, rhs); Assert.True(result.GetWorkspaceOrPoint().HasGroup(1)); }
public void Capitalized() { int lowercaseCharacterCount = DiscreteChar.Lower().GetProbs().Count(p => p > 0); int uppercaseCharacterCount = DiscreteChar.Upper().GetProbs().Count(p => p > 0); var capitalizedAutomaton1 = StringDistribution.Capitalized(minLength: 3, maxLength: 5); Assert.True(capitalizedAutomaton1.IsProper()); StringInferenceTestUtilities.TestProbability( capitalizedAutomaton1, StringInferenceTestUtilities.StringUniformProbability(2, 4, lowercaseCharacterCount) / uppercaseCharacterCount, "Abc", "Bcde", "Abcde"); StringInferenceTestUtilities.TestProbability(capitalizedAutomaton1, 0.0, "A", "abc", "Ab", "Abcdef", string.Empty); var capitalizedAutomaton2 = StringDistribution.Capitalized(minLength: 3); Assert.False(capitalizedAutomaton2.IsProper()); StringInferenceTestUtilities.TestProbability(capitalizedAutomaton2, 1.0, "Abc", "Bcde", "Abcde", "Abfjrhfjlrl"); StringInferenceTestUtilities.TestProbability(capitalizedAutomaton2, 0.0, "A", "abc", "Ab", string.Empty); }
/// <summary> /// A helper function for testing messages to <c>args</c>. /// </summary> /// <param name="str">The message from <c>str</c>.</param> /// <param name="format">The message from <c>format</c>.</param> /// <param name="args">The message from <c>args</c>.</param> /// <param name="expectedArgsRequireEveryPlaceholder"> /// The expected message to <c>args</c> if the format string is required to contain placeholders for all arguments. /// </param> /// <param name="expectedArgsAllowMissingPlaceholders"> /// The expected message to <c>args</c> if the format string may not contain placeholders for some arguments. /// </param> private static void TestMessageToArgs( StringDistribution str, StringDistribution format, StringDistribution[] args, StringDistribution[] expectedArgsRequireEveryPlaceholder, StringDistribution[] expectedArgsAllowMissingPlaceholders) { string[] argNames = GetDefaultArgumentNames(args.Length); var result = new StringDistribution[args.Length]; Assert.Equal( expectedArgsRequireEveryPlaceholder, StringFormatOp_RequireEveryPlaceholder_NoArgumentNames.ArgsAverageConditional(str, format, args, result)); Assert.Equal( expectedArgsRequireEveryPlaceholder, StringFormatOp_RequireEveryPlaceholder.ArgsAverageConditional(str, format, args, argNames, result)); Assert.Equal( expectedArgsAllowMissingPlaceholders, StringFormatOp_AllowMissingPlaceholders_NoArgumentNames.ArgsAverageConditional(str, format, args, result)); var actualArgs = StringFormatOp_AllowMissingPlaceholders.ArgsAverageConditional(str, format, args, argNames, result); Assert.Equal(expectedArgsAllowMissingPlaceholders, actualArgs); }
public void LoopyEpsilonClosure1() { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddEpsilonTransition(Weight.FromValue(0.5), automaton.Start); var nextState = automaton.Start.AddEpsilonTransition(Weight.FromValue(0.4)); nextState.AddEpsilonTransition(Weight.One).AddEpsilonTransition(Weight.One, automaton.Start); automaton.Start.EndWeight = Weight.FromValue(0.1); AssertStochastic(automaton); StringAutomaton.EpsilonClosure startClosure = automaton.Start.GetEpsilonClosure(); Assert.Equal(3, startClosure.Size); Assert.Equal(0.0, startClosure.EndWeight.LogValue, 1e-8); for (int i = 0; i < startClosure.Size; ++i) { Weight weight = startClosure.GetStateWeightByIndex(i); double expectedWeight = startClosure.GetStateByIndex(i) == automaton.Start ? 10 : 4; Assert.Equal(expectedWeight, weight.Value, 1e-8); } }
public void ImpossibleBranchTest4() { Variable <string> str1 = Variable.StringLower().Named("str1"); Variable <string> str2 = Variable.StringLower().Named("str2"); Variable <string> str3 = Variable.StringLower().Named("str3"); Variable <string> text = Variable.New <string>().Named("text"); Variable <int> selector = Variable.DiscreteUniform(3).Named("selector"); using (Variable.Case(selector, 0)) { text.SetTo(str1 + Variable.Constant(" ") + str2 + Variable.Constant(" ") + str3); } using (Variable.Case(selector, 1)) { text.SetTo(str1 + Variable.Constant(" ") + str3); } using (Variable.Case(selector, 2)) { text.SetTo(str1); } text.ObservedValue = "abc def"; var engine = new InferenceEngine(); var selectorPosterior = engine.Infer <Discrete>(selector); var str1Posterior = engine.Infer <StringDistribution>(str1); var str2Posterior = engine.Infer <StringDistribution>(str2); var str3Posterior = engine.Infer <StringDistribution>(str3); Assert.True(selectorPosterior.IsPointMass && selectorPosterior.Point == 1); Assert.True(str1Posterior.IsPointMass && str1Posterior.Point == "abc"); Assert.Equal(StringDistribution.Lower(), str2Posterior); Assert.True(str3Posterior.IsPointMass && str3Posterior.Point == "def"); }