internal MethodInvoke(IEnumerable <IStatementBlock> containers, MethodInfo method, params IModelExpression[] args) { this.timestamp = GetTimestamp(); this.method = method; this.args.AddRange(args); this.Containers = new List <IStatementBlock>(containers); foreach (IModelExpression arg in args) { if (ReferenceEquals(arg, null)) { throw new ArgumentNullException(); } if (arg is Variable) { Variable v = (Variable)arg; if (v.IsObserved) { continue; } foreach (ConditionBlock cb in v.GetContainers <ConditionBlock>()) { if (!this.Containers.Contains(cb)) { throw new InvalidOperationException($"{arg} was created in condition {cb} and cannot be used outside. " + $"To give {arg} a conditional definition, use SetTo inside {cb} rather than assignment (=). " + $"If you are using GetCopyFor, make sure you call GetCopyFor outside of conflicting conditional statements."); } } } } foreach (ConditionBlock cb in StatementBlock.EnumerateBlocks <ConditionBlock>(containers)) { cb.ConditionVariableUntyped.constraints.Add(this); } }
/// <summary> /// Close blocks in order to recover from exceptions /// </summary> internal static void CloseAllBlocks() { List <IStatementBlock> blocks = new List <IStatementBlock>(StatementBlock.GetOpenBlocks()); blocks.Reverse(); foreach (StatementBlock block in blocks) { block.CloseBlock(); } }
/// <summary> /// Builds the model necessary to infer marginals for the supplied variables and algorithm. /// </summary> /// <param name="engine">The inference algorithm being used</param> /// <param name="inferOnlySpecifiedVars">If true, inference will be restricted to only the variables given.</param> /// <param name="vars">Variables to infer.</param> /// <returns></returns> /// <remarks> /// Algorithm: starting from the variables to infer, we search through the graph to build up a "searched set". /// Each Variable and MethodInvoke in this set has an associated timestamp. /// We sort by timestamp, and then generate code. /// </remarks> public ITypeDeclaration Build(InferenceEngine engine, bool inferOnlySpecifiedVars, IEnumerable <IVariable> vars) { List <IStatementBlock> openBlocks = StatementBlock.GetOpenBlocks(); if (openBlocks.Count > 0) { throw new InvalidOperationException("The block " + openBlocks[0] + " has not been closed."); } Reset(); this.inferOnlySpecifiedVars = inferOnlySpecifiedVars; variablesToInfer.AddRange(vars); foreach (IVariable var in vars) { toSearch.Push(var); } while (toSearch.Count > 0) { IModelExpression expr = toSearch.Pop(); SearchExpressionUntyped(expr); } // lock in the set of model expressions. ModelExpressions = new List <IModelExpression>(searched); List <int> timestamps = new List <int>(); List <IModelExpression> exprs = new List <IModelExpression>(); foreach (IModelExpression expr in ModelExpressions) { if (expr is Variable var) { exprs.Add(var); timestamps.Add(var.timestamp); } else if (expr is MethodInvoke mi) { exprs.Add(mi); timestamps.Add(mi.timestamp); } } Collection.Sort(timestamps, exprs); foreach (IModelExpression expr in exprs) { BuildExpressionUntyped(expr); } foreach (IModelExpression expr in exprs) { FinishExpressionUntyped(expr, engine.Algorithm); } return(modelType); }
/// <summary> /// Get a random variable representing an item of an array. /// </summary> /// <param name="array"></param> /// <param name="itemPrototype"></param> /// <param name="index"></param> /// <returns></returns> internal static TItem GetItem(VariableArrayBase <TItem, TArray> array, TItem itemPrototype, params IModelExpression[] index) { Set <Range> switchRanges = new Set <Range>(); IList <Range> ranges = array.Ranges; if (index.Length != ranges.Count) { throw new ArgumentException("Provided " + index.Length + " indices to an array of rank " + ranges.Count); } for (int i = 0; i < ranges.Count; i++) { ranges[i].CheckCompatible(index[i], array); foreach (SwitchBlock block in StatementBlock.EnumerateOpenBlocks <SwitchBlock>()) { if (block.Range.Equals(index[i])) { throw new ArgumentException("Cannot index by '" + index[i] + "' in a switch block over '" + block.ConditionVariable + "'"); } } } IVariable item; Dictionary <IReadOnlyList <IModelExpression>, IVariable> itemVariables = ((HasItemVariables)array).GetItemsUntyped(); if (itemVariables.TryGetValue(index, out item)) { return((TItem)item); } // the item must be in the same containers as the array (not the currently open containers) if (itemPrototype is IVariableArray) { Dictionary <Range, Range> replacements = new Dictionary <Range, Range>(); Dictionary <IModelExpression, IModelExpression> expressionReplacements = new Dictionary <IModelExpression, IModelExpression>(); for (int i = 0; i < ranges.Count; i++) { expressionReplacements.Add(ranges[i], index[i]); } IVariable result = ((IVariableArray)itemPrototype).ReplaceRanges(replacements, expressionReplacements, deepCopy: false); TItem v = (TItem)result; v.MakeItem(array, index); return(v); } else { TItem v = (TItem)itemPrototype.Clone(); v.MakeItem(array, index); return(v); } }
IVariableArray IVariableArray.ReplaceRanges(Dictionary <Range, Range> rangeReplacements, Dictionary <IModelExpression, IModelExpression> expressionReplacements, bool deepCopy) { // must do this replacement first, since it will influence how we replace the itemPrototype Range newRange = Range.Replace(rangeReplacements, expressionReplacements); TItem itemPrototype = (TItem)((IVariableJaggedArray)this).ItemPrototype; if (itemPrototype is IVariableArray) { IVariable result = ((IVariableArray)itemPrototype).ReplaceRanges(rangeReplacements, expressionReplacements, deepCopy); itemPrototype = (TItem)result; } else { // make a clone in the current containers itemPrototype = (TItem)itemPrototype.Clone(); itemPrototype.containers = StatementBlock.GetOpenBlocks(); } return(new VariableArray <TItem, TArray>(itemPrototype, newRange)); }
internal MethodInvoke(MethodInfo method, params IModelExpression[] args) : this(StatementBlock.GetOpenBlocks(), method, args) { }