Пример #1
0
 private void CreateUsesChannel(VariableInformation vi, int useCount, VariableToChannelInformation vtci, IList <IStatement> stmts)
 {
     vtci.usageChannel      = ChannelInfo.UseChannel(vi);
     vtci.usageChannel.decl = vi.DeriveArrayVariable(stmts, context, vi.Name + "_uses", Builder.LiteralExpr(useCount), Builder.VarDecl("_ind", typeof(int)), useLiteralIndices: true);
     context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, vtci.usageChannel.decl);
     context.OutputAttributes.Set(vtci.usageChannel.decl, vtci.usageChannel);
     context.OutputAttributes.Set(vtci.usageChannel.decl, new DescriptionAttribute("uses of '" + vi.Name + "'"));
 }
Пример #2
0
        private void CreateUseChannel(IVariableDeclaration ivd, VariableInformation vi, IList <IStatement> stmts)
        {
            IVariableDeclaration useDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_use");

            useOfVariable[ivd] = useDecl;
            ChannelInfo usageChannel = ChannelInfo.UseChannel(vi);

            usageChannel.decl = useDecl;
            context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl);
            context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl);
            context.OutputAttributes.Set(useDecl, usageChannel);
            context.OutputAttributes.Set(useDecl, new DescriptionAttribute("use of '" + ivd.Name + "'"));
            context.OutputAttributes.Remove <InitialiseTo>(vi.declaration);
            SetMarginalPrototype(useDecl);
        }
Пример #3
0
        private IVariableDeclaration CreateUseChannel(VariableInformation vi, IList <IStatement> stmts)
        {
            IVariableDeclaration useDecl      = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_use");
            ChannelInfo          usageChannel = ChannelInfo.UseChannel(vi);

            usageChannel.decl = useDecl;
            context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl);
            context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl);
            context.OutputAttributes.Set(useDecl, usageChannel);
            context.OutputAttributes.Set(useDecl, new DescriptionAttribute("use of '" + vi.Name + "'"));
            context.OutputAttributes.Remove <InitialiseTo>(vi.declaration);
            // The following lines are needed for AddMarginalStatements
            VariableInformation useInformation = VariableInformation.GetVariableInformation(context, useDecl);

            useInformation.IsStochastic = true;
            return(useDecl);
        }
Пример #4
0
        private VariableToChannelInformation DeclareUsesArray(IList <IStatement> stmts, IVariableDeclaration ivd, VariableInformation vi, int useCount, int usageDepth)
        {
            // Create AnyIndex expressions up to usageDepth
            List <IList <IExpression> > prefixSizes = new List <IList <IExpression> >();
            List <IList <IExpression> > prefixVars  = new List <IList <IExpression> >();

            vi.DefineAllIndexVars(context);
            for (int d = 0; d < usageDepth; d++)
            {
                IList <IExpression> sizeBracket = Builder.ExprCollection();
                IList <IExpression> varBracket  = Builder.ExprCollection();
                for (int i = 0; i < vi.sizes[d].Length; i++)
                {
                    sizeBracket.Add(Builder.StaticMethod(new Func <int>(GateAnalysisTransform.AnyIndex)));
                    varBracket.Add(Builder.VarRefExpr(vi.indexVars[d][i]));
                }
                prefixSizes.Add(sizeBracket);
                prefixVars.Add(varBracket);
            }
            string prefix = vi.Name;

            if (prefix.EndsWith("_use"))
            {
                prefix = prefix.Substring(0, prefix.Length - 4);
            }
            string arrayName = VariableInformation.GenerateName(context, prefix + "_uses");
            IVariableDeclaration usesDecl = vi.DeriveArrayVariable(stmts, context, arrayName, Builder.LiteralExpr(useCount), Builder.VarDecl("_ind", typeof(int)),
                                                                   prefixSizes, prefixVars, useLiteralIndices: true);

            context.OutputAttributes.Remove <ChannelInfo>(usesDecl);
            context.OutputAttributes.Add(usesDecl, new DescriptionAttribute($"uses of '{vi.Name}'"));
            ChannelInfo ci = ChannelInfo.UseChannel(vi);

            ci.decl = usesDecl;
            context.OutputAttributes.Set(usesDecl, ci);
            VariableToChannelInformation vtci = new VariableToChannelInformation();

            vtci.usesDecl       = usesDecl;
            vtci.usageDepth     = usageDepth;
            usesOfVariable[ivd] = vtci;
            return(vtci);
        }
Пример #5
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);
            // Check if this is an index local variable
            if (baseVar == null) return expr;
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr;

            // Get the loop context for this variable
            LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar);
            if (lc == null)
            {
                Error("Loop context not found for '" + baseVar.Name + "'.");
                return expr;
            }

            // Get the reference loop context for this expression
            RefLoopContext rlc = lc.GetReferenceLoopContext(context);
            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.loops.Count == 0) return expr;

            // the set of loop variables that are constant wrt the expr
            Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>();
            constantLoopVars.AddRange(lc.loopVariables);

            // collect set of all loop variable indices in the expression
            Set<int> embeddedLoopIndices = new Set<int>();
            List<IList<IExpression>> brackets = Recognizer.GetIndices(expr);
            foreach (IList<IExpression> bracket in brackets)
            {
                foreach (IExpression index in bracket)
                {
                    IExpression indExpr = index;
                    if (indExpr is IBinaryExpression ibe)
                    {
                        indExpr = ibe.Left;
                    }
                    IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr);
                    if (indVar != null)
                    {
                        if (!constantLoopVars.Contains(indVar))
                        {
                            int loopIndex = rlc.loopVariables.IndexOf(indVar);
                            if (loopIndex != -1)
                            {
                                // indVar is a loop variable
                                constantLoopVars.Add(rlc.loopVariables[loopIndex]);
                            }
                            else 
                            {
                                // indVar is not a loop variable
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar);
                                foreach (var ivd in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                    else
                    {
                        foreach(var ivd in Recognizer.GetVariables(indExpr))
                        {
                            if (!constantLoopVars.Contains(ivd))
                            {
                                // copied from above
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd);
                                foreach (var ivd2 in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd2))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd2);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd2} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                }
            }

            // Find loop variables that must be constant due to condition statements.
            List<IStatement> ancestors = context.FindAncestors<IStatement>();
            foreach (IStatement ancestor in ancestors)
            {
                if (!(ancestor is IConditionStatement ics))
                    continue;
                ConditionBinding binding = new ConditionBinding(ics.Condition);
                IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs);
                IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs);
                int index = rlc.loopVariables.IndexOf(ivd);
                if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars))
                {
                    constantLoopVars.Add(ivd);
                    continue;
                }
                int index2 = rlc.loopVariables.IndexOf(ivd2);
                if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars))
                {
                    constantLoopVars.Add(ivd2);
                    continue;
                }
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get<Containers>(baseVar);

            IExpression originalExpr = expr;

            for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++)
            {
                IVariableDeclaration loopVar = rlc.loopVariables[currentLoop];
                if (constantLoopVars.Contains(loopVar))
                    continue;
                IForStatement loop = rlc.loops[currentLoop];
                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a loop.  Variables on the left hand side of an assignment must be indexed by all containing loops.");
                    continue;
                }
                if (embeddedLoopIndices.Contains(currentLoop))
                {
                    string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
                    Warning(string.Format(warningText, originalExpr, loopVar.Name));
                }
                // split expr into a target and extra indices, where target will be replicated and extra indices will be added later
                var extraIndices = new List<IEnumerable<IExpression>>();
                AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate);

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IExpression loopSize = Recognizer.LoopSizeExpression(loop);
                IList<IStatement> stmts = Builder.StmtCollection();
                List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate);
                IVariableDeclaration newIndexVar = loopVar;
                // if loopVar is already an indexVar of varInfo, create a new variable
                if (varInfo.HasIndexVar(loopVar))
                {
                    newIndexVar = VariableInformation.GenerateLoopVar(context, "_a");
                    context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar);
                }
                IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"),
                                                                          loopSize, newIndexVar, inds, useArrays: true);
                if (!context.InputAttributes.Has<DerivedVariable>(repVar))
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                if (context.InputAttributes.Has<ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate),
                    new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // RemoveUnusedLoops will also remove conditionals involving those loop variables.
                // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms.
                containers = Containers.RemoveUnusedLoops(containers, context, repMethod);
                if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context);
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                if (containers.Contains(loop))
                {
                    Error("Internal: invalid containers for replicating " + baseVar);
                    break;
                }
                int ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                context.OutputAttributes.Set(repVar, containers);
                List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex);
                foreach (IStatement container in missing.inputs)
                {
                    if (container is IForStatement ifs) loops.Add(ifs);
                }
                context.OutputAttributes.Set(repVar, new LoopContext(loops));
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar));
                expr = Builder.JaggedArrayIndex(expr, extraIndices);
            }

            return expr;
        }
Пример #6
0
        protected void ProcessAssign(IExpression target, IExpression rhs, ref bool shouldDelete)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);

            if (ivd == null)
            {
                return;
            }
            if (rhs is IArrayCreateExpression)
            {
                IArrayCreateExpression iace = (IArrayCreateExpression)rhs;
                bool zeroLength             = iace.Dimensions.All(dimExpr =>
                                                                  (dimExpr is ILiteralExpression) && ((ILiteralExpression)dimExpr).Value.Equals(0));
                if (!zeroLength && iace.Initializer == null)
                {
                    return; // variable will have assignments to elements
                }
            }
            bool firstTime = !variablesAssigned.Contains(ivd);

            variablesAssigned.Add(ivd);
            bool isInferred   = context.InputAttributes.Has <IsInferred>(ivd);
            bool isStochastic = CodeRecognizer.IsStochastic(context, ivd);

            if (!isStochastic)
            {
                return;
            }
            VariableInformation vi            = VariableInformation.GetVariableInformation(context, ivd);
            Containers          defContainers = context.InputAttributes.Get <Containers>(ivd);
            int        ancIndex = defContainers.GetMatchingAncestorIndex(context);
            Containers missing  = defContainers.GetContainersNotInContext(context, ancIndex);
            // definition of a stochastic variable
            IExpression lhs = target;

            if (lhs is IVariableDeclarationExpression)
            {
                lhs = Builder.VarRefExpr(ivd);
            }
            IExpression defExpr = lhs;

            if (firstTime && isStochastic)
            {
                // Create a ChannelInfo attribute for use by later transforms, e.g. MessageTransform
                ChannelInfo defChannel = ChannelInfo.DefChannel(vi);
                defChannel.decl = ivd;
                context.OutputAttributes.Set(ivd, defChannel);
            }
            bool       isDerived = context.InputAttributes.Has <DerivedVariable>(ivd);
            IAlgorithm algorithm = this.algorithmDefault;
            Algorithm  algAttr   = context.InputAttributes.Get <Algorithm>(ivd);

            if (algAttr != null)
            {
                algorithm = algAttr.algorithm;
            }
            if (algorithm is VariationalMessagePassing && ((VariationalMessagePassing)algorithm).UseDerivMessages && isDerived && firstTime)
            {
                vi.DefineAllIndexVars(context);
                IList <IStatement>   stmts     = Builder.StmtCollection();
                IVariableDeclaration derivDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_deriv");
                context.OutputAttributes.Set(ivd, new DerivMessage(derivDecl));
                ChannelInfo derivChannel = ChannelInfo.DefChannel(vi);
                derivChannel.decl = derivDecl;
                context.OutputAttributes.Set(derivChannel.decl, derivChannel);
                context.OutputAttributes.Set(derivChannel.decl, new DescriptionAttribute("deriv of '" + ivd.Name + "'"));
                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            bool isPointEstimate = context.InputAttributes.Has <PointEstimate>(ivd);

            if (this.analysis.variablesExcludingVariableFactor.Contains(ivd))
            {
                this.variablesLackingVariableFactor.Add(ivd);
                // ivd will get a marginal channel in ConvertMethodInvoke
                useOfVariable[ivd] = ivd;
                return;
            }
            if (isDerived && !isInferred && !isPointEstimate)
            {
                return;
            }

            IExpression useExpr2 = null;

            if (firstTime)
            {
                // create marginal and use channels
                vi.DefineAllIndexVars(context);
                IList <IStatement> stmts = Builder.StmtCollection();

                CreateMarginalChannel(ivd, vi, stmts);
                if (isStochastic)
                {
                    CreateUseChannel(ivd, vi, stmts);
                    context.InputAttributes.Set(useOfVariable[ivd], defContainers);
                }

                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            if (isStochastic && !useOfVariable.ContainsKey(ivd))
            {
                Error("cannot find use channel of " + ivd);
                return;
            }
            IExpression  marginalExpr = Builder.ReplaceVariable(lhs, ivd, marginalOfVariable[ivd]);
            IExpression  useExpr      = isStochastic ? Builder.ReplaceVariable(lhs, ivd, useOfVariable[ivd]) : marginalExpr;
            InitialiseTo it           = context.InputAttributes.Get <InitialiseTo>(ivd);

            Type[] genArgs = new Type[] { defExpr.GetExpressionType() };
            if (rhs is IMethodInvokeExpression)
            {
                IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs;
                if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy)) && ancIndex < context.InputStack.Count - 2)
                {
                    IExpression          arg  = imie.Arguments[0];
                    IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(arg);
                    if (ivd2 != null && context.InputAttributes.Get <MarginalPrototype>(ivd) == context.InputAttributes.Get <MarginalPrototype>(ivd2))
                    {
                        // if a variable is a copy, use the original expression since it will give more precise dependencies.
                        defExpr      = arg;
                        shouldDelete = true;
                        bool makeClone = false;
                        if (makeClone)
                        {
                            VariableInformation         vi2      = VariableInformation.GetVariableInformation(context, ivd2);
                            IList <IStatement>          stmts    = Builder.StmtCollection();
                            List <IList <IExpression> > indices  = Recognizer.GetIndices(defExpr);
                            IVariableDeclaration        useDecl2 = vi2.DeriveIndexedVariable(stmts, context, ivd2.Name + "_use", indices);
                            useExpr2 = Builder.VarRefExpr(useDecl2);
                            Containers defContainers2 = context.InputAttributes.Get <Containers>(ivd2);
                            int        ancIndex2      = defContainers2.GetMatchingAncestorIndex(context);
                            Containers missing2       = defContainers2.GetContainersNotInContext(context, ancIndex2);
                            stmts = Containers.WrapWithContainers(stmts, missing2.outputs);
                            context.AddStatementsBeforeAncestorIndex(ancIndex2, stmts);
                            context.InputAttributes.Set(useDecl2, defContainers2);

                            // TODO: call CreateUseChannel
                            ChannelInfo usageChannel = ChannelInfo.UseChannel(vi2);
                            usageChannel.decl = useDecl2;
                            context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.OutputAttributes.Set(useDecl2, usageChannel);
                            //context.OutputAttributes.Set(useDecl2, new DescriptionAttribute("use of '" + ivd.Name + "'"));
                            context.OutputAttributes.Remove <InitialiseTo>(vi.declaration);

                            IExpression copyExpr = Builder.StaticGenericMethod(
                                new Func <PlaceHolder, PlaceHolder>(Clone.Copy), genArgs, useExpr2);
                            var copyStmt = Builder.AssignStmt(useExpr, copyExpr);
                            context.AddStatementAfterCurrent(copyStmt);
                        }
                    }
                }
            }

            // Add the variable factor
            IExpression variableFactorExpr;
            bool        isGateExitRandom = context.InputAttributes.Has <VariationalMessagePassing.GateExitRandomVariable>(ivd);

            if (isGateExitRandom)
            {
                variableFactorExpr = Builder.StaticGenericMethod(
                    new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Gate.ExitingVariable),
                    genArgs, defExpr, marginalExpr);
            }
            else
            {
                Delegate d = algorithm.GetVariableFactor(isDerived, it != null);
                if (isPointEstimate)
                {
                    d = new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint);
                }
                if (it == null)
                {
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, marginalExpr);
                }
                else
                {
                    IExpression initExpr = Builder.ReplaceExpression(lhs, Builder.VarRefExpr(ivd), it.initialMessagesExpression);
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, initExpr, marginalExpr);
                }
            }
            context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr);
            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, variableFactorExpr);
            if (isStochastic)
            {
                context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor());
            }
            var assignStmt = Builder.AssignStmt(useExpr2 == null ? useExpr : useExpr2, variableFactorExpr);

            context.AddStatementAfterCurrent(assignStmt);
        }
Пример #7
0
        /// <summary>
        /// Converts a variable declaration by creating definition, marginal and uses channel variables.
        /// </summary>
        protected override IExpression ConvertVariableDeclExpr(IVariableDeclarationExpression ivde)
        {
            IVariableDeclaration ivd = ivde.Variable;
            VariableInformation  vi  = VariableInformation.GetVariableInformation(context, ivd);

            // If the variable is deterministic, return
            if (!vi.IsStochastic)
            {
                ProcessConstant(ivd);
                context.OutputAttributes.Set(ivd, new DescriptionAttribute("The constant '" + ivd.Name + "'"));
                return(ivde);
            }
            bool suppressVariableFactor = context.InputAttributes.Has <SuppressVariableFactor>(ivd);
            bool isDerived  = context.InputAttributes.Has <DerivedVariable>(ivd);
            bool isConstant = false;
            bool isInferred = context.InputAttributes.Has <IsInferred>(ivd);
            int  useCount;

            ChannelAnalysisTransform.UsageInfo info;
            if (!analysis.usageInfo.TryGetValue(ivd, out info))
            {
                useCount = 0;
            }
            else
            {
                useCount = info.NumberOfUsesOld;
            }
            if (!(algorithm is Algorithms.GibbsSampling) && !isConstant && !suppressVariableFactor && (useCount == 1) && !isInferred && isDerived)
            {
                // this is optional
                suppressVariableFactor = true;
                context.InputAttributes.Set(ivd, new SuppressVariableFactor());
            }
            context.InputAttributes.Remove <LoopContext>(ivd);
            context.InputAttributes.Set(ivd, new LoopContext(context));

            // Create variable-to-channel information for the variable.
            VariableToChannelInformation vtc = new VariableToChannelInformation();

            vtc.shareAllUses = (useCount == 1);
            Context.InputAttributes.Set(ivd, vtc);

            // Ensure the marginal prototype is set.
            MarginalPrototype mpa = Context.InputAttributes.Get <MarginalPrototype>(ivd);

            try
            {
                vi.SetMarginalPrototypeFromAttribute(mpa);
            }
            catch (ArgumentException ex)
            {
                Error(ex.Message);
            }

            // Create the definition channel
            vtc.defChannel      = ChannelInfo.DefChannel(vi);
            vtc.defChannel.decl = ivd;
            // Always create a variable factor for a stochastic variable
            if (!isConstant && !suppressVariableFactor)
            {
                vi.DefineAllIndexVars(context);
                IList <IStatement> stmts = Builder.StmtCollection();

                // Create marginal channel
                vtc.marginalChannel      = ChannelInfo.MarginalChannel(vi);
                vtc.marginalChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_marginal");
                context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, vtc.marginalChannel.decl);
                context.OutputAttributes.Set(vtc.marginalChannel.decl, vtc.marginalChannel);
                context.OutputAttributes.Set(vtc.marginalChannel.decl, new DescriptionAttribute("marginal of '" + ivd.Name + "'"));
                SetMarginalPrototype(vtc.marginalChannel.decl);
                if (algorithm is GibbsSampling && ((GibbsSampling)algorithm).UseSideChannels)
                {
                    Type marginalType = MessageTransform.GetDistributionType(vi.varType, vi.InnermostElementType,
                                                                             vi.marginalPrototypeExpression.GetExpressionType(), true);
                    Type domainType = ivd.VariableType.DotNetType;

                    vtc.samplesChannel      = ChannelInfo.MarginalChannel(vi);
                    vtc.samplesChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_samples");
                    context.OutputAttributes.Remove <InitialiseTo>(vtc.samplesChannel.decl);
                    context.OutputAttributes.Set(vtc.samplesChannel.decl, vtc.samplesChannel);
                    context.OutputAttributes.Set(vtc.samplesChannel.decl, new DescriptionAttribute("samples of '" + ivd.Name + "'"));
                    Type                samplesType = typeof(List <>).MakeGenericType(domainType);
                    IExpression         samples_mpe = Builder.NewObject(samplesType);
                    VariableInformation samples_vi  = VariableInformation.GetVariableInformation(context, vtc.samplesChannel.decl);
                    samples_vi.marginalPrototypeExpression = samples_mpe;

                    vtc.conditionalsChannel      = ChannelInfo.MarginalChannel(vi);
                    vtc.conditionalsChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_conditionals");
                    context.OutputAttributes.Remove <InitialiseTo>(vtc.conditionalsChannel.decl);
                    context.OutputAttributes.Set(vtc.conditionalsChannel.decl, vtc.conditionalsChannel);
                    context.OutputAttributes.Set(vtc.conditionalsChannel.decl, new DescriptionAttribute("conditionals of '" + ivd.Name + "'"));
                    Type                conditionalsType = typeof(List <>).MakeGenericType(marginalType);
                    IExpression         conditionals_mpe = Builder.NewObject(conditionalsType);
                    VariableInformation conditionals_vi  = VariableInformation.GetVariableInformation(context, vtc.conditionalsChannel.decl);
                    conditionals_vi.marginalPrototypeExpression = conditionals_mpe;
                }
                else
                {
                    vtc.samplesChannel      = vtc.marginalChannel;
                    vtc.conditionalsChannel = vtc.marginalChannel;
                }

                // Create uses channel
                vtc.usageChannel      = ChannelInfo.UseChannel(vi);
                vtc.usageChannel.decl = vi.DeriveArrayVariable(stmts, context, vi.Name + "_uses", Builder.LiteralExpr(useCount), Builder.VarDecl("_ind", typeof(int)), useLiteralIndices: true);
                context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, vtc.usageChannel.decl);
                context.OutputAttributes.Set(vtc.usageChannel.decl, vtc.usageChannel);
                context.OutputAttributes.Set(vtc.usageChannel.decl, new DescriptionAttribute("uses of '" + ivd.Name + "'"));
                SetMarginalPrototype(vtc.usageChannel.decl);

                //setAllGroupRoots(context, ivd, false);

                context.AddStatementsBeforeCurrent(stmts);

                // Append usageDepth indices to def/marginal/use expressions
                IExpression defExpr      = Builder.VarRefExpr(ivd);
                IExpression marginalExpr = Builder.VarRefExpr(vtc.marginalChannel.decl);
                IExpression usageExpr    = Builder.VarRefExpr(vtc.usageChannel.decl);
                IExpression countExpr    = Builder.LiteralExpr(useCount);

                // Add clone factor tying together all of the channels
                IMethodInvokeExpression usesEqualDefExpression;
                Type[] genArgs = new Type[] { vi.varType };
                if (algorithm is GibbsSampling && ((GibbsSampling)algorithm).UseSideChannels)
                {
                    GibbsSampling gs               = (GibbsSampling)algorithm;
                    IExpression   burnInExpr       = Builder.LiteralExpr(gs.BurnIn);
                    IExpression   thinExpr         = Builder.LiteralExpr(gs.Thin);
                    IExpression   samplesExpr      = Builder.VarRefExpr(vtc.samplesChannel.decl);
                    IExpression   conditionalsExpr = Builder.VarRefExpr(vtc.conditionalsChannel.decl);
                    if (isDerived)
                    {
                        Delegate d = new FuncOut3 <PlaceHolder, int, int, int, PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder[]>(Factor.ReplicateWithMarginalGibbs);
                        usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, burnInExpr, thinExpr, marginalExpr, samplesExpr, conditionalsExpr);
                    }
                    else
                    {
                        Delegate d = new FuncOut3 <PlaceHolder, int, int, int, PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder[]>(Factor.UsesEqualDefGibbs);
                        usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, burnInExpr, thinExpr, marginalExpr, samplesExpr, conditionalsExpr);
                    }
                }
                else
                {
                    Delegate d;
                    if (isDerived)
                    {
                        d = new FuncOut <PlaceHolder, int, PlaceHolder, PlaceHolder[]>(Factor.ReplicateWithMarginal <PlaceHolder>);
                    }
                    else
                    {
                        d = new FuncOut <PlaceHolder, int, PlaceHolder, PlaceHolder[]>(Factor.UsesEqualDef <PlaceHolder>);
                    }
                    usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, marginalExpr);
                }
                if (isDerived)
                {
                    context.OutputAttributes.Set(usesEqualDefExpression, new DerivedVariable());            // used by Gibbs
                }
                // Mark this as a pseudo-factor
                context.OutputAttributes.Set(usesEqualDefExpression, new IsVariableFactor());
                if (useCount == 1)
                {
                    context.OutputAttributes.Set(usesEqualDefExpression, new DivideMessages(false));
                }
                else
                {
                    context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, usesEqualDefExpression);
                }
                context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, usesEqualDefExpression);
                IAssignExpression assignExpr = Builder.AssignExpr(usageExpr, usesEqualDefExpression);

                // Copy attributes across from input to output
                Context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, assignExpr);
                context.OutputAttributes.Remove <InitialiseTo>(ivd);
                if (vi.ArrayDepth == 0)
                {
                    // Insert the UsesEqualDef statement after the declaration.
                    // Note the variable will not have been defined yet.
                    context.AddStatementAfterCurrent(Builder.ExprStatement(assignExpr));
                }
                else
                {
                    // For an array, the UsesEqualDef statement should be inserted after the array is allocated.
                    // Store the statement for later use by ConvertArrayCreate.
                    context.InputAttributes.Remove <LoopContext>(ivd);
                    context.InputAttributes.Set(ivd, new LoopContext(context));
                    context.InputAttributes.Remove <Containers>(ivd);
                    context.InputAttributes.Set(ivd, new Containers(context));
                    vtc.usesEqualDefsStatements = Builder.StmtCollection();
                    vtc.usesEqualDefsStatements.Add(Builder.ExprStatement(assignExpr));
                }
            }
            // These must be set after the above or they will be copied to the other channels
            context.OutputAttributes.Set(ivd, vtc.defChannel);
            context.OutputAttributes.Set(vtc.defChannel.decl, new DescriptionAttribute("definition of '" + ivd.Name + "'"));
            return(ivde);
        }
Пример #8
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);

            // Check if this is an index local variable
            if (baseVar == null)
            {
                return(expr);
            }
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar))
            {
                return(expr);
            }
            if (cutVariables.Contains(baseVar))
            {
                return(expr);
            }

            // Get the repeat context for this variable
            RepeatContext lc = context.InputAttributes.Get <RepeatContext>(baseVar);

            if (lc == null)
            {
                Error("Repeat context not found for '" + baseVar.Name + "'.");
                return(expr);
            }

            // Get the reference loop context for this expression
            var rlc = lc.GetReferenceRepeatContext(context);

            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.repeats.Count == 0)
            {
                return(expr);
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get <Containers>(baseVar);

            for (int currentRepeat = 0; currentRepeat < rlc.repeatCounts.Count; currentRepeat++)
            {
                IExpression      repeatCount = rlc.repeatCounts[currentRepeat];
                IRepeatStatement repeat      = rlc.repeats[currentRepeat];

                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a repeat block.");
                    continue;
                }
                // are we replicating the argument of Gate.Cases?
                IMethodInvokeExpression imie = context.FindAncestor <IMethodInvokeExpression>();
                if (imie != null)
                {
                    if (Recognizer.IsStaticMethod(imie, typeof(Gate), "Cases"))
                    {
                        Error("'if(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'");
                    }
                    if (Recognizer.IsStaticMethod(imie, typeof(Gate), "CasesInt"))
                    {
                        Error("'case(" + expr + ")' or 'switch(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'");
                    }
                }

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IList <IStatement>  stmts   = Builder.StmtCollection();

                List <IList <IExpression> > inds   = Recognizer.GetIndices(expr);
                IVariableDeclaration        repVar = varInfo.DeriveIndexedVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rpt"), inds);
                if (!context.InputAttributes.Has <DerivedVariable>(repVar))
                {
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                }
                if (context.InputAttributes.Has <ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo         ci         = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }
                // set the RepeatContext of repVar to include all repeats up to this one (so that it doesn't get Entered again)
                List <IRepeatStatement> repeats = new List <IRepeatStatement>(lc.repeats);
                for (int i = 0; i <= currentRepeat; i++)
                {
                    repeats.Add(rlc.repeats[i]);
                }
                context.OutputAttributes.Remove <RepeatContext>(repVar);
                context.OutputAttributes.Set(repVar, new RepeatContext(repeats));

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression powerPlateMethod = Builder.StaticGenericMethod(
                    new Func <PlaceHolder, double, PlaceHolder>(PowerPlate.Enter),
                    new Type[] { returnType }, expr, repeatCount);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), powerPlateMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, powerPlateMethod);
                context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(baseVar, context.OutputAttributes, powerPlateMethod);
                context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(baseVar, context.OutputAttributes, powerPlateMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // remove inner repeats
                for (int i = currentRepeat + 1; i < rlc.repeatCounts.Count; i++)
                {
                    containers = containers.RemoveOneRepeat(rlc.repeats[i]);
                }
                context.OutputAttributes.Set(repVar, containers);
                containers = containers.RemoveOneRepeat(repeat);
                containers = Containers.RemoveUnusedLoops(containers, context, powerPlateMethod);
                if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar))
                {
                    containers = Containers.RemoveStochasticConditionals(containers, context);
                }
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                int        ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing  = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr    = Builder.VarRefExpr(repVar);
            }

            return(expr);
        }
        private IExpression GetClone(IExpression expr)
        {
            // Determine if this is a definition i.e. the variable is on the left hand side of an assignment
            // This must be done before base.ConvertArrayIndexer changes the expression!
            if (Recognizer.IsBeingIndexed(context))
            {
                return(expr);
            }
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);

            // If not an indexed variable reference, skip it (e.g. an indexed argument reference)
            if (baseVar == null)
            {
                return(expr);
            }
            DepthInfo depthInfo;

            if (!analysis.depthInfos.TryGetValue(baseVar, out depthInfo))
            {
                return(expr);
            }
            if (depthInfo.useCount <= 1 || depthInfo.indexInfoOfDepth.Count == 1)
            {
                return(expr);
            }
            bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(baseVar);

            if (isEvidenceVar && !cloneEvidenceVars)
            {
                return(expr);
            }
            int       depth = Recognizer.GetIndexingDepth(expr);
            IndexInfo info  = depthInfo.indexInfoOfDepth[depth];

            if (info.clone == null)
            {
                if (depth == depthInfo.definitionDepth)
                {
                    return(expr);
                }
                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                string             name     = baseVar.Name + "_depth" + depth;
                IList <IStatement> stmts    = Builder.StmtCollection();

                // declare the clone
                varInfo.DefineAllIndexVars(context);
                IVariableDeclaration newvd = varInfo.DeriveIndexedVariable(stmts, context, name);
                VariableInformation  newVariableInformation = VariableInformation.GetVariableInformation(context, newvd);
                newVariableInformation.LiteralIndexingDepth = info.literalIndexingDepth;
                if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                {
                    context.InputAttributes.Set(newvd, new DerivedVariable());
                }
                if (context.InputAttributes.Has <ChannelInfo>(baseVar))
                {
                    ChannelInfo ci = ChannelInfo.UseChannel(varInfo);
                    ci.decl = newvd;
                    context.OutputAttributes.Set(newvd, ci);
                }

                // define the clone
                // if depth < definitionDepth, index by the definitionDepth
                // else index by depth
                // e.g.
                // x[i] = definition
                // x_depth0[i] = Copy(x[i])
                // x_depth2[i][j] = Copy(x[i][j])
                int         indexingDepth = System.Math.Max(depth, depthInfo.definitionDepth);
                IExpression lhs           = Builder.VarRefExpr(newvd);
                // TODO: clone the next lower depth, not always the baseVar
                IExpression rhs = Builder.VarRefExpr(baseVar);
                AddCopyStatements(stmts, newVariableInformation, indexingDepth, lhs, rhs);

                // Reduce memory consumption by declaring the clone outside of unnecessary loops.
                // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item.
                Containers containers = info.containers;
                containers = Containers.RemoveUnusedLoops(containers, context, Builder.VarRefExpr(baseVar));
                if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar))
                {
                    containers = Containers.RemoveStochasticConditionals(containers, context);
                }
                // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible,
                // then wrap the declaration with the remaining containers.
                int        ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing  = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
                context.InputAttributes.Set(newvd, containers);
                info.clone = Builder.VarRefExpr(newvd);
            }
            List <IList <IExpression> > indices = Recognizer.GetIndices(expr);

            return(Builder.JaggedArrayIndex(info.clone, indices));
        }