/// <summary> /// If arg is an array indexer expression, get a list of all Range indexes, in order. Indexes that are not Ranges instead get their Ranges added to dict. /// </summary> /// <param name="arg"></param> /// <param name="dict"></param> /// <returns></returns> internal static List <List <Range> > GetRangeBrackets(IModelExpression arg, IDictionary <IModelExpression, List <List <Range> > > dict) { if (arg is Variable) { Variable v = (Variable)arg; if (v.IsArrayElement) { List <List <Range> > brackets = GetRangeBrackets(v.ArrayVariable, dict); List <Range> indices = new List <Range>(); // must add item indices after array's indices foreach (IModelExpression expr in v.indices) { if (expr is Range) { indices.Add((Range)expr); } else { List <List <Range> > argBrackets = GetRangeBrackets(expr, dict); dict[expr] = argBrackets; } } brackets.Add(indices); return(brackets); } } return(new List <List <Range> >()); }
internal static void ForEachRange(IModelExpression arg, Action <Range> action) { if (arg is Range) { action((Range)arg); return; } else if (arg is Variable) { Variable v = (Variable)arg; if (v.IsLoopIndex) { action(v.loopRange); } if (v.IsArrayElement) { ForEachRange(v.ArrayVariable, action); // must add item indices after array's indices foreach (IModelExpression expr in v.indices) { ForEachRange(expr, action); } } } }
protected void FinishExpressionUntyped(IModelExpression expr, IAlgorithm alg) { if (expr is MethodInvoke) { return; } MethodInfo mb = new Action <IModelExpression <object>, IAlgorithm>(this.FinishExpression <object>).Method.GetGenericMethodDefinition(); Type domainType = null; // Look through the interfaces for this model expression (is there a better way of doing this?). // We expect to find IModelExpression<> - we can then get the type parameter from this interface Type[] faces = expr.GetType().GetInterfaces(); foreach (Type face in faces) { if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof(IModelExpression <>)) { domainType = face.GetGenericArguments()[0]; break; } } if (domainType == null) { throw new ArgumentException("Expression: " + expr + " does not implement IModelExpression<>"); } MethodInfo mi2 = mb.MakeGenericMethod(domainType); Util.Invoke(mi2, this, expr, alg); }
private void SearchMethodInvoke(MethodInvoke method) { if (searched.Contains(method)) { return; } searched.Add(method); if (method.returnValue != null) { IModelExpression target = method.returnValue; if (method.method.DeclaringType == target.GetType() && method.method.Name == new Func <bool>(Variable <bool> .RemovedBySetTo).Method.Name) { throw new InvalidOperationException("Variable '" + target + "' was consumed by variable.SetTo(). It can no longer be used or inferred. Perhaps you meant Variable.ConstrainEqual instead of SetTo."); } } if (method.returnValue != null) { toSearch.Push(method.returnValue); } foreach (IModelExpression arg in method.args) { toSearch.Push(arg); } SearchContainers(method.Containers); }
protected Edge AddEdge(GraphWriter g, IModelExpression from, IModelExpression to, string name) { Node sourceNode = GetNode(g, from); Node targetNode = GetNode(g, to); childNodes.Add(targetNode); return(AddEdge(g, sourceNode, targetNode, name)); }
protected void AddFactorEdges(GraphWriter g, MethodInvoke mi) { var parameters = mi.method.GetParameters(); for (int i = 0; i < mi.args.Count; i++) { var parameter = parameters[i]; if (parameter.IsOut) { AddEdge(g, mi, mi.args[i], parameter.Name); } else { AddEdge(g, mi.args[i], mi, parameter.Name); } } if (mi.returnValue != null) { AddEdge(g, mi, mi.returnValue, ""); } if (!UseContainers) { // add edges from condition variables to target (if there are no such edges already) IModelExpression target = (mi.returnValue != null) ? mi.returnValue : mi; Set <IStatementBlock> excluded = new Set <IStatementBlock>(); if (target is Variable) { // if target is in the ConditionBlock, then don't connect with the condition variable Variable targetVar = (Variable)target; excluded.AddRange(targetVar.Containers); } foreach (IStatementBlock block in mi.Containers) { if (excluded.Contains(block)) { continue; } if (block is ConditionBlock) { ConditionBlock cb = (ConditionBlock)block; Variable c = cb.ConditionVariableUntyped; List <Variable> condVars; if (!conditionVariables.TryGetValue(target, out condVars)) { condVars = new List <Variable>(); conditionVariables[target] = condVars; } if (!condVars.Contains(c)) { AddEdge(g, c, target, "condition"); condVars.Add(c); } } } } }
/// <summary> /// Build a variable expression /// </summary> /// <typeparam name="T">Domain type of the variable expression</typeparam> /// <param name="expr">The variable expression</param> private void BuildExpression <T>(IModelExpression <T> expr) { if (expr is Variable <T> var) { BuildVariable <T>(var); } else { throw new InferCompilerException("Unhandled model expression type: " + expr.GetType()); } }
private void FinishExpression <T>(IModelExpression <T> expr, IAlgorithm alg) { if (expr is Variable <T> ) { FinishVariable <T>((Variable <T>)expr, alg); } else { throw new InferCompilerException("Unhandled model expression type: " + expr.GetType()); } }
/// <summary> /// Search a variable expression /// </summary> /// <typeparam name="T">Domain type of the variable expression</typeparam> /// <param name="var">The variable expression</param> /// public void SearchExpression <T>(IModelExpression <T> var) { if (var is Variable <T> varT) { SearchVariable <T>(varT); } else { throw new InferCompilerException("Unhandled model expression type: " + var.GetType()); } }
private static Range ReplaceExpressions(Range r, Dictionary <IModelExpression, IModelExpression> replacements) { IModelExpression <int> newSize = (IModelExpression <int>)ReplaceExpressions(r.Size, replacements); if (ReferenceEquals(newSize, r.Size)) { return(r); } Range newRange = new Range(newSize); newRange.Parent = r; replacements.Add(r, newRange); return(newRange); }
/// <summary> /// True if arg is indexed by at least the given ranges. /// </summary> /// <param name="arg"></param> /// <param name="ranges"></param> /// <returns></returns> internal static bool IsIndexedByAll(IModelExpression arg, ICollection <Range> ranges) { Set <Range> argRanges = new Set <Range>(); ForEachRange(arg, argRanges.Add); foreach (Range r in ranges) { if (!argRanges.Contains(r)) { return(false); } } return(true); }
/// <summary> /// Builds the model necessary to infer marginals for the supplied variables and algorithm. /// </summary> /// <param name="engine">The inference algorithm being used</param> /// <param name="inferOnlySpecifiedVars">If true, inference will be restricted to only the variables given.</param> /// <param name="vars">Variables to infer.</param> /// <returns></returns> /// <remarks> /// Algorithm: starting from the variables to infer, we search through the graph to build up a "searched set". /// Each Variable and MethodInvoke in this set has an associated timestamp. /// We sort by timestamp, and then generate code. /// </remarks> public ITypeDeclaration Build(InferenceEngine engine, bool inferOnlySpecifiedVars, IEnumerable <IVariable> vars) { List <IStatementBlock> openBlocks = StatementBlock.GetOpenBlocks(); if (openBlocks.Count > 0) { throw new InvalidOperationException("The block " + openBlocks[0] + " has not been closed."); } Reset(); this.inferOnlySpecifiedVars = inferOnlySpecifiedVars; variablesToInfer.AddRange(vars); foreach (IVariable var in vars) { toSearch.Push(var); } while (toSearch.Count > 0) { IModelExpression expr = toSearch.Pop(); SearchExpressionUntyped(expr); } // lock in the set of model expressions. ModelExpressions = new List <IModelExpression>(searched); List <int> timestamps = new List <int>(); List <IModelExpression> exprs = new List <IModelExpression>(); foreach (IModelExpression expr in ModelExpressions) { if (expr is Variable var) { exprs.Add(var); timestamps.Add(var.timestamp); } else if (expr is MethodInvoke mi) { exprs.Add(mi); timestamps.Add(mi.timestamp); } } Collection.Sort(timestamps, exprs); foreach (IModelExpression expr in exprs) { BuildExpressionUntyped(expr); } foreach (IModelExpression expr in exprs) { FinishExpressionUntyped(expr, engine.Algorithm); } return(modelType); }
/// <summary> /// Throws an exception if an index expression is not valid for subscripting an array. /// </summary> /// <param name="index">Index expression</param> /// <param name="array">Array that the expression is indexing</param> /// <exclude/> internal void CheckCompatible(IModelExpression index, IVariableArray array) { if (IsCompatibleWith(index)) { return; } string message = StringUtil.TypeToString(array.GetType()) + " " + array + " cannot be indexed by " + index + "."; if (index is Range) { string constructorName = "the constructor"; message += " Perhaps you omitted " + index + " as an argument to " + constructorName + "?"; } throw new ArgumentException(message, "index"); }
protected void SearchExpressionUntyped(IModelExpression expr) { if (expr == null) { throw new NullReferenceException("Model expression was null."); } // Console.WriteLine("Searching expression: "+var+" "+builtVars.ContainsKey(var)); if (searched.Contains(expr)) { return; } if (expr is MethodInvoke methodInvoke) { SearchMethodInvoke(methodInvoke); return; } if (expr is Range range) { SearchRange(range); return; } MethodInfo mb = new Action <IModelExpression <object> >(this.SearchExpression <object>).Method.GetGenericMethodDefinition(); Type domainType = null; // Look through the interfaces for this model expression (is there a better way of doing this?). // We expect to find IModelExpression<> - we can then get the type parameter from this interface Type[] faces = expr.GetType().GetInterfaces(); foreach (Type face in faces) { if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof(IModelExpression <>)) { domainType = face.GetGenericArguments()[0]; break; } } if (domainType == null) { throw new ArgumentException("Expression: " + expr + " does not implement IModelExpression<>"); } MethodInfo mi2 = mb.MakeGenericMethod(domainType); Util.Invoke(mi2, this, expr); }
/// <summary> /// True if index is compatible with this range /// </summary> /// <param name="index">Index expression</param> /// <returns></returns> /// <exclude/> internal bool IsCompatibleWith(IModelExpression index) { if (index is Range range) { return(range.GetRoot() == GetRoot()); } else if (index is Variable indexVar) { Range valueRange = indexVar.GetValueRange(false); if (valueRange == null) { return(true); } return(IsCompatibleWith(valueRange)); } else { return(true); } }
internal static void InvalidateAllEngines(IModelExpression expr) { foreach (WeakReference weakRef in allEngineInstances.Keys) { if (weakRef.Target is InferenceEngine engine) { var modelExpressions = engine.mb.ModelExpressions; if (modelExpressions != null && modelExpressions.Contains(expr)) { engine.mb.Reset(); // must rebuild the model engine.InvalidateCompiledAlgorithms(); } } else { // The engine has been freed, so we can remove it from the dictionary. allEngineInstances.TryRemove(weakRef, out EmptyStruct value); } } }
private static IModelExpression ReplaceExpressions(IModelExpression expr, Dictionary <IModelExpression, IModelExpression> replacements) { if (replacements.ContainsKey(expr)) { return(replacements[expr]); } if (expr is Range) { return(ReplaceExpressions((Range)expr, replacements)); } else if (expr is Variable) { Variable v = (Variable)expr; if (v.IsArrayElement) { bool changed = false; IVariableArray newArray = (IVariableArray)ReplaceExpressions(v.ArrayVariable, replacements); if (!ReferenceEquals(newArray, v.ArrayVariable)) { changed = true; } IModelExpression[] newIndices = new IModelExpression[v.indices.Count]; for (int i = 0; i < newIndices.Length; i++) { newIndices[i] = ReplaceExpressions(v.indices[i], replacements); if (!ReferenceEquals(newIndices[i], v.indices[i])) { changed = true; } } if (changed) { return ((IModelExpression) Invoker.InvokeMember(newArray.GetType(), "get_Item", BindingFlags.Public | BindingFlags.Instance | BindingFlags.InvokeMethod, newArray, newIndices)); } } } return(expr); }
protected void BuildExpressionUntyped(IModelExpression var) { if (var == null) { throw new NullReferenceException("Model expression was null."); } // Console.WriteLine("Building expression: "+var+" "+builtVars.ContainsKey(var)); if (var is MethodInvoke methodInvoke) { BuildMethodInvoke(methodInvoke); return; } MethodInfo mb = new Action <IModelExpression <object> >(this.BuildExpression <object>).Method.GetGenericMethodDefinition(); Type domainType = null; // Look through the interfaces for this model expression (is there a better way of doing this?). // We expect to find IModelExpression<> - we can then get the type parameter from this interface Type[] faces = var.GetType().GetInterfaces(); foreach (Type face in faces) { if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof(IModelExpression <>)) { domainType = face.GetGenericArguments()[0]; break; } } if (domainType == null) { throw new ArgumentException("Expression: " + var + " does not implement IModelExpression<>"); } // Construct the BuildExpression method for this type. MethodInfo mi2 = mb.MakeGenericMethod(domainType); // Invoke the typed BuildExpression method. This will recurse into BuildExpressionUntyped // as necessary Util.Invoke(mi2, this, var); }
/// <summary> /// Create a new DistributedSchedule attribute /// </summary> /// <param name="commExpression"></param> public DistributedSchedule(Variable <ICommunicator> commExpression) { this.commExpression = commExpression; }
/// <summary> /// Create a new ParallelSchedule attribute /// </summary> /// <param name="scheduleExpression">An observed variable of type int[][][], whose dimensions are [thread][block][item]. Each thread must have the same number of blocks, but blocks can be different sizes. Must have at least one thread.</param> public ParallelSchedule(Variable <int[][][]> scheduleExpression) { this.scheduleExpression = scheduleExpression; }
/// <summary> /// Sets/Gets element in array given by index expression /// </summary> /// <param name="index"></param> /// <returns></returns> TItem IJaggedVariableArray <TItem> .this[IModelExpression index] { get { return(this[index]); } set { this[index] = value; } }
protected Node GetNode(GraphWriter g, IModelExpression expr) { if (nodeOfExpr.ContainsKey(expr)) { return(nodeOfExpr[expr]); } Node nd = g.AddNode("node" + (Count++)); nodeOfExpr[expr] = nd; nd.Label = expr.ToString(); nd.FontSize = 9; if (expr is Variable) { Variable ve = (Variable)expr; if (ve.IsObserved) { nd.Shape = ShapeStyle.None; if (ve.IsBase) { // if the observed value is a ValueType, display it directly rather than the variable name object value = ((HasObservedValue)ve).ObservedValue; if (ReferenceEquals(value, null)) { nd.Label = "null"; } else if (value.GetType().IsValueType) { nd.Label = value.ToString(); } } } if (!ve.IsReadOnly) { nd.FontSize = 10; nd.FontColor = Color.Blue; } if (UseContainers && ve.Containers.Count > 0) { var context = ConditionContext.GetContext(ve.Containers); if (context != null) { var contextNode = GetNode(g, context); AddGroupEdge(g, contextNode, nd); } } } else if (expr is MethodInvoke) { MethodInvoke mi = (MethodInvoke)expr; nd.FillColor = Color.Black; nd.FontColor = Color.White; nd.Shape = ShapeStyle.Box; nd.FontSize = 8; string methodName = mi.method.Name; if (mi.op != null) { methodName = mi.op.ToString(); } nd.Label = methodName; if (UseContainers && mi.Containers.Count > 0) { var context = ConditionContext.GetContext(mi.Containers); if (context != null) { var contextNode = GetNode(g, context); AddGroupEdge(g, contextNode, nd); } } } return(nd); }
/// <summary> /// Add a statement of the form x = f(...) to the MSL. /// </summary> /// <param name="method">Stores the method to call, the argument variables, and target variable.</param> /// <remarks> /// If any variable in the statement is an item variable, then we surround the statement with a loop over its range. /// Since there may be multiple item variables, and each item may depend on multiple ranges, we may end up with multiple loops. /// </remarks> private void BuildMethodInvoke(MethodInvoke method) { if (method.ReturnValue is Variable && ((Variable)method.ReturnValue).Inline) { return; } // Open containing blocks List <IStatementBlock> stBlocks = method.Containers; List <Range> localRanges = new List <Range>(); // each argument of method puts a partial order on the ranges. // e.g. array[i,j][k] requires i < k, j < k but says nothing about i and j // we assemble these constraints into a total order. Dictionary <Range, int> indexOfRange = new Dictionary <Range, int>(); Dictionary <IModelExpression, List <List <Range> > > dict = MethodInvoke.GetRangeBrackets(method.returnValueAndArgs()); foreach (IModelExpression arg in method.returnValueAndArgs()) { MethodInvoke.ForEachRange(arg, delegate(Range r) { if (!localRanges.Contains(r)) { localRanges.Add(r); } }); } ParameterInfo[] pis = method.method.GetParameters(); for (int i = 0; i < pis.Length; i++) { IModelExpression arg = method.Arguments[i]; ParameterInfo pi = pis[i]; if (pi.IsOut && arg is HasObservedValue && ((HasObservedValue)arg).IsObserved) { throw new NotImplementedException(string.Format("Out parameter '{0}' of {1} cannot be observed. Use ConstrainEqual or observe a copy of the variable.", pi.Name, method)); } } foreach (IStatementBlock b in method.Containers) { if (b is HasRange) { HasRange br = (HasRange)b; localRanges.Remove(br.Range); } } localRanges.Sort(delegate(Range a, Range b) { return(MethodInvoke.CompareRanges(dict, a, b)); }); // convert from List<Range> to List<IStatementBlock> List <IStatementBlock> localRangeBlocks = new List <IStatementBlock>(localRanges.Select(r => r)); BuildStatementBlocks(stBlocks, true); BuildStatementBlocks(localRangeBlocks, true); // Invoke method IExpression methodExpr = method.GetExpression(); IStatement st = Builder.ExprStatement(methodExpr); if (methodExpr is IAssignExpression && method.ReturnValue is HasObservedValue && ((HasObservedValue)method.ReturnValue).IsObserved) { Attributes.Set(st, new Constraint()); } AddStatement(st); foreach (ICompilerAttribute attr in method.attributes) { Attributes.Add(methodExpr, attr); } BuildStatementBlocks(localRangeBlocks, false); BuildStatementBlocks(stBlocks, false); }
/// <summary> /// Add the definition of a random variable to the MSL, inside of the necessary containers. /// </summary> /// <typeparam name="T"></typeparam> /// <param name="variable"></param> /// <remarks> /// A scalar variable is declared and defined in one line such as: <c>int x = factor(...);</c>. /// An array variable is first declared with an initializer such as: <c>int[] array = new int[4];</c>. /// Then it is defined either with a bulk factor such as: <c>array = factor(...);</c>, /// or it is defined via its item variable. /// An item variable is defined by 'for' loop whose body is: <c>array[i] = factor(...);</c>. /// </remarks> protected void BuildRandVar <T>(Variable <T> variable) { if (!variable.IsDefined) { throw new InferCompilerException("Variable '" + variable + "' has no definition"); } if (variable.IsArrayElement) { for (int initType = 0; initType < 2; initType++) { IModelExpression init = (initType == 0) ? variable.initialiseTo : variable.initialiseBackwardTo; if (init != null) { IExpression initExpr = init.GetExpression(); // find the base variable Variable parent = variable; while (parent.ArrayVariable != null) { IVariableDeclaration[] indexVars = new IVariableDeclaration[parent.indices.Count]; for (int i = 0; i < indexVars.Length; i++) { IModelExpression expr = parent.indices[i]; if (!(expr is Range)) { throw new Exception(parent + ".InitializeTo is not allowed since the indices are not ranges"); } indexVars[i] = ((Range)expr).GetIndexDeclaration(); } initExpr = VariableInformation.MakePlaceHolderArrayCreate(initExpr, indexVars); parent = (Variable)parent.ArrayVariable; } IVariableDeclaration parentDecl = (IVariableDeclaration)parent.GetDeclaration(); ICompilerAttribute attr; if (initType == 0) { attr = new InitialiseTo(initExpr); } else { attr = new InitialiseBackwardTo(initExpr); } Attributes.Set(parentDecl, attr); } } return; } IVariableDeclaration ivd = (IVariableDeclaration)variable.GetDeclaration(); if (variable.initialiseTo != null) { Attributes.Set(ivd, new InitialiseTo(variable.initialiseTo.GetExpression())); } if (variable.initialiseBackwardTo != null) { Attributes.Set(ivd, new InitialiseBackwardTo(variable.initialiseBackwardTo.GetExpression())); } List <IStatementBlock> stBlocks = new List <IStatementBlock>(); stBlocks.AddRange(variable.Containers); IVariableDeclarationExpression ivde = Builder.VarDeclExpr(ivd); if (variable is IVariableArray iva) { IList <IStatement> sc = Builder.StmtCollection(); IList <IVariableDeclaration[]> jaggedIndexVars; IList <IExpression[]> jaggedSizes; GetJaggedArrayIndicesAndSizes(iva, out jaggedIndexVars, out jaggedSizes); // check that containers are all unique and distinct from jaggedIndexVars Set <IVariableDeclaration> loopVars = new Set <IVariableDeclaration>(); foreach (IStatementBlock stBlock in stBlocks) { if (stBlock is ForEachBlock fb) { IVariableDeclaration loopVar = fb.Range.GetIndexDeclaration(); if (loopVars.Contains(loopVar)) { throw new InvalidOperationException("Variable '" + ivd.Name + "' uses range '" + loopVar.Name + "' twice. Use a cloned range instead."); } loopVars.Add(loopVar); } } foreach (IVariableDeclaration[] bracket in jaggedIndexVars) { foreach (IVariableDeclaration indexVar in bracket) { if (loopVars.Contains(indexVar)) { throw new InvalidOperationException("Variable '" + ivd.Name + "' uses range '" + indexVar.Name + "' twice. Use a cloned range instead."); } } } Builder.NewJaggedArray(sc, ivd, jaggedIndexVars, jaggedSizes); if (!variable.Inline) { BuildStatementBlocks(stBlocks, true); foreach (IStatement stmt in sc) { AddStatement(stmt); } BuildStatementBlocks(stBlocks, false); } ivde = null; // prevent re-declaration } if (ivde != null) { if (!variable.Inline) { BuildStatementBlocks(stBlocks, true); AddStatement(Builder.ExprStatement(ivde)); BuildStatementBlocks(stBlocks, false); } ivde = null; } if (ivde != null) { throw new InferCompilerException("Variable '" + variable + "' has no definition"); } }
/// <summary> /// Create a new DistributedSchedule attribute /// </summary> /// <param name="commExpression"></param> /// <param name="schedulePerThreadExpression">An observed variable of type int[][][][], whose dimensions are [distributedStage][thread][block][item]. Each thread must have the same number of blocks, but blocks can be different sizes. Must have at least one thread.</param> public DistributedSchedule(Variable <ICommunicator> commExpression, Variable <int[][][][]> schedulePerThreadExpression) { this.commExpression = commExpression; this.schedulePerThreadExpression = schedulePerThreadExpression; }
/// <summary> /// Creates a new DistributedCommunication attribute /// </summary> /// <param name="arrayIndicesToSendExpression"></param> /// <param name="arrayIndicesToReceiveExpression"></param> public DistributedCommunication(IModelExpression arrayIndicesToSendExpression, IModelExpression arrayIndicesToReceiveExpression) { this.arrayIndicesToSendExpression = arrayIndicesToSendExpression; this.arrayIndicesToReceiveExpression = arrayIndicesToReceiveExpression; }
/// <summary> /// Constructs a range whose size is given by an integer-value expression. /// </summary> /// <param name="size">An expression giving the size of the range</param> public Range(IModelExpression <int> size) { this.name = $"index{globalCounter.GetNext()}"; this.Size = size; }