/// <summary>EP message to <c>format</c>.</summary> /// <param name="str">Incoming message from <c>str</c>.</param> /// <param name="args">Incoming message from <c>args</c>.</param> /// <returns>The outgoing EP message to the <c>format</c> argument.</returns> /// <remarks> /// <para>The outgoing message is a distribution matching the moments of <c>format</c> as the random arguments are varied. The formula is <c>proj[p(format) sum_(str,args) p(str,args) factor(str,format,args)]/p(format)</c>.</para> /// </remarks> public static StringDistribution FormatAverageConditional(StringDistribution str, IList <StringDistribution> args) { Argument.CheckIfNotNull(str, "str"); Argument.CheckIfNotNull(args, "args"); Argument.CheckIfValid(args.Count > 0, "args", "There must be at least one argument provided."); // TODO: relax? if (args.Count >= 10) { throw new NotImplementedException("Up to 10 arguments currently supported."); } // Disallow special characters in args List <StringAutomaton> escapedArgs = args.Select(DisallowBraceReplacersTransducer.ProjectSource).ToList(); // Reverse the process defined by StrAverageConditional StringAutomaton result = GetPlaceholderReplacingTransducer(escapedArgs, true).ProjectSource(str); for (int i = 0; i < args.Count; ++i) { result = GetArgumentEscapingTransducer(i, args.Count, true).ProjectSource(result); } result = DisallowBraceReplacersTransducer.ProjectSource(result); return(StringDistribution.FromWorkspace(result)); }
/// <summary>EP message to <c>str</c>.</summary> /// <param name="format">Incoming message from <c>format</c>.</param> /// <param name="args">Incoming message from <c>args</c>.</param> /// <returns>The outgoing EP message to the <c>str</c> argument.</returns> /// <remarks> /// <para>The outgoing message is a distribution matching the moments of <c>str</c> as the random arguments are varied. The formula is <c>proj[p(str) sum_(format,args) p(format,args) factor(str,format,args)]/p(str)</c>.</para> /// </remarks> public static StringDistribution StrAverageConditional(StringDistribution format, IList <StringDistribution> args) { Argument.CheckIfNotNull(format, "format"); Argument.CheckIfNotNull(args, "args"); Argument.CheckIfValid(args.Count > 0, "args", "There must be at least one argument provided."); // TODO: relax? if (args.Count >= 10) { throw new NotImplementedException("Up to 10 arguments currently supported."); } // Disallow special characters in args or format StringAutomaton result = DisallowBraceReplacersTransducer.ProjectSource(format); List <StringAutomaton> escapedArgs = args.Select(DisallowBraceReplacersTransducer.ProjectSource).ToList(); // Check braces for correctness and replace them with special characters. // Also, make sure that each argument placeholder is present exactly once. // Superposition of transducers is used instead of a single transducer to allow for any order of arguments. // TODO: in case of a single argument, argument escaping stage can be skipped for (int i = 0; i < args.Count; ++i) { result = GetArgumentEscapingTransducer(i, args.Count, false).ProjectSource(result); } // Now replace placeholders with arguments result = GetPlaceholderReplacingTransducer(escapedArgs, false).ProjectSource(result); return(StringDistribution.FromWorkspace(result)); }
public static StringDistribution SubAverageConditional(StringDistribution str, int start, int minLength, int maxLength) { Argument.CheckIfNotNull(str, "str"); Argument.CheckIfInRange(start >= 0, "start", "Start index must be non-negative."); Argument.CheckIfInRange(minLength >= 0, "minLength", "Min length must be non-negative."); Argument.CheckIfInRange(maxLength >= 0, "maxLength", "Max length must be non-negative."); if (str.IsPointMass) { var strPoint = str.Point; var alts = new HashSet <string>(); for (int length = minLength; length <= maxLength; length++) { var s = strPoint.Substring(start, Math.Min(length, strPoint.Length)); alts.Add(s); } return(StringDistribution.OneOf(alts)); } var anyChar = StringAutomaton.ConstantOnElement(1.0, DiscreteChar.Any()); var transducer = StringTransducer.Consume(StringAutomaton.Repeat(anyChar, minTimes: start, maxTimes: start)); transducer.AppendInPlace(StringTransducer.Copy(StringAutomaton.Repeat(anyChar, minTimes: minLength, maxTimes: maxLength))); transducer.AppendInPlace(StringTransducer.Consume(StringAutomaton.Constant(1.0))); return(StringDistribution.FromWorkspace(transducer.ProjectSource(str.GetWorkspaceOrPoint()))); }
/// <summary>EP message to <c>str2</c>.</summary> /// <param name="concat">Incoming message from <c>concat</c>.</param> /// <param name="str1">Incoming message from <c>str1</c>.</param> /// <returns>The outgoing EP message to the <c>str2</c> argument.</returns> /// <remarks> /// <para>The outgoing message is a distribution matching the moments of <c>str2</c> as the random arguments are varied. The formula is <c>proj[p(str2) sum_(concat,str1) p(concat,str1) factor(concat,str1,str2)]/p(str2)</c>.</para> /// </remarks> public static StringDistribution Str2AverageConditional(StringDistribution concat, StringDistribution str1) { Argument.CheckIfNotNull(concat, "concat"); Argument.CheckIfNotNull(str1, "str1"); StringTransducer transducer = StringTransducer.Consume(str1.GetProbabilityFunction()); transducer.AppendInPlace(StringTransducer.Copy()); return(StringDistribution.FromWorkspace(transducer.ProjectSource(concat))); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="StringConcatOp"]/message_doc[@name="Str1AverageConditional(StringDistribution, StringDistribution)"]/*'/> public static StringDistribution Str1AverageConditional(StringDistribution concat, StringDistribution str2) { Argument.CheckIfNotNull(concat, "concat"); Argument.CheckIfNotNull(str2, "str2"); StringTransducer transducer = StringTransducer.Copy(); transducer.AppendInPlace(StringTransducer.Consume(str2.GetWorkspaceOrPoint())); return(StringDistribution.FromWorkspace(transducer.ProjectSource(concat.GetWorkspaceOrPoint()))); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="StringOfLengthOp"]/message_doc[@name="StrAverageConditional(DiscreteChar, Discrete)"]/*'/> public static StringDistribution StrAverageConditional(DiscreteChar allowedChars, Discrete length) { Argument.CheckIfNotNull(length, "length"); Argument.CheckIfValid(allowedChars.IsPartialUniform(), "allowedChars", "The set of allowed characters must be passed as a partial uniform distribution."); double logNormalizer = allowedChars.GetLogAverageOf(allowedChars); var oneCharacter = StringAutomaton.ConstantOnElementLog(logNormalizer, allowedChars); var manyCharacters = StringAutomaton.Repeat(oneCharacter, length.GetWorkspace()); return(StringDistribution.FromWorkspace(manyCharacters)); }
/// <summary>EP message to <c>str</c>.</summary> /// <param name="sub">Incoming message from <c>sub</c>.</param> /// <param name="start">Constant value for <c>start</c>.</param> /// <param name="length">Constant value for <c>length</c>.</param> /// <returns>The outgoing EP message to the <c>str</c> argument.</returns> /// <remarks> /// <para>The outgoing message is a distribution matching the moments of <c>str</c> as the random arguments are varied. The formula is <c>proj[p(str) sum_(sub) p(sub) factor(sub,str,start,length)]/p(str)</c>.</para> /// </remarks> public static StringDistribution StrAverageConditional(StringDistribution sub, int start, int length) { Argument.CheckIfNotNull(sub, "sub"); Argument.CheckIfInRange(start >= 0, "start", "Start index must be non-negative."); Argument.CheckIfInRange(length >= 0, "length", "Length must be non-negative."); var anyChar = StringAutomaton.ConstantOnElement(1.0, DiscreteChar.Any()); var transducer = StringTransducer.Produce(StringAutomaton.Repeat(anyChar, minTimes: start, maxTimes: start)); transducer.AppendInPlace(StringTransducer.Copy(StringAutomaton.Repeat(anyChar, minTimes: length, maxTimes: length))); transducer.AppendInPlace(StringTransducer.Produce(StringAutomaton.Constant(1.0))); return(StringDistribution.FromWorkspace(transducer.ProjectSource(sub))); }
public static StringDistribution FormatAverageConditional(StringDistribution str, IReadOnlyList <StringDistribution> args, IReadOnlyList <string> argNames) { Argument.CheckIfNotNull(str, "str"); ValidateArguments(args, argNames); var allowedArgs = args.Select(arg => arg.GetWorkspaceOrPoint()).ToList(); // Try optimizations for special cases if (TryOptimizedFormatAverageConditionalImpl(str, allowedArgs, argNames, out StringDistribution resultDist)) { return(resultDist); } // Reverse the process defined by StrAverageConditional var placeholderReplacer = GetPlaceholderReplacingTransducer(allowedArgs, argNames, true, false); StringAutomaton format = str.IsPointMass ? placeholderReplacer.ProjectSource(str.Point) : placeholderReplacer.ProjectSource(str.GetWorkspaceOrPoint()); StringAutomaton validatedFormat = GetValidatedFormatString(format, argNames); return(StringDistribution.FromWorkspace(validatedFormat)); }
/// <summary> /// The implementation of <see cref="StrAverageConditional(StringDistribution, IList{StringDistribution}, IReadOnlyList{string})"/>. /// </summary> /// <param name="format">The message from <c>format</c>.</param> /// <param name="allowedArgs">The message from <c>args</c>, truncated to allowed values and converted to automata.</param> /// <param name="argNames">The names of the arguments.</param> /// <param name="withGroups">Whether the result should mark different arguments with groups.</param> /// <param name="noValidation">Whether incorrect format string values should not be pruned.</param> /// <returns>The message to <c>str</c>.</returns> private static StringDistribution StrAverageConditionalImpl( StringDistribution format, IList <StringAutomaton> allowedArgs, IReadOnlyList <string> argNames, bool withGroups, bool noValidation) { StringDistribution resultDist = TryOptimizedStrAverageConditionalImpl(format, allowedArgs, argNames, withGroups); if (resultDist != null) { return(resultDist); } // Check braces for correctness. StringAutomaton validatedFormat = format.GetWorkspaceOrPoint(); if (!noValidation) { validatedFormat = GetValidatedFormatString(format.GetWorkspaceOrPoint(), argNames); } // Now replace placeholders with arguments var placeholderReplacer = GetPlaceholderReplacingTransducer(allowedArgs, argNames, false, withGroups); StringAutomaton str = placeholderReplacer.ProjectSource(validatedFormat); return(StringDistribution.FromWorkspace(str)); }
/// <summary> /// An implementation of <see cref="StrAverageConditional(StringDistribution, IList{StringDistribution}, IReadOnlyList{string})"/> /// specialized for some cases for performance reasons. /// </summary> /// <param name="format">The message from <c>format</c>.</param> /// <param name="allowedArgs">The message from <c>args</c>, truncated to allowed values and converted to automata.</param> /// <param name="argNames">The names of the arguments.</param> /// <param name="withGroups">Whether the result should mark different arguments with groups.</param> /// <returns> /// Result distribution if there is an optimized implementation available for the provided parameters. /// <see langword="null"/> otherwise. /// </returns> /// <remarks> /// Supports the case of point mass <paramref name="format"/>. /// </remarks> private static StringDistribution TryOptimizedStrAverageConditionalImpl( StringDistribution format, IList <StringAutomaton> allowedArgs, IReadOnlyList <string> argNames, bool withGroups) { if (!format.IsPointMass) { // Fall back to the general case return(null); } // Check braces for correctness & replace placeholders with arguments simultaneously var result = StringAutomaton.Builder.ConstantOn(Weight.One, string.Empty); bool[] argumentSeen = new bool[allowedArgs.Count]; int openingBraceIndex = format.Point.IndexOf("{", StringComparison.Ordinal), closingBraceIndex = -1; while (openingBraceIndex != -1) { // Add the part of the format before the placeholder result.Append(StringAutomaton.ConstantOn(1.0, format.Point.Substring(closingBraceIndex + 1, openingBraceIndex - closingBraceIndex - 1))); // Find next opening and closing braces closingBraceIndex = format.Point.IndexOf("}", openingBraceIndex + 1, StringComparison.Ordinal); int nextOpeningBraceIndex = format.Point.IndexOf("{", openingBraceIndex + 1, StringComparison.Ordinal); // Opening brace must be followed by a closing brace if (closingBraceIndex == -1 || (nextOpeningBraceIndex != -1 && nextOpeningBraceIndex < closingBraceIndex)) { return(StringDistribution.Zero()); } string argumentName = format.Point.Substring(openingBraceIndex + 1, closingBraceIndex - openingBraceIndex - 1); int argumentIndex = argNames.IndexOf(argumentName); // Unknown or previously seen argument found if (argumentIndex == -1 || argumentSeen[argumentIndex]) { return(StringDistribution.Zero()); } // Replace the placeholder by the argument result.Append(allowedArgs[argumentIndex], withGroups ? argumentIndex + 1 : 0); // Mark the argument as 'seen' argumentSeen[argumentIndex] = true; openingBraceIndex = nextOpeningBraceIndex; } // There should be no closing braces after the last opening brace if (format.Point.IndexOf('}', closingBraceIndex + 1) != -1) { return(StringDistribution.Zero()); } if (RequirePlaceholderForEveryArgument && argumentSeen.Any(seen => !seen)) { // Some argument wasn't present although it was required return(StringDistribution.Zero()); } // Append the part of the format after the last placeholder result.Append(StringAutomaton.ConstantOn(1.0, format.Point.Substring(closingBraceIndex + 1, format.Point.Length - closingBraceIndex - 1))); return(StringDistribution.FromWorkspace(result.GetAutomaton())); }
/// <summary> /// An implementation of <see cref="FormatAverageConditional(StringDistribution, IList{StringDistribution}, IReadOnlyList{string})"/> /// specialized for some cases for performance reasons. /// </summary> /// <param name="str">The message from <c>str</c>.</param> /// <param name="allowedArgs">The message from <c>args</c>, truncated to allowed values and converted to automata.</param> /// <param name="argNames">The names of the arguments.</param> /// <param name="resultDist">The computed result.</param> /// <returns> /// <see langword="true"/> if there is an optimized implementation available for the provided parameters, /// and <paramref name="resultDist"/> has been computed using it. /// <see langword="false"/> otherwise. /// </returns> /// <remarks> /// Supports the case of point mass <paramref name="str"/> and <paramref name="allowedArgs"/>, /// where each of the arguments is present in <paramref name="str"/> at most once and the occurrences /// are non-overlapping. /// </remarks> private static bool TryOptimizedFormatAverageConditionalImpl( StringDistribution str, IList <StringAutomaton> allowedArgs, IReadOnlyList <string> argNames, out StringDistribution resultDist) { resultDist = null; string[] allowedArgPoints = Util.ArrayInit(allowedArgs.Count, i => allowedArgs[i].TryComputePoint()); if (!str.IsPointMass || !allowedArgPoints.All(argPoint => argPoint != null && SubstringOccurrencesCount(str.Point, argPoint) <= 1)) { // Fall back to the general case return(false); } // Obtain arguments present in 'str' (ordered by position) var argPositions = allowedArgPoints.Select((arg, argIndex) => Tuple.Create(argIndex, str.Point.IndexOf(arg, StringComparison.Ordinal))) .Where(t => t.Item2 != -1) .OrderBy(t => t.Item2) .ToList(); if (RequirePlaceholderForEveryArgument && argPositions.Count != allowedArgs.Count) { // Some argument is not in 'str' resultDist = StringDistribution.Zero(); return(true); } StringAutomaton result = StringAutomaton.ConstantOn(1.0, string.Empty); int curArgumentIndex = -1; int curArgumentPos = -1; int curArgumentLength = 1; for (int i = 0; i < argPositions.Count; ++i) { int prevArgumentIndex = curArgumentIndex; int prevArgumentPos = curArgumentPos; int prevArgumentLength = curArgumentLength; curArgumentIndex = argPositions[i].Item1; curArgumentPos = argPositions[i].Item2; curArgumentLength = allowedArgPoints[curArgumentIndex].Length; if (prevArgumentIndex != -1 && curArgumentPos < prevArgumentPos + prevArgumentLength) { // It's easier to fall back to the general case in case of overlapping arguments return(false); } // Append the contents of 'str' preceeding the current argument result.AppendInPlace(str.Point.Substring(prevArgumentPos + prevArgumentLength, curArgumentPos - prevArgumentPos - prevArgumentLength)); // The format may have included either the text ot the placeholder string argName = "{" + argNames[curArgumentIndex] + "}"; if (RequirePlaceholderForEveryArgument) { result.AppendInPlace(StringAutomaton.ConstantOn(1.0, argName)); } else { result.AppendInPlace(StringAutomaton.ConstantOn(1.0, argName, allowedArgPoints[curArgumentIndex])); } } // Append the rest of 'str' result.AppendInPlace(str.Point.Substring(curArgumentPos + curArgumentLength, str.Point.Length - curArgumentPos - curArgumentLength)); resultDist = StringDistribution.FromWorkspace(result); return(true); }
public void WordModel() { // We want to build a word model as a reasonably simple StringDistribution. It // should satisfy the following: // (1) The probability of a word of moderate length should not be // significantly less than the probability of a shorter word. // (2) The probability of a specific word conditioned on its length matches that of // words in the target language. // We achieve this by putting non-normalized character distributions on the edges. The // StringDistribution is unaware that these are non-normalized. // The StringDistribution itself is non-normalizable. const double TargetProb1 = 0.05; const double Ratio1 = 0.4; const double TargetProb2 = TargetProb1 * Ratio1; const double Ratio2 = 0.2; const double TargetProb3 = TargetProb2 * Ratio2; const double TargetProb4 = TargetProb3 * Ratio2; const double TargetProb5 = TargetProb4 * Ratio2; const double Ratio3 = 0.999; const double TargetProb6 = TargetProb5 * Ratio3; const double TargetProb7 = TargetProb6 * Ratio3; const double TargetProb8 = TargetProb7 * Ratio3; const double Ratio4 = 0.9; const double TargetProb9 = TargetProb8 * Ratio4; const double TargetProb10 = TargetProb9 * Ratio4; var targetProbabilitiesPerLength = new double[] { TargetProb1, TargetProb2, TargetProb3, TargetProb4, TargetProb5, TargetProb6, TargetProb7, TargetProb8, TargetProb9, TargetProb10 }; var charDistUpper = DiscreteChar.Upper(); var charDistLower = DiscreteChar.Lower(); var charDistUpperNarrow = DiscreteChar.OneOf('A', 'B'); var charDistLowerNarrow = DiscreteChar.OneOf('a', 'b'); var charDistUpperScaled = DiscreteChar.Uniform(); var charDistLowerScaled1 = DiscreteChar.Uniform(); var charDistLowerScaled2 = DiscreteChar.Uniform(); var charDistLowerScaled3 = DiscreteChar.Uniform(); var charDistLowerScaledEnd = DiscreteChar.Uniform(); charDistUpperScaled.SetToPartialUniformOf(charDistUpper, Math.Log(TargetProb1)); charDistLowerScaled1.SetToPartialUniformOf(charDistLower, Math.Log(Ratio1)); charDistLowerScaled2.SetToPartialUniformOf(charDistLower, Math.Log(Ratio2)); charDistLowerScaled3.SetToPartialUniformOf(charDistLower, Math.Log(Ratio3)); charDistLowerScaledEnd.SetToPartialUniformOf(charDistLower, Math.Log(Ratio4)); var wordModel = StringDistribution.Concatenate( new List <DiscreteChar> { charDistUpperScaled, charDistLowerScaled1, charDistLowerScaled2, charDistLowerScaled2, charDistLowerScaled2, charDistLowerScaled3, charDistLowerScaled3, charDistLowerScaled3, charDistLowerScaledEnd }, true, true); const string Word = "Abcdefghij"; const double Eps = 1e-5; var broadDist = StringDistribution.Char(charDistUpper); var narrowDist = StringDistribution.Char(charDistUpperNarrow); var narrowWord = "A"; var expectedProbForNarrow = 0.5; for (var i = 0; i < targetProbabilitiesPerLength.Length; i++) { var currentWord = Word.Substring(0, i + 1); var probCurrentWord = Math.Exp(wordModel.GetLogProb(currentWord)); Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps); var logAvg = Math.Exp(wordModel.GetLogAverageOf(broadDist)); Assert.Equal(targetProbabilitiesPerLength[i], logAvg, Eps); var prod = StringDistribution.Zero(); prod.SetToProduct(broadDist, wordModel); Xunit.Assert.True(prod.GetWorkspaceOrPoint().HasElementLogValueOverrides); probCurrentWord = Math.Exp(prod.GetLogProb(currentWord)); Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps); prod.SetToProduct(narrowDist, wordModel); Xunit.Assert.False(prod.GetWorkspaceOrPoint().HasElementLogValueOverrides); var probNarrowWord = Math.Exp(prod.GetLogProb(narrowWord)); Assert.Equal(expectedProbForNarrow, probNarrowWord, Eps); broadDist = broadDist.Append(charDistLower); narrowDist = narrowDist.Append(charDistLowerNarrow); narrowWord += "a"; expectedProbForNarrow *= 0.5; } // Copied model var copiedModel = StringDistribution.FromWorkspace(StringTransducer.Copy().ProjectSource(wordModel.GetWorkspaceOrPoint())); // Under transducer. for (var i = 0; i < targetProbabilitiesPerLength.Length; i++) { var currentWord = Word.Substring(0, i + 1); var probCurrentWord = Math.Exp(copiedModel.GetLogProb(currentWord)); Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps); } // Rescaled model var scale = 0.5; var newTargetProb1 = TargetProb1 * scale; var charDistUpperScaled1 = DiscreteChar.Uniform(); charDistUpperScaled1.SetToPartialUniformOf(charDistUpper, Math.Log(newTargetProb1)); var reWeightingTransducer = StringTransducer.Replace(StringDistribution.Char(charDistUpper).GetWorkspaceOrPoint(), StringDistribution.Char(charDistUpperScaled1).GetWorkspaceOrPoint()) .Append(StringTransducer.Copy()); var reWeightedWordModel = StringDistribution.FromWorkspace(reWeightingTransducer.ProjectSource(wordModel.GetWorkspaceOrPoint())); for (var i = 0; i < targetProbabilitiesPerLength.Length; i++) { var currentWord = Word.Substring(0, i + 1); var probCurrentWord = Math.Exp(reWeightedWordModel.GetLogProb(currentWord)); Assert.Equal(scale * targetProbabilitiesPerLength[i], probCurrentWord, Eps); } }