コード例 #1
0
        public void ComputeNormalizerWithNonTrivialLoop1()
        {
            StringAutomaton automaton = StringAutomaton.Zero();

            var state = automaton.Start.AddTransition('a', Weight.FromValue(0.9));

            state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One;
            state = state.AddTransition('a', Weight.FromValue(0.9));
            state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One;
            state = state.AddTransition('a', Weight.FromValue(0.9), automaton.Start);
            state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One;

            AssertStochastic(automaton);
            Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
            Assert.Equal(0.0, GetLogNormalizerByGetValue(automaton), 1e-6);
            Assert.Equal(0.0, GetLogNormalizerByGetValueWithTransducers(automaton), 1e-6);
        }
コード例 #2
0
        public void NonNormalizableLoop2()
        {
            StringAutomaton automaton = StringAutomaton.Zero();

            var endState = automaton.Start.AddTransition('a', Weight.FromValue(2.0));

            endState.SetEndWeight(Weight.FromValue(5.0));
            endState.AddTransition('b', Weight.FromValue(0.1), automaton.Start);
            endState.AddTransition('c', Weight.FromValue(0.05), automaton.Start);
            endState.AddSelfTransition('!', Weight.FromValue(0.75));

            StringAutomaton copyOfAutomaton = automaton.Clone();

            Assert.Throws <InvalidOperationException>(() => copyOfAutomaton.NormalizeValues());
            Assert.False(copyOfAutomaton.TryNormalizeValues());
            ////Assert.Equal(f, copyOfF); // TODO: fix equality first
        }
コード例 #3
0
        /// <summary>
        /// Tests if the weights of all outgoing transitions sum to one for each state of a given automaton.
        /// </summary>
        /// <param name="automaton">The automaton.</param>
        private static void AssertStochastic(StringAutomaton automaton)
        {
            StringAutomaton automatonClone = automaton.Clone();

            automatonClone.RemoveDeadStates();

            for (int i = 0; i < automatonClone.States.Count; ++i)
            {
                Weight weightSum = automatonClone.States[i].EndWeight;
                for (int j = 0; j < automatonClone.States[i].TransitionCount; ++j)
                {
                    weightSum = Weight.Sum(weightSum, automatonClone.States[i].GetTransition(j).Weight);
                }

                Assert.Equal(0.0, weightSum.LogValue, 1e-6);
            }
        }
コード例 #4
0
        public void NormalCdfIntegralTest()
        {
            Assert.True(0 <= NormalCdfIntegral(190187183095334850882507750944849586799124505055478568794871547478488387682304.0, -190187183095334850882507750944849586799124505055478568794871547478488387682304.0, -1, 0.817880416082724044547388352452631856079457366800004151664125953519049673808376291470533145141236089924006896061006277409614237094627499958581030715374379576478204968748786874650796450332240045653919846557755590765736997127532958984375e-78).Mantissa);
            Assert.True(0 <= NormalCdfIntegral(213393529.2046706974506378173828125, -213393529.2046706974506378173828125, -1, 0.72893668811495072384656764856902984306419313043079455383121967315673828125e-9).Mantissa);
            Assert.True(0 < NormalCdfIntegral(-0.421468532207607216033551367218024097383022308349609375, 0.42146843802130329326161017888807691633701324462890625, -0.99999999999999989, 0.62292398855983019004972723654291189010479001808562316000461578369140625e-8).Mantissa);

            Parallel.ForEach(OperatorTests.Doubles(), x =>
            {
                foreach (var y in OperatorTests.Doubles())
                {
                    foreach (var r in OperatorTests.Doubles().Where(d => d >= -1 && d <= 1))
                    {
                        MMath.NormalCdfIntegral(x, y, r);
                    }
                }
            });
        }
コード例 #5
0
        public void SparseBernoulliListArithmetic()
        {
            double tolerance     = 1e-10;
            var    commonValue1  = new Bernoulli(0.1);
            var    commonValue2  = new Bernoulli(0.2);
            var    specialValue1 = new Bernoulli(0.7);
            var    specialValue2 = new Bernoulli(0.8);
            var    specialValue3 = new Bernoulli(0.9);

            var listSize             = 100;
            var sparseBernoulliList1 = SparseBernoulliList.Constant(listSize, commonValue1);
            var sparseBernoulliList2 = SparseBernoulliList.Constant(listSize, commonValue2);

            sparseBernoulliList1[20] = specialValue1;
            sparseBernoulliList1[55] = specialValue2;
            sparseBernoulliList2[25] = specialValue2;
            sparseBernoulliList2[55] = specialValue3;

            // Product
            var product = sparseBernoulliList1 * sparseBernoulliList2;

            Assert.Equal(3, product.SparseValues.Count);
            Assert.Equal(commonValue1 * commonValue2, product.CommonValue);
            Assert.Equal(specialValue1 * commonValue2, product[20]);
            Assert.Equal(commonValue1 * specialValue2, product[25]);
            Assert.Equal(specialValue2 * specialValue3, product[55]);

            // Ratio
            var ratio = sparseBernoulliList1 / sparseBernoulliList2;

            Assert.Equal(2, ratio.SparseValues.Count);
            Assert.Equal((commonValue1 / commonValue2).GetProbTrue(), ratio.CommonValue.GetProbTrue(), tolerance);
            Assert.Equal((specialValue1 / commonValue2).GetProbTrue(), ratio[20].GetProbTrue(), tolerance);
            Assert.Equal((commonValue1 / specialValue2).GetProbTrue(), ratio[25].GetProbTrue(), tolerance);
            Assert.Equal((specialValue2 / specialValue3).GetProbTrue(), ratio[55].GetProbTrue(), tolerance);

            // Power
            var exponent = 1.2;
            var power    = sparseBernoulliList1 ^ exponent;

            Assert.Equal(2, power.SparseValues.Count);
            Assert.Equal(commonValue1 ^ exponent, power.CommonValue);
            Assert.Equal(specialValue1 ^ exponent, power[20]);
            Assert.Equal(specialValue2 ^ exponent, power[55]);
        }
コード例 #6
0
        public void AutomatonNormalizationPerformance2()
        {
            Assert.Timeout(() =>
            {
                StringAutomaton automaton = StringAutomaton.Zero();
                var nextState             = automaton.Start.AddTransitionsForSequence("abc");
                nextState.EndWeight       = Weight.One;
                nextState.AddSelfTransition('d', Weight.FromValue(0.1));
                nextState           = nextState.AddTransitionsForSequence("efg");
                nextState.EndWeight = Weight.One;
                nextState.AddSelfTransition('h', Weight.FromValue(0.2));
                nextState           = nextState.AddTransitionsForSequence("grlkhgn;lk3rng");
                nextState.EndWeight = Weight.One;
                nextState.AddSelfTransition('h', Weight.FromValue(0.3));

                ProfileAction(() => automaton.GetLogNormalizer(), 100000);
            }, 20000);
        }
コード例 #7
0
        private static void Test(string name, IDistribution <string> dist, params string[] vals)
        {
            var sa = (StringDistribution)dist;

            Console.Write(name + "=" + sa);
            double sum = 0;

            foreach (var s in vals)
            {
                var logProb = dist.GetLogProb(s);
                sum += Math.Exp(logProb);
            }

            var ok     = Math.Abs(sum - 1.0) < 1E-8;
            var valstr = string.Join("|", vals.Select(s => s + "$").ToArray());

            Assert.True(ok, $"Result was {sa} should be ({valstr})");
        }
コード例 #8
0
        public void ProductWithGroups()
        {
            StringDistribution lhsWithoutGroup = StringDistribution.String("ab");
            var weightFunction      = lhsWithoutGroup.GetWorkspaceOrPoint();
            var transitionWithGroup = weightFunction.Start.GetTransitions()[0];

            transitionWithGroup.Group = 1;
            weightFunction.Start.SetTransition(0, transitionWithGroup);
            StringDistribution lhs = StringDistribution.FromWeightFunction(weightFunction);
            StringDistribution rhs = StringDistribution.OneOf("ab", "ac");

            Assert.True(lhs.GetWorkspaceOrPoint().HasGroup(1));
            Assert.False(rhs.GetWorkspaceOrPoint().UsesGroups());
            var result = StringDistribution.Zero();

            result.SetToProduct(lhs, rhs);
            Assert.True(result.GetWorkspaceOrPoint().HasGroup(1));
        }
コード例 #9
0
        public void CopySequence2()
        {
            const int CopyCount   = 10;
            const int LetterCount = 20;

            StringTransducer copy = StringTransducer.Copy();

            for (int i = 0; i < CopyCount - 1; ++i)
            {
                copy.AppendInPlace(StringTransducer.Copy());
            }

            var             sequence         = new string(Enumerable.Repeat('a', LetterCount).ToArray());
            StringAutomaton result           = copy.ProjectSource(sequence);
            var             expectedLogValue = Math.Log(StringInferenceTestUtilities.Partitions(LetterCount, CopyCount));

            Assert.Equal(expectedLogValue, result.GetLogValue(sequence), 1e-8);
        }
コード例 #10
0
        public void UniformDetection()
        {
            StringDistribution dist1 = StringDistribution.OneOf(0.0, StringDistribution.Zero(), 1.0, StringDistribution.Any());

            Assert.True(dist1.IsUniform());
            StringInferenceTestUtilities.TestProbability(dist1, 1.0, string.Empty, "a", "bc");

            StringDistribution dist2 = StringDistribution.OneOf(0.3, StringDistribution.Any(), 0.7, StringDistribution.Any());

            Assert.True(dist2.IsUniform());
            StringInferenceTestUtilities.TestProbability(dist1, 1.0, string.Empty, "a", "bc");

            StringDistribution dist3 =
                StringDistribution.OneOf(1.0, StringDistribution.Any(), 2.0, StringDistribution.OneOf(0.1, StringDistribution.Any(), 0.2, StringDistribution.Any()));

            Assert.True(dist3.IsUniform());
            StringInferenceTestUtilities.TestProbability(dist3, 1.0, string.Empty, "a", "bc");
        }
コード例 #11
0
        public void SemanticWebTest1()
        {
            var prop0    = "Tony Blair";
            var prop1    = "6 May 1953";
            var template = Variable.Random(StringDistribution.Any());
            var text     = Variable.StringFormat(template, prop0, prop1);

            var engine = new InferenceEngine();

            engine.Compiler.RecommendedQuality = QualityBand.Experimental;
            engine.NumberOfIterations          = 1;

            var textDist = engine.Infer <StringDistribution>(text);

            Console.WriteLine("textDist={0}", textDist);

            Assert.False(double.IsNegativeInfinity(textDist.GetLogProb("6 May 1953 is the date of birth of Tony Blair.")));
        }
コード例 #12
0
        public void StringFormatPerformanceTest2()
        {
            Assert.Timeout(() =>
            {
                Rand.Restart(777);

                Variable <string> template       = Variable.Constant("{0} {1}").Named("template");
                Variable <string> arg1           = Variable.Random(StringDistribution.Any(minLength: 1, maxLength: 15)).Named("arg1");
                Variable <string> arg2           = Variable.Random(StringDistribution.Any(minLength: 1)).Named("arg2");
                Variable <string> text           = Variable.StringFormat(template, arg1, arg2).Named("text");
                Variable <string> fullTextFormat = Variable.Random(StringDistribution.Any()).Named("fullTextFormat");
                Variable <string> fullText       = Variable.StringFormat(fullTextFormat, text).Named("fullText");

                var engine = new InferenceEngine();
                engine.Compiler.RecommendedQuality = QualityBand.Experimental;
                engine.ShowProgress = false;

                Action action = () =>
                {
                    // Generate random observed string
                    string observedPattern = string.Empty;
                    for (int j = 0; j < 5; ++j)
                    {
                        for (int k = 0; k < 5; ++k)
                        {
                            observedPattern += (char)Rand.Int('a', 'z' + 1);
                        }

                        observedPattern += ' ';
                    }

                    // Run inference
                    fullText.ObservedValue = observedPattern;
                    engine.Infer <StringDistribution>(arg1);
                    engine.Infer <StringDistribution>(arg2);
                    engine.Infer <StringDistribution>(text);
                    engine.Infer <StringDistribution>(fullTextFormat);
                };

                action(); // To exclude the compilation time from the profile
                ProfileAction(action, 100);
            }, 10000);
        }
コード例 #13
0
        public void NonNormalizableLoop4()
        {
            StringAutomaton automaton = StringAutomaton.Zero();

            automaton.Start.AddSelfTransition('a', Weight.FromValue(0.1));
            var branch1 = automaton.Start.AddTransition('a', Weight.FromValue(2.0));

            branch1.AddSelfTransition('a', Weight.FromValue(2.0));
            branch1.EndWeight = Weight.One;
            var branch2 = automaton.Start.AddTransition('a', Weight.FromValue(2.0));

            branch2.EndWeight = Weight.One;

            StringAutomaton copyOfAutomaton = automaton.Clone();

            Assert.Throws <InvalidOperationException>(() => automaton.NormalizeValues());
            Assert.False(copyOfAutomaton.TryNormalizeValues());
            ////Assert.Equal(f, copyOfF); // TODO: fix equality first
        }
コード例 #14
0
        public void SampleFiniteSupport()
        {
            Rand.Restart(69);

            StringDistribution dist        = StringDistribution.OneOf("a", "ab").Append(StringDistribution.OneOf("c", "bc"));
            const int          SampleCount = 10000;

            int[] sampleCounts = new int[3];
            for (int i = 0; i < SampleCount; ++i)
            {
                string sample      = dist.Sample();
                int    sampleIndex = sample == "ac" ? 0 : sample == "abc" ? 1 : 2;
                ++sampleCounts[sampleIndex];
            }

            Assert.Equal(0.25, sampleCounts[0] / (double)SampleCount, 1e-2);
            Assert.Equal(0.5, sampleCounts[1] / (double)SampleCount, 1e-2);
            Assert.Equal(0.25, sampleCounts[2] / (double)SampleCount, 1e-2);
        }
コード例 #15
0
        public void AutomatonNormalizationPerformance3()
        {
            Assert.Timeout(() =>
            {
                StringAutomaton automaton = StringAutomaton.Zero();
                automaton.Start.AddSelfTransition('a', Weight.FromValue(0.5));
                automaton.Start.EndWeight = Weight.One;
                var nextState             = automaton.Start.AddTransitionsForSequence("aa");
                nextState.AddSelfTransition('a', Weight.FromValue(0.5));
                nextState.EndWeight = Weight.One;

                for (int i = 0; i < 3; ++i)
                {
                    automaton = automaton.Product(automaton);
                }

                ProfileAction(() => automaton.GetLogNormalizer(), 100);
            }, 120000);
        }
コード例 #16
0
ファイル: SerializableTest.cs プロジェクト: kant2002/infer
            public void AssertEqualTo(MyClass that)
            {
                Assert.Equal(0, this.bernoulli.MaxDiff(that.bernoulli));
                Assert.Equal(0, this.beta.MaxDiff(that.beta));
                Assert.Equal(0, this.binomial.MaxDiff(that.binomial));
                Assert.Equal(0, this.conjugateDirichlet.MaxDiff(that.conjugateDirichlet));
                Assert.Equal(0, this.dirichlet.MaxDiff(that.dirichlet));
                Assert.Equal(0, this.discrete.MaxDiff(that.discrete));
                Assert.Equal(0, this.gamma.MaxDiff(that.gamma));
                Assert.Equal(0, this.gammaPower.MaxDiff(that.gammaPower));
                Assert.Equal(0, this.gaussian.MaxDiff(that.gaussian));
                Assert.Equal(0, this.nonconjugateGaussian.MaxDiff(that.nonconjugateGaussian));
                Assert.Equal(0, this.pointMass.MaxDiff(that.pointMass));
                Assert.Equal(0, this.sparseBernoulliList.MaxDiff(that.sparseBernoulliList));
                Assert.Equal(0, this.sparseBetaList.MaxDiff(that.sparseBetaList));
                Assert.Equal(0, this.sparseGammaList.MaxDiff(that.sparseGammaList));
                Assert.Equal(0, this.truncatedGamma.MaxDiff(that.truncatedGamma));
                Assert.Equal(0, this.truncatedGaussian.MaxDiff(that.truncatedGaussian));
                Assert.Equal(0, this.wrappedGaussian.MaxDiff(that.wrappedGaussian));
                Assert.Equal(0, this.sparseGaussianList.MaxDiff(that.sparseGaussianList));
                Assert.Equal(0, this.unnormalizedDiscrete.MaxDiff(that.unnormalizedDiscrete));
                Assert.Equal(0, this.vectorGaussian.MaxDiff(that.vectorGaussian));
                Assert.Equal(0, this.wishart.MaxDiff(that.wishart));
                Assert.Equal(0, this.pareto.MaxDiff(that.pareto));
                Assert.Equal(0, this.poisson.MaxDiff(that.poisson));
                Assert.Equal(0, ga.MaxDiff(that.ga));
                Assert.Equal(0, vga.MaxDiff(that.vga));
                Assert.Equal(0, ga2D.MaxDiff(that.ga2D));
                Assert.Equal(0, vga2D.MaxDiff(that.vga2D));
                Assert.Equal(0, gaJ.MaxDiff(that.gaJ));
                Assert.Equal(0, vgaJ.MaxDiff(that.vgaJ));
                Assert.Equal(0, this.sparseGp.MaxDiff(that.sparseGp));
                Assert.True(this.quantileEstimator.ValueEquals(that.quantileEstimator));
                Assert.True(this.innerQuantiles.Equals(that.innerQuantiles));
                Assert.True(this.outerQuantiles.Equals(that.outerQuantiles));

                if (this.stringDistribution1 != null)
                {
                    Assert.Equal(0, this.stringDistribution1.MaxDiff(that.stringDistribution1));
                    Assert.Equal(0, this.stringDistribution2.MaxDiff(that.stringDistribution2));
                }
            }
コード例 #17
0
        public void UniformOf()
        {
            var unif1 = StringDistribution.ZeroOrMore(DiscreteChar.Lower());

            Assert.False(unif1.IsUniform());
            Assert.False(unif1.IsProper());
            StringInferenceTestUtilities.TestProbability(unif1, 1.0, "hello", "a", string.Empty);
            StringInferenceTestUtilities.TestProbability(unif1, 0.0, "123", "!", "Abc");

            // Test if non-uniform element distribution does not affect the outcome
            Vector probs = DiscreteChar.Digit().GetProbs();

            probs['1'] = 0;
            probs['2'] = 0.3;
            probs['3'] = 0.0001;
            var unif2 = StringDistribution.ZeroOrMore(DiscreteChar.FromVector(probs));

            StringInferenceTestUtilities.TestProbability(unif2, 1.0, "0", "234", string.Empty);
            StringInferenceTestUtilities.TestProbability(unif2, 0.0, "1", "231", "!", "Abc");
        }
コード例 #18
0
        public void PointMassToString()
        {
            StringDistribution point = StringDistribution.PointMass("ab\"");

            Assert.Equal("ab\"", point.ToString(SequenceDistributionFormats.Friendly));
            Assert.Equal("ab\"", point.ToString(SequenceDistributionFormats.Regexp));
            Assert.Equal(
                @"digraph finite_state_machine {" + Environment.NewLine +
                @"  rankdir=LR;" + Environment.NewLine +
                @"  node [shape = doublecircle; label = ""0\nE=0""]; N0" + Environment.NewLine +
                @"  node [shape = circle; label = ""1\nE=0""]; N1" + Environment.NewLine +
                @"  node [shape = circle; label = ""2\nE=0""]; N2" + Environment.NewLine +
                @"  node [shape = circle; label = ""3\nE=1""]; N3" + Environment.NewLine +
                @"  N0 -> N1 [ label = ""W=1\na"" ];" + Environment.NewLine +
                @"  N1 -> N2 [ label = ""W=1\nb"" ];" + Environment.NewLine +
                @"  N2 -> N3 [ label = ""W=1\n\"""" ];" + Environment.NewLine +
                @"}" + Environment.NewLine
                ,
                point.ToString(SequenceDistributionFormats.GraphViz));
        }
コード例 #19
0
        public void SparseDiscreteNormalise()
        {
            double   commonValue    = 1.0;
            double   nonCommonValue = 10.0;
            Discrete d = new Discrete(
                SparseVector.FromSparseValues(100, 1.0, new List <SparseElement>
            {
                new SparseElement(20, nonCommonValue),
                new SparseElement(55, nonCommonValue)
            }));

            Vector v = d.GetProbs();

            Assert.Equal(Sparsity.Sparse, d.Sparsity);
            double       sum = (v.Count - 2) * commonValue + 2 * nonCommonValue;
            SparseVector sv  = (SparseVector)v;

            Assert.Equal(sv.CommonValue, commonValue / sum);
            Assert.Equal(2, sv.SparseValues.Count);
        }
コード例 #20
0
        public void AutomatonNormalizationPerformance2()
        {
            Assert.Timeout(() =>
            {
                var builder   = new StringAutomaton.Builder();
                var nextState = builder.Start.AddTransitionsForSequence("abc");
                nextState.SetEndWeight(Weight.One);
                nextState.AddSelfTransition('d', Weight.FromValue(0.1));
                nextState = nextState.AddTransitionsForSequence("efg");
                nextState.SetEndWeight(Weight.One);
                nextState.AddSelfTransition('h', Weight.FromValue(0.2));
                nextState = nextState.AddTransitionsForSequence("grlkhgn;lk3rng");
                nextState.SetEndWeight(Weight.One);
                nextState.AddSelfTransition('h', Weight.FromValue(0.3));

                var automaton = builder.GetAutomaton();

                ProfileAction(() => automaton.GetLogNormalizer(), 100000);
            }, 20000);
        }
コード例 #21
0
        public void AutomatonNormalizationPerformance3()
        {
            Assert.Timeout(() =>
            {
                var builder = new StringAutomaton.Builder();
                builder.Start.AddSelfTransition('a', Weight.FromValue(0.5));
                builder.Start.SetEndWeight(Weight.One);
                var nextState = builder.Start.AddTransitionsForSequence("aa");
                nextState.AddSelfTransition('a', Weight.FromValue(0.5));
                nextState.SetEndWeight(Weight.One);

                var automaton = builder.GetAutomaton();
                for (int i = 0; i < 3; ++i)
                {
                    automaton = automaton.Product(automaton);
                }

                ProfileAction(() => automaton.GetLogNormalizer(), 100);
            }, 120000);
        }
コード例 #22
0
 /// <summary>
 /// A helper function for testing messages to <c>format</c>.
 /// </summary>
 /// <param name="str">The message from <c>str</c>.</param>
 /// <param name="args">The message from <c>args</c>.</param>
 /// <param name="expectedFormatRequireEveryPlaceholder">
 /// The expected message to <c>format</c> if the format string is required to contain placeholders for all arguments.
 /// </param>
 /// <param name="expectedFormatAllowMissingPlaceholders">
 /// The expected message to <c>format</c> if the format string may not contain placeholders for some arguments.
 /// </param>
 private static void TestMessageToFormat(
     StringDistribution str,
     StringDistribution[] args,
     StringDistribution expectedFormatRequireEveryPlaceholder,
     StringDistribution expectedFormatAllowMissingPlaceholders)
 {
     string[] argNames = GetDefaultArgumentNames(args.Length);
     Assert.Equal(
         expectedFormatRequireEveryPlaceholder,
         StringFormatOp_RequireEveryPlaceholder_NoArgumentNames.FormatAverageConditional(str, args));
     Assert.Equal(
         expectedFormatRequireEveryPlaceholder,
         StringFormatOp_RequireEveryPlaceholder.FormatAverageConditional(str, args, argNames));
     Assert.Equal(
         expectedFormatAllowMissingPlaceholders,
         StringFormatOp_AllowMissingPlaceholders_NoArgumentNames.FormatAverageConditional(str, args));
     Assert.Equal(
         expectedFormatAllowMissingPlaceholders,
         StringFormatOp_AllowMissingPlaceholders.FormatAverageConditional(str, args, argNames));
 }
コード例 #23
0
        public static void TestAutomatonPropertyPreservation(StringAutomaton automaton, Func <StringAutomaton, StringAutomaton> testedOperation)
        {
            var automatonWithClearProperties = automaton
                                               .WithLogValueOverride(null)
                                               .WithPruneStatesWithLogEndWeightLessThan(null);

            var outputForClearProperties = testedOperation(automatonWithClearProperties);

            Assert.Null(outputForClearProperties.LogValueOverride);
            Assert.Null(outputForClearProperties.PruneStatesWithLogEndWeightLessThan);

            var automatonWithSetProperties = automaton
                                             .WithLogValueOverride(-1)
                                             .WithPruneStatesWithLogEndWeightLessThan(-128);

            var outputForSetProperties = testedOperation(automatonWithSetProperties);

            Assert.Equal(automatonWithSetProperties.LogValueOverride, outputForSetProperties.LogValueOverride);
            Assert.Equal(automatonWithSetProperties.PruneStatesWithLogEndWeightLessThan, outputForSetProperties.PruneStatesWithLogEndWeightLessThan);
        }
コード例 #24
0
        public void ComputeNormalizerWithManyNonTrivialLoops1()
        {
            StringAutomaton automaton = StringAutomaton.Zero();

            AddEpsilonLoop(automaton.Start, 3, 0.2);
            AddEpsilonLoop(automaton.Start, 5, 0.3);
            automaton.Start.EndWeight = Weight.FromValue(0.1);
            var nextState = automaton.Start.AddTransition('a', Weight.FromValue(0.4));

            nextState.EndWeight = Weight.FromValue(0.6);
            AddEpsilonLoop(nextState, 0, 0.3);
            nextState = nextState.AddTransition('b', Weight.FromValue(0.1));
            AddEpsilonLoop(nextState, 1, 0.9);
            nextState.EndWeight = Weight.FromValue(0.1);

            AssertStochastic(automaton);
            Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
            Assert.Equal(0.0, GetLogNormalizerByGetValue(automaton), 1e-6);
            Assert.Equal(0.0, GetLogNormalizerByGetValueWithTransducers(automaton), 1e-6);
        }
コード例 #25
0
        public void ComputeNormalizerSimple3()
        {
            StringAutomaton automaton = StringAutomaton.Zero();

            automaton.Start.AddSelfTransition('a', Weight.FromValue(0.7));
            automaton.Start.EndWeight = Weight.FromValue(0.1);

            var state1 = automaton.Start.AddTransition('b', Weight.FromValue(0.15));

            state1.AddSelfTransition('a', Weight.FromValue(0.4));
            state1.EndWeight = Weight.FromValue(0.6);

            var state2 = automaton.Start.AddTransition('c', Weight.FromValue(0.05));

            state2.EndWeight = Weight.One;

            AssertStochastic(automaton);
            Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
            Assert.Equal(0.0, GetLogNormalizerByGetValue(automaton), 1e-6);
            Assert.Equal(0.0, GetLogNormalizerByGetValueWithTransducers(automaton), 1e-6);
        }
コード例 #26
0
        public void ProductWithGroups()
        {
            StringDistribution lhsWithoutGroup = StringDistribution.String("ab");

            // add a group to first transition of the start state
            var weightFunctionBuilder = StringAutomaton.Builder.FromAutomaton(lhsWithoutGroup.GetWorkspaceOrPoint());
            var transitionIterator    = weightFunctionBuilder.Start.TransitionIterator;
            var transitionWithGroup   = transitionIterator.Value;

            transitionWithGroup.Group = 1;
            transitionIterator.Value  = transitionWithGroup;

            StringDistribution lhs = StringDistribution.FromWeightFunction(weightFunctionBuilder.GetAutomaton());
            StringDistribution rhs = StringDistribution.OneOf("ab", "ac");

            Assert.True(lhs.GetWorkspaceOrPoint().HasGroup(1));
            Assert.False(rhs.GetWorkspaceOrPoint().UsesGroups);
            var result = StringDistribution.Zero();

            result.SetToProduct(lhs, rhs);
            Assert.True(result.GetWorkspaceOrPoint().HasGroup(1));
        }
コード例 #27
0
        public void Capitalized()
        {
            int lowercaseCharacterCount = DiscreteChar.Lower().GetProbs().Count(p => p > 0);
            int uppercaseCharacterCount = DiscreteChar.Upper().GetProbs().Count(p => p > 0);

            var capitalizedAutomaton1 = StringDistribution.Capitalized(minLength: 3, maxLength: 5);

            Assert.True(capitalizedAutomaton1.IsProper());
            StringInferenceTestUtilities.TestProbability(
                capitalizedAutomaton1,
                StringInferenceTestUtilities.StringUniformProbability(2, 4, lowercaseCharacterCount) / uppercaseCharacterCount,
                "Abc",
                "Bcde",
                "Abcde");
            StringInferenceTestUtilities.TestProbability(capitalizedAutomaton1, 0.0, "A", "abc", "Ab", "Abcdef", string.Empty);

            var capitalizedAutomaton2 = StringDistribution.Capitalized(minLength: 3);

            Assert.False(capitalizedAutomaton2.IsProper());
            StringInferenceTestUtilities.TestProbability(capitalizedAutomaton2, 1.0, "Abc", "Bcde", "Abcde", "Abfjrhfjlrl");
            StringInferenceTestUtilities.TestProbability(capitalizedAutomaton2, 0.0, "A", "abc", "Ab", string.Empty);
        }
コード例 #28
0
        /// <summary>
        /// A helper function for testing messages to <c>args</c>.
        /// </summary>
        /// <param name="str">The message from <c>str</c>.</param>
        /// <param name="format">The message from <c>format</c>.</param>
        /// <param name="args">The message from <c>args</c>.</param>
        /// <param name="expectedArgsRequireEveryPlaceholder">
        /// The expected message to <c>args</c> if the format string is required to contain placeholders for all arguments.
        /// </param>
        /// <param name="expectedArgsAllowMissingPlaceholders">
        /// The expected message to <c>args</c> if the format string may not contain placeholders for some arguments.
        /// </param>
        private static void TestMessageToArgs(
            StringDistribution str,
            StringDistribution format,
            StringDistribution[] args,
            StringDistribution[] expectedArgsRequireEveryPlaceholder,
            StringDistribution[] expectedArgsAllowMissingPlaceholders)
        {
            string[] argNames = GetDefaultArgumentNames(args.Length);
            var      result   = new StringDistribution[args.Length];

            Assert.Equal(
                expectedArgsRequireEveryPlaceholder,
                StringFormatOp_RequireEveryPlaceholder_NoArgumentNames.ArgsAverageConditional(str, format, args, result));
            Assert.Equal(
                expectedArgsRequireEveryPlaceholder,
                StringFormatOp_RequireEveryPlaceholder.ArgsAverageConditional(str, format, args, argNames, result));
            Assert.Equal(
                expectedArgsAllowMissingPlaceholders,
                StringFormatOp_AllowMissingPlaceholders_NoArgumentNames.ArgsAverageConditional(str, format, args, result));
            var actualArgs = StringFormatOp_AllowMissingPlaceholders.ArgsAverageConditional(str, format, args, argNames, result);

            Assert.Equal(expectedArgsAllowMissingPlaceholders, actualArgs);
        }
コード例 #29
0
        public void LoopyEpsilonClosure1()
        {
            StringAutomaton automaton = StringAutomaton.Zero();

            automaton.Start.AddEpsilonTransition(Weight.FromValue(0.5), automaton.Start);
            var nextState = automaton.Start.AddEpsilonTransition(Weight.FromValue(0.4));

            nextState.AddEpsilonTransition(Weight.One).AddEpsilonTransition(Weight.One, automaton.Start);
            automaton.Start.EndWeight = Weight.FromValue(0.1);

            AssertStochastic(automaton);

            StringAutomaton.EpsilonClosure startClosure = automaton.Start.GetEpsilonClosure();
            Assert.Equal(3, startClosure.Size);
            Assert.Equal(0.0, startClosure.EndWeight.LogValue, 1e-8);

            for (int i = 0; i < startClosure.Size; ++i)
            {
                Weight weight         = startClosure.GetStateWeightByIndex(i);
                double expectedWeight = startClosure.GetStateByIndex(i) == automaton.Start ? 10 : 4;
                Assert.Equal(expectedWeight, weight.Value, 1e-8);
            }
        }
コード例 #30
0
        public void ImpossibleBranchTest4()
        {
            Variable <string> str1 = Variable.StringLower().Named("str1");
            Variable <string> str2 = Variable.StringLower().Named("str2");
            Variable <string> str3 = Variable.StringLower().Named("str3");
            Variable <string> text = Variable.New <string>().Named("text");

            Variable <int> selector = Variable.DiscreteUniform(3).Named("selector");

            using (Variable.Case(selector, 0))
            {
                text.SetTo(str1 + Variable.Constant(" ") + str2 + Variable.Constant(" ") + str3);
            }

            using (Variable.Case(selector, 1))
            {
                text.SetTo(str1 + Variable.Constant(" ") + str3);
            }

            using (Variable.Case(selector, 2))
            {
                text.SetTo(str1);
            }

            text.ObservedValue = "abc def";

            var engine            = new InferenceEngine();
            var selectorPosterior = engine.Infer <Discrete>(selector);
            var str1Posterior     = engine.Infer <StringDistribution>(str1);
            var str2Posterior     = engine.Infer <StringDistribution>(str2);
            var str3Posterior     = engine.Infer <StringDistribution>(str3);

            Assert.True(selectorPosterior.IsPointMass && selectorPosterior.Point == 1);
            Assert.True(str1Posterior.IsPointMass && str1Posterior.Point == "abc");
            Assert.Equal(StringDistribution.Lower(), str2Posterior);
            Assert.True(str3Posterior.IsPointMass && str3Posterior.Point == "def");
        }