예제 #1
0
 internal MethodInvoke(IEnumerable <IStatementBlock> containers, MethodInfo method, params IModelExpression[] args)
 {
     this.timestamp = GetTimestamp();
     this.method    = method;
     this.args.AddRange(args);
     this.Containers = new List <IStatementBlock>(containers);
     foreach (IModelExpression arg in args)
     {
         if (ReferenceEquals(arg, null))
         {
             throw new ArgumentNullException();
         }
         if (arg is Variable)
         {
             Variable v = (Variable)arg;
             if (v.IsObserved)
             {
                 continue;
             }
             foreach (ConditionBlock cb in v.GetContainers <ConditionBlock>())
             {
                 if (!this.Containers.Contains(cb))
                 {
                     throw new InvalidOperationException($"{arg} was created in condition {cb} and cannot be used outside.  " +
                                                         $"To give {arg} a conditional definition, use SetTo inside {cb} rather than assignment (=).  " +
                                                         $"If you are using GetCopyFor, make sure you call GetCopyFor outside of conflicting conditional statements.");
                 }
             }
         }
     }
     foreach (ConditionBlock cb in StatementBlock.EnumerateBlocks <ConditionBlock>(containers))
     {
         cb.ConditionVariableUntyped.constraints.Add(this);
     }
 }
예제 #2
0
        /// <summary>
        /// Close blocks in order to recover from exceptions
        /// </summary>
        internal static void CloseAllBlocks()
        {
            List <IStatementBlock> blocks = new List <IStatementBlock>(StatementBlock.GetOpenBlocks());

            blocks.Reverse();
            foreach (StatementBlock block in blocks)
            {
                block.CloseBlock();
            }
        }
예제 #3
0
        /// <summary>
        /// Builds the model necessary to infer marginals for the supplied variables and algorithm.
        /// </summary>
        /// <param name="engine">The inference algorithm being used</param>
        /// <param name="inferOnlySpecifiedVars">If true, inference will be restricted to only the variables given.</param>
        /// <param name="vars">Variables to infer.</param>
        /// <returns></returns>
        /// <remarks>
        /// Algorithm: starting from the variables to infer, we search through the graph to build up a "searched set".
        /// Each Variable and MethodInvoke in this set has an associated timestamp.
        /// We sort by timestamp, and then generate code.
        /// </remarks>
        public ITypeDeclaration Build(InferenceEngine engine, bool inferOnlySpecifiedVars, IEnumerable <IVariable> vars)
        {
            List <IStatementBlock> openBlocks = StatementBlock.GetOpenBlocks();

            if (openBlocks.Count > 0)
            {
                throw new InvalidOperationException("The block " + openBlocks[0] + " has not been closed.");
            }
            Reset();
            this.inferOnlySpecifiedVars = inferOnlySpecifiedVars;
            variablesToInfer.AddRange(vars);
            foreach (IVariable var in vars)
            {
                toSearch.Push(var);
            }
            while (toSearch.Count > 0)
            {
                IModelExpression expr = toSearch.Pop();
                SearchExpressionUntyped(expr);
            }
            // lock in the set of model expressions.
            ModelExpressions = new List <IModelExpression>(searched);
            List <int> timestamps         = new List <int>();
            List <IModelExpression> exprs = new List <IModelExpression>();

            foreach (IModelExpression expr in ModelExpressions)
            {
                if (expr is Variable var)
                {
                    exprs.Add(var);
                    timestamps.Add(var.timestamp);
                }
                else if (expr is MethodInvoke mi)
                {
                    exprs.Add(mi);
                    timestamps.Add(mi.timestamp);
                }
            }
            Collection.Sort(timestamps, exprs);
            foreach (IModelExpression expr in exprs)
            {
                BuildExpressionUntyped(expr);
            }
            foreach (IModelExpression expr in exprs)
            {
                FinishExpressionUntyped(expr, engine.Algorithm);
            }
            return(modelType);
        }
예제 #4
0
        /// <summary>
        /// Get a random variable representing an item of an array.
        /// </summary>
        /// <param name="array"></param>
        /// <param name="itemPrototype"></param>
        /// <param name="index"></param>
        /// <returns></returns>
        internal static TItem GetItem(VariableArrayBase <TItem, TArray> array, TItem itemPrototype, params IModelExpression[] index)
        {
            Set <Range>   switchRanges = new Set <Range>();
            IList <Range> ranges       = array.Ranges;

            if (index.Length != ranges.Count)
            {
                throw new ArgumentException("Provided " + index.Length + " indices to an array of rank " + ranges.Count);
            }
            for (int i = 0; i < ranges.Count; i++)
            {
                ranges[i].CheckCompatible(index[i], array);
                foreach (SwitchBlock block in StatementBlock.EnumerateOpenBlocks <SwitchBlock>())
                {
                    if (block.Range.Equals(index[i]))
                    {
                        throw new ArgumentException("Cannot index by '" + index[i] + "' in a switch block over '" + block.ConditionVariable + "'");
                    }
                }
            }
            IVariable item;
            Dictionary <IReadOnlyList <IModelExpression>, IVariable> itemVariables = ((HasItemVariables)array).GetItemsUntyped();

            if (itemVariables.TryGetValue(index, out item))
            {
                return((TItem)item);
            }
            // the item must be in the same containers as the array (not the currently open containers)
            if (itemPrototype is IVariableArray)
            {
                Dictionary <Range, Range> replacements = new Dictionary <Range, Range>();
                Dictionary <IModelExpression, IModelExpression> expressionReplacements = new Dictionary <IModelExpression, IModelExpression>();
                for (int i = 0; i < ranges.Count; i++)
                {
                    expressionReplacements.Add(ranges[i], index[i]);
                }
                IVariable result = ((IVariableArray)itemPrototype).ReplaceRanges(replacements, expressionReplacements, deepCopy: false);
                TItem     v      = (TItem)result;
                v.MakeItem(array, index);
                return(v);
            }
            else
            {
                TItem v = (TItem)itemPrototype.Clone();
                v.MakeItem(array, index);
                return(v);
            }
        }
예제 #5
0
        IVariableArray IVariableArray.ReplaceRanges(Dictionary <Range, Range> rangeReplacements, Dictionary <IModelExpression, IModelExpression> expressionReplacements, bool deepCopy)
        {
            // must do this replacement first, since it will influence how we replace the itemPrototype
            Range newRange      = Range.Replace(rangeReplacements, expressionReplacements);
            TItem itemPrototype = (TItem)((IVariableJaggedArray)this).ItemPrototype;

            if (itemPrototype is IVariableArray)
            {
                IVariable result = ((IVariableArray)itemPrototype).ReplaceRanges(rangeReplacements, expressionReplacements, deepCopy);
                itemPrototype = (TItem)result;
            }
            else
            {
                // make a clone in the current containers
                itemPrototype            = (TItem)itemPrototype.Clone();
                itemPrototype.containers = StatementBlock.GetOpenBlocks();
            }
            return(new VariableArray <TItem, TArray>(itemPrototype, newRange));
        }
예제 #6
0
 internal MethodInvoke(MethodInfo method, params IModelExpression[] args)
     : this(StatementBlock.GetOpenBlocks(), method, args)
 {
 }