private void CreateUsesChannel(VariableInformation vi, int useCount, VariableToChannelInformation vtci, IList <IStatement> stmts) { vtci.usageChannel = ChannelInfo.UseChannel(vi); vtci.usageChannel.decl = vi.DeriveArrayVariable(stmts, context, vi.Name + "_uses", Builder.LiteralExpr(useCount), Builder.VarDecl("_ind", typeof(int)), useLiteralIndices: true); context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, vtci.usageChannel.decl); context.OutputAttributes.Set(vtci.usageChannel.decl, vtci.usageChannel); context.OutputAttributes.Set(vtci.usageChannel.decl, new DescriptionAttribute("uses of '" + vi.Name + "'")); }
private void CreateUseChannel(IVariableDeclaration ivd, VariableInformation vi, IList <IStatement> stmts) { IVariableDeclaration useDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_use"); useOfVariable[ivd] = useDecl; ChannelInfo usageChannel = ChannelInfo.UseChannel(vi); usageChannel.decl = useDecl; context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl); context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl); context.OutputAttributes.Set(useDecl, usageChannel); context.OutputAttributes.Set(useDecl, new DescriptionAttribute("use of '" + ivd.Name + "'")); context.OutputAttributes.Remove <InitialiseTo>(vi.declaration); SetMarginalPrototype(useDecl); }
private IVariableDeclaration CreateUseChannel(VariableInformation vi, IList <IStatement> stmts) { IVariableDeclaration useDecl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_use"); ChannelInfo usageChannel = ChannelInfo.UseChannel(vi); usageChannel.decl = useDecl; context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl); context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl); context.OutputAttributes.Set(useDecl, usageChannel); context.OutputAttributes.Set(useDecl, new DescriptionAttribute("use of '" + vi.Name + "'")); context.OutputAttributes.Remove <InitialiseTo>(vi.declaration); // The following lines are needed for AddMarginalStatements VariableInformation useInformation = VariableInformation.GetVariableInformation(context, useDecl); useInformation.IsStochastic = true; return(useDecl); }
private VariableToChannelInformation DeclareUsesArray(IList <IStatement> stmts, IVariableDeclaration ivd, VariableInformation vi, int useCount, int usageDepth) { // Create AnyIndex expressions up to usageDepth List <IList <IExpression> > prefixSizes = new List <IList <IExpression> >(); List <IList <IExpression> > prefixVars = new List <IList <IExpression> >(); vi.DefineAllIndexVars(context); for (int d = 0; d < usageDepth; d++) { IList <IExpression> sizeBracket = Builder.ExprCollection(); IList <IExpression> varBracket = Builder.ExprCollection(); for (int i = 0; i < vi.sizes[d].Length; i++) { sizeBracket.Add(Builder.StaticMethod(new Func <int>(GateAnalysisTransform.AnyIndex))); varBracket.Add(Builder.VarRefExpr(vi.indexVars[d][i])); } prefixSizes.Add(sizeBracket); prefixVars.Add(varBracket); } string prefix = vi.Name; if (prefix.EndsWith("_use")) { prefix = prefix.Substring(0, prefix.Length - 4); } string arrayName = VariableInformation.GenerateName(context, prefix + "_uses"); IVariableDeclaration usesDecl = vi.DeriveArrayVariable(stmts, context, arrayName, Builder.LiteralExpr(useCount), Builder.VarDecl("_ind", typeof(int)), prefixSizes, prefixVars, useLiteralIndices: true); context.OutputAttributes.Remove <ChannelInfo>(usesDecl); context.OutputAttributes.Add(usesDecl, new DescriptionAttribute($"uses of '{vi.Name}'")); ChannelInfo ci = ChannelInfo.UseChannel(vi); ci.decl = usesDecl; context.OutputAttributes.Set(usesDecl, ci); VariableToChannelInformation vtci = new VariableToChannelInformation(); vtci.usesDecl = usesDecl; vtci.usageDepth = usageDepth; usesOfVariable[ivd] = vtci; return(vtci); }
protected IExpression ConvertWithReplication(IExpression expr) { IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // Check if this is an index local variable if (baseVar == null) return expr; // Check if the variable is stochastic if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr; // Get the loop context for this variable LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar); if (lc == null) { Error("Loop context not found for '" + baseVar.Name + "'."); return expr; } // Get the reference loop context for this expression RefLoopContext rlc = lc.GetReferenceLoopContext(context); // If the reference is in the same loop context as the declaration, do nothing. if (rlc.loops.Count == 0) return expr; // the set of loop variables that are constant wrt the expr Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>(); constantLoopVars.AddRange(lc.loopVariables); // collect set of all loop variable indices in the expression Set<int> embeddedLoopIndices = new Set<int>(); List<IList<IExpression>> brackets = Recognizer.GetIndices(expr); foreach (IList<IExpression> bracket in brackets) { foreach (IExpression index in bracket) { IExpression indExpr = index; if (indExpr is IBinaryExpression ibe) { indExpr = ibe.Left; } IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr); if (indVar != null) { if (!constantLoopVars.Contains(indVar)) { int loopIndex = rlc.loopVariables.IndexOf(indVar); if (loopIndex != -1) { // indVar is a loop variable constantLoopVars.Add(rlc.loopVariables[loopIndex]); } else { // indVar is not a loop variable LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar); foreach (var ivd in lc2.loopVariables) { if (!constantLoopVars.Contains(ivd)) { int loopIndex2 = rlc.loopVariables.IndexOf(ivd); if (loopIndex2 != -1) embeddedLoopIndices.Add(loopIndex2); else Error($"Index {ivd} is not in {rlc} for expression {expr}"); } } } } } else { foreach(var ivd in Recognizer.GetVariables(indExpr)) { if (!constantLoopVars.Contains(ivd)) { // copied from above LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd); foreach (var ivd2 in lc2.loopVariables) { if (!constantLoopVars.Contains(ivd2)) { int loopIndex2 = rlc.loopVariables.IndexOf(ivd2); if (loopIndex2 != -1) embeddedLoopIndices.Add(loopIndex2); else Error($"Index {ivd2} is not in {rlc} for expression {expr}"); } } } } } } } // Find loop variables that must be constant due to condition statements. List<IStatement> ancestors = context.FindAncestors<IStatement>(); foreach (IStatement ancestor in ancestors) { if (!(ancestor is IConditionStatement ics)) continue; ConditionBinding binding = new ConditionBinding(ics.Condition); IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs); IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs); int index = rlc.loopVariables.IndexOf(ivd); if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars)) { constantLoopVars.Add(ivd); continue; } int index2 = rlc.loopVariables.IndexOf(ivd2); if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars)) { constantLoopVars.Add(ivd2); continue; } } // Determine if this expression is being defined (is on the LHS of an assignment) bool isDef = Recognizer.IsBeingMutated(context, expr); Containers containers = context.InputAttributes.Get<Containers>(baseVar); IExpression originalExpr = expr; for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++) { IVariableDeclaration loopVar = rlc.loopVariables[currentLoop]; if (constantLoopVars.Contains(loopVar)) continue; IForStatement loop = rlc.loops[currentLoop]; // must replicate across this loop. if (isDef) { Error("Cannot re-define a variable in a loop. Variables on the left hand side of an assignment must be indexed by all containing loops."); continue; } if (embeddedLoopIndices.Contains(currentLoop)) { string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays."; Warning(string.Format(warningText, originalExpr, loopVar.Name)); } // split expr into a target and extra indices, where target will be replicated and extra indices will be added later var extraIndices = new List<IEnumerable<IExpression>>(); AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); IExpression loopSize = Recognizer.LoopSizeExpression(loop); IList<IStatement> stmts = Builder.StmtCollection(); List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate); IVariableDeclaration newIndexVar = loopVar; // if loopVar is already an indexVar of varInfo, create a new variable if (varInfo.HasIndexVar(loopVar)) { newIndexVar = VariableInformation.GenerateLoopVar(context, "_a"); context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar); } IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"), loopSize, newIndexVar, inds, useArrays: true); if (!context.InputAttributes.Has<DerivedVariable>(repVar)) context.OutputAttributes.Set(repVar, new DerivedVariable()); if (context.InputAttributes.Has<ChannelInfo>(baseVar)) { VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar); ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo); ci.decl = repVar; context.OutputAttributes.Set(repVar, ci); } // Create replicate factor Type returnType = Builder.ToType(repVar.VariableType); IMethodInvokeExpression repMethod = Builder.StaticGenericMethod( new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate), new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize); IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod); // Copy attributes across from variable to replication expression context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod); context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod); context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod); stmts.Add(Builder.ExprStatement(assignExpression)); // add any containers missing from context. containers = new Containers(context); // RemoveUnusedLoops will also remove conditionals involving those loop variables. // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms. containers = Containers.RemoveUnusedLoops(containers, context, repMethod); if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context); //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count); //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; } if (containers.Contains(loop)) { Error("Internal: invalid containers for replicating " + baseVar); break; } int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.inputs); context.OutputAttributes.Set(repVar, containers); List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex); foreach (IStatement container in missing.inputs) { if (container is IForStatement ifs) loops.Add(ifs); } context.OutputAttributes.Set(repVar, new LoopContext(loops)); // must convert the output since it may contain 'if' conditions context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true); baseVar = repVar; expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar)); expr = Builder.JaggedArrayIndex(expr, extraIndices); } return expr; }
protected void ProcessAssign(IExpression target, IExpression rhs, ref bool shouldDelete) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target); if (ivd == null) { return; } if (rhs is IArrayCreateExpression) { IArrayCreateExpression iace = (IArrayCreateExpression)rhs; bool zeroLength = iace.Dimensions.All(dimExpr => (dimExpr is ILiteralExpression) && ((ILiteralExpression)dimExpr).Value.Equals(0)); if (!zeroLength && iace.Initializer == null) { return; // variable will have assignments to elements } } bool firstTime = !variablesAssigned.Contains(ivd); variablesAssigned.Add(ivd); bool isInferred = context.InputAttributes.Has <IsInferred>(ivd); bool isStochastic = CodeRecognizer.IsStochastic(context, ivd); if (!isStochastic) { return; } VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); Containers defContainers = context.InputAttributes.Get <Containers>(ivd); int ancIndex = defContainers.GetMatchingAncestorIndex(context); Containers missing = defContainers.GetContainersNotInContext(context, ancIndex); // definition of a stochastic variable IExpression lhs = target; if (lhs is IVariableDeclarationExpression) { lhs = Builder.VarRefExpr(ivd); } IExpression defExpr = lhs; if (firstTime && isStochastic) { // Create a ChannelInfo attribute for use by later transforms, e.g. MessageTransform ChannelInfo defChannel = ChannelInfo.DefChannel(vi); defChannel.decl = ivd; context.OutputAttributes.Set(ivd, defChannel); } bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd); IAlgorithm algorithm = this.algorithmDefault; Algorithm algAttr = context.InputAttributes.Get <Algorithm>(ivd); if (algAttr != null) { algorithm = algAttr.algorithm; } if (algorithm is VariationalMessagePassing && ((VariationalMessagePassing)algorithm).UseDerivMessages && isDerived && firstTime) { vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); IVariableDeclaration derivDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_deriv"); context.OutputAttributes.Set(ivd, new DerivMessage(derivDecl)); ChannelInfo derivChannel = ChannelInfo.DefChannel(vi); derivChannel.decl = derivDecl; context.OutputAttributes.Set(derivChannel.decl, derivChannel); context.OutputAttributes.Set(derivChannel.decl, new DescriptionAttribute("deriv of '" + ivd.Name + "'")); // Add the declarations stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); } bool isPointEstimate = context.InputAttributes.Has <PointEstimate>(ivd); if (this.analysis.variablesExcludingVariableFactor.Contains(ivd)) { this.variablesLackingVariableFactor.Add(ivd); // ivd will get a marginal channel in ConvertMethodInvoke useOfVariable[ivd] = ivd; return; } if (isDerived && !isInferred && !isPointEstimate) { return; } IExpression useExpr2 = null; if (firstTime) { // create marginal and use channels vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); CreateMarginalChannel(ivd, vi, stmts); if (isStochastic) { CreateUseChannel(ivd, vi, stmts); context.InputAttributes.Set(useOfVariable[ivd], defContainers); } // Add the declarations stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); } if (isStochastic && !useOfVariable.ContainsKey(ivd)) { Error("cannot find use channel of " + ivd); return; } IExpression marginalExpr = Builder.ReplaceVariable(lhs, ivd, marginalOfVariable[ivd]); IExpression useExpr = isStochastic ? Builder.ReplaceVariable(lhs, ivd, useOfVariable[ivd]) : marginalExpr; InitialiseTo it = context.InputAttributes.Get <InitialiseTo>(ivd); Type[] genArgs = new Type[] { defExpr.GetExpressionType() }; if (rhs is IMethodInvokeExpression) { IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs; if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy)) && ancIndex < context.InputStack.Count - 2) { IExpression arg = imie.Arguments[0]; IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(arg); if (ivd2 != null && context.InputAttributes.Get <MarginalPrototype>(ivd) == context.InputAttributes.Get <MarginalPrototype>(ivd2)) { // if a variable is a copy, use the original expression since it will give more precise dependencies. defExpr = arg; shouldDelete = true; bool makeClone = false; if (makeClone) { VariableInformation vi2 = VariableInformation.GetVariableInformation(context, ivd2); IList <IStatement> stmts = Builder.StmtCollection(); List <IList <IExpression> > indices = Recognizer.GetIndices(defExpr); IVariableDeclaration useDecl2 = vi2.DeriveIndexedVariable(stmts, context, ivd2.Name + "_use", indices); useExpr2 = Builder.VarRefExpr(useDecl2); Containers defContainers2 = context.InputAttributes.Get <Containers>(ivd2); int ancIndex2 = defContainers2.GetMatchingAncestorIndex(context); Containers missing2 = defContainers2.GetContainersNotInContext(context, ancIndex2); stmts = Containers.WrapWithContainers(stmts, missing2.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex2, stmts); context.InputAttributes.Set(useDecl2, defContainers2); // TODO: call CreateUseChannel ChannelInfo usageChannel = ChannelInfo.UseChannel(vi2); usageChannel.decl = useDecl2; context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl2); context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl2); context.OutputAttributes.Set(useDecl2, usageChannel); //context.OutputAttributes.Set(useDecl2, new DescriptionAttribute("use of '" + ivd.Name + "'")); context.OutputAttributes.Remove <InitialiseTo>(vi.declaration); IExpression copyExpr = Builder.StaticGenericMethod( new Func <PlaceHolder, PlaceHolder>(Clone.Copy), genArgs, useExpr2); var copyStmt = Builder.AssignStmt(useExpr, copyExpr); context.AddStatementAfterCurrent(copyStmt); } } } } // Add the variable factor IExpression variableFactorExpr; bool isGateExitRandom = context.InputAttributes.Has <VariationalMessagePassing.GateExitRandomVariable>(ivd); if (isGateExitRandom) { variableFactorExpr = Builder.StaticGenericMethod( new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Gate.ExitingVariable), genArgs, defExpr, marginalExpr); } else { Delegate d = algorithm.GetVariableFactor(isDerived, it != null); if (isPointEstimate) { d = new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint); } if (it == null) { variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, marginalExpr); } else { IExpression initExpr = Builder.ReplaceExpression(lhs, Builder.VarRefExpr(ivd), it.initialMessagesExpression); variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, initExpr, marginalExpr); } } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr); context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, variableFactorExpr); if (isStochastic) { context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor()); } var assignStmt = Builder.AssignStmt(useExpr2 == null ? useExpr : useExpr2, variableFactorExpr); context.AddStatementAfterCurrent(assignStmt); }
/// <summary> /// Converts a variable declaration by creating definition, marginal and uses channel variables. /// </summary> protected override IExpression ConvertVariableDeclExpr(IVariableDeclarationExpression ivde) { IVariableDeclaration ivd = ivde.Variable; VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); // If the variable is deterministic, return if (!vi.IsStochastic) { ProcessConstant(ivd); context.OutputAttributes.Set(ivd, new DescriptionAttribute("The constant '" + ivd.Name + "'")); return(ivde); } bool suppressVariableFactor = context.InputAttributes.Has <SuppressVariableFactor>(ivd); bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd); bool isConstant = false; bool isInferred = context.InputAttributes.Has <IsInferred>(ivd); int useCount; ChannelAnalysisTransform.UsageInfo info; if (!analysis.usageInfo.TryGetValue(ivd, out info)) { useCount = 0; } else { useCount = info.NumberOfUsesOld; } if (!(algorithm is Algorithms.GibbsSampling) && !isConstant && !suppressVariableFactor && (useCount == 1) && !isInferred && isDerived) { // this is optional suppressVariableFactor = true; context.InputAttributes.Set(ivd, new SuppressVariableFactor()); } context.InputAttributes.Remove <LoopContext>(ivd); context.InputAttributes.Set(ivd, new LoopContext(context)); // Create variable-to-channel information for the variable. VariableToChannelInformation vtc = new VariableToChannelInformation(); vtc.shareAllUses = (useCount == 1); Context.InputAttributes.Set(ivd, vtc); // Ensure the marginal prototype is set. MarginalPrototype mpa = Context.InputAttributes.Get <MarginalPrototype>(ivd); try { vi.SetMarginalPrototypeFromAttribute(mpa); } catch (ArgumentException ex) { Error(ex.Message); } // Create the definition channel vtc.defChannel = ChannelInfo.DefChannel(vi); vtc.defChannel.decl = ivd; // Always create a variable factor for a stochastic variable if (!isConstant && !suppressVariableFactor) { vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); // Create marginal channel vtc.marginalChannel = ChannelInfo.MarginalChannel(vi); vtc.marginalChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_marginal"); context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, vtc.marginalChannel.decl); context.OutputAttributes.Set(vtc.marginalChannel.decl, vtc.marginalChannel); context.OutputAttributes.Set(vtc.marginalChannel.decl, new DescriptionAttribute("marginal of '" + ivd.Name + "'")); SetMarginalPrototype(vtc.marginalChannel.decl); if (algorithm is GibbsSampling && ((GibbsSampling)algorithm).UseSideChannels) { Type marginalType = MessageTransform.GetDistributionType(vi.varType, vi.InnermostElementType, vi.marginalPrototypeExpression.GetExpressionType(), true); Type domainType = ivd.VariableType.DotNetType; vtc.samplesChannel = ChannelInfo.MarginalChannel(vi); vtc.samplesChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_samples"); context.OutputAttributes.Remove <InitialiseTo>(vtc.samplesChannel.decl); context.OutputAttributes.Set(vtc.samplesChannel.decl, vtc.samplesChannel); context.OutputAttributes.Set(vtc.samplesChannel.decl, new DescriptionAttribute("samples of '" + ivd.Name + "'")); Type samplesType = typeof(List <>).MakeGenericType(domainType); IExpression samples_mpe = Builder.NewObject(samplesType); VariableInformation samples_vi = VariableInformation.GetVariableInformation(context, vtc.samplesChannel.decl); samples_vi.marginalPrototypeExpression = samples_mpe; vtc.conditionalsChannel = ChannelInfo.MarginalChannel(vi); vtc.conditionalsChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_conditionals"); context.OutputAttributes.Remove <InitialiseTo>(vtc.conditionalsChannel.decl); context.OutputAttributes.Set(vtc.conditionalsChannel.decl, vtc.conditionalsChannel); context.OutputAttributes.Set(vtc.conditionalsChannel.decl, new DescriptionAttribute("conditionals of '" + ivd.Name + "'")); Type conditionalsType = typeof(List <>).MakeGenericType(marginalType); IExpression conditionals_mpe = Builder.NewObject(conditionalsType); VariableInformation conditionals_vi = VariableInformation.GetVariableInformation(context, vtc.conditionalsChannel.decl); conditionals_vi.marginalPrototypeExpression = conditionals_mpe; } else { vtc.samplesChannel = vtc.marginalChannel; vtc.conditionalsChannel = vtc.marginalChannel; } // Create uses channel vtc.usageChannel = ChannelInfo.UseChannel(vi); vtc.usageChannel.decl = vi.DeriveArrayVariable(stmts, context, vi.Name + "_uses", Builder.LiteralExpr(useCount), Builder.VarDecl("_ind", typeof(int)), useLiteralIndices: true); context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, vtc.usageChannel.decl); context.OutputAttributes.Set(vtc.usageChannel.decl, vtc.usageChannel); context.OutputAttributes.Set(vtc.usageChannel.decl, new DescriptionAttribute("uses of '" + ivd.Name + "'")); SetMarginalPrototype(vtc.usageChannel.decl); //setAllGroupRoots(context, ivd, false); context.AddStatementsBeforeCurrent(stmts); // Append usageDepth indices to def/marginal/use expressions IExpression defExpr = Builder.VarRefExpr(ivd); IExpression marginalExpr = Builder.VarRefExpr(vtc.marginalChannel.decl); IExpression usageExpr = Builder.VarRefExpr(vtc.usageChannel.decl); IExpression countExpr = Builder.LiteralExpr(useCount); // Add clone factor tying together all of the channels IMethodInvokeExpression usesEqualDefExpression; Type[] genArgs = new Type[] { vi.varType }; if (algorithm is GibbsSampling && ((GibbsSampling)algorithm).UseSideChannels) { GibbsSampling gs = (GibbsSampling)algorithm; IExpression burnInExpr = Builder.LiteralExpr(gs.BurnIn); IExpression thinExpr = Builder.LiteralExpr(gs.Thin); IExpression samplesExpr = Builder.VarRefExpr(vtc.samplesChannel.decl); IExpression conditionalsExpr = Builder.VarRefExpr(vtc.conditionalsChannel.decl); if (isDerived) { Delegate d = new FuncOut3 <PlaceHolder, int, int, int, PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder[]>(Factor.ReplicateWithMarginalGibbs); usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, burnInExpr, thinExpr, marginalExpr, samplesExpr, conditionalsExpr); } else { Delegate d = new FuncOut3 <PlaceHolder, int, int, int, PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder[]>(Factor.UsesEqualDefGibbs); usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, burnInExpr, thinExpr, marginalExpr, samplesExpr, conditionalsExpr); } } else { Delegate d; if (isDerived) { d = new FuncOut <PlaceHolder, int, PlaceHolder, PlaceHolder[]>(Factor.ReplicateWithMarginal <PlaceHolder>); } else { d = new FuncOut <PlaceHolder, int, PlaceHolder, PlaceHolder[]>(Factor.UsesEqualDef <PlaceHolder>); } usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, marginalExpr); } if (isDerived) { context.OutputAttributes.Set(usesEqualDefExpression, new DerivedVariable()); // used by Gibbs } // Mark this as a pseudo-factor context.OutputAttributes.Set(usesEqualDefExpression, new IsVariableFactor()); if (useCount == 1) { context.OutputAttributes.Set(usesEqualDefExpression, new DivideMessages(false)); } else { context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, usesEqualDefExpression); } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, usesEqualDefExpression); IAssignExpression assignExpr = Builder.AssignExpr(usageExpr, usesEqualDefExpression); // Copy attributes across from input to output Context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, assignExpr); context.OutputAttributes.Remove <InitialiseTo>(ivd); if (vi.ArrayDepth == 0) { // Insert the UsesEqualDef statement after the declaration. // Note the variable will not have been defined yet. context.AddStatementAfterCurrent(Builder.ExprStatement(assignExpr)); } else { // For an array, the UsesEqualDef statement should be inserted after the array is allocated. // Store the statement for later use by ConvertArrayCreate. context.InputAttributes.Remove <LoopContext>(ivd); context.InputAttributes.Set(ivd, new LoopContext(context)); context.InputAttributes.Remove <Containers>(ivd); context.InputAttributes.Set(ivd, new Containers(context)); vtc.usesEqualDefsStatements = Builder.StmtCollection(); vtc.usesEqualDefsStatements.Add(Builder.ExprStatement(assignExpr)); } } // These must be set after the above or they will be copied to the other channels context.OutputAttributes.Set(ivd, vtc.defChannel); context.OutputAttributes.Set(vtc.defChannel.decl, new DescriptionAttribute("definition of '" + ivd.Name + "'")); return(ivde); }
protected IExpression ConvertWithReplication(IExpression expr) { IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // Check if this is an index local variable if (baseVar == null) { return(expr); } // Check if the variable is stochastic if (!CodeRecognizer.IsStochastic(context, baseVar)) { return(expr); } if (cutVariables.Contains(baseVar)) { return(expr); } // Get the repeat context for this variable RepeatContext lc = context.InputAttributes.Get <RepeatContext>(baseVar); if (lc == null) { Error("Repeat context not found for '" + baseVar.Name + "'."); return(expr); } // Get the reference loop context for this expression var rlc = lc.GetReferenceRepeatContext(context); // If the reference is in the same loop context as the declaration, do nothing. if (rlc.repeats.Count == 0) { return(expr); } // Determine if this expression is being defined (is on the LHS of an assignment) bool isDef = Recognizer.IsBeingMutated(context, expr); Containers containers = context.InputAttributes.Get <Containers>(baseVar); for (int currentRepeat = 0; currentRepeat < rlc.repeatCounts.Count; currentRepeat++) { IExpression repeatCount = rlc.repeatCounts[currentRepeat]; IRepeatStatement repeat = rlc.repeats[currentRepeat]; // must replicate across this loop. if (isDef) { Error("Cannot re-define a variable in a repeat block."); continue; } // are we replicating the argument of Gate.Cases? IMethodInvokeExpression imie = context.FindAncestor <IMethodInvokeExpression>(); if (imie != null) { if (Recognizer.IsStaticMethod(imie, typeof(Gate), "Cases")) { Error("'if(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'"); } if (Recognizer.IsStaticMethod(imie, typeof(Gate), "CasesInt")) { Error("'case(" + expr + ")' or 'switch(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'"); } } VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); IList <IStatement> stmts = Builder.StmtCollection(); List <IList <IExpression> > inds = Recognizer.GetIndices(expr); IVariableDeclaration repVar = varInfo.DeriveIndexedVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rpt"), inds); if (!context.InputAttributes.Has <DerivedVariable>(repVar)) { context.OutputAttributes.Set(repVar, new DerivedVariable()); } if (context.InputAttributes.Has <ChannelInfo>(baseVar)) { VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar); ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo); ci.decl = repVar; context.OutputAttributes.Set(repVar, ci); } // set the RepeatContext of repVar to include all repeats up to this one (so that it doesn't get Entered again) List <IRepeatStatement> repeats = new List <IRepeatStatement>(lc.repeats); for (int i = 0; i <= currentRepeat; i++) { repeats.Add(rlc.repeats[i]); } context.OutputAttributes.Remove <RepeatContext>(repVar); context.OutputAttributes.Set(repVar, new RepeatContext(repeats)); // Create replicate factor Type returnType = Builder.ToType(repVar.VariableType); IMethodInvokeExpression powerPlateMethod = Builder.StaticGenericMethod( new Func <PlaceHolder, double, PlaceHolder>(PowerPlate.Enter), new Type[] { returnType }, expr, repeatCount); IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), powerPlateMethod); // Copy attributes across from variable to replication expression context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, powerPlateMethod); context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(baseVar, context.OutputAttributes, powerPlateMethod); context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(baseVar, context.OutputAttributes, powerPlateMethod); stmts.Add(Builder.ExprStatement(assignExpression)); // add any containers missing from context. containers = new Containers(context); // remove inner repeats for (int i = currentRepeat + 1; i < rlc.repeatCounts.Count; i++) { containers = containers.RemoveOneRepeat(rlc.repeats[i]); } context.OutputAttributes.Set(repVar, containers); containers = containers.RemoveOneRepeat(repeat); containers = Containers.RemoveUnusedLoops(containers, context, powerPlateMethod); if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar)) { containers = Containers.RemoveStochasticConditionals(containers, context); } //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count); //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; } int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.inputs); // must convert the output since it may contain 'if' conditions context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true); baseVar = repVar; expr = Builder.VarRefExpr(repVar); } return(expr); }
private IExpression GetClone(IExpression expr) { // Determine if this is a definition i.e. the variable is on the left hand side of an assignment // This must be done before base.ConvertArrayIndexer changes the expression! if (Recognizer.IsBeingIndexed(context)) { return(expr); } IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // If not an indexed variable reference, skip it (e.g. an indexed argument reference) if (baseVar == null) { return(expr); } DepthInfo depthInfo; if (!analysis.depthInfos.TryGetValue(baseVar, out depthInfo)) { return(expr); } if (depthInfo.useCount <= 1 || depthInfo.indexInfoOfDepth.Count == 1) { return(expr); } bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(baseVar); if (isEvidenceVar && !cloneEvidenceVars) { return(expr); } int depth = Recognizer.GetIndexingDepth(expr); IndexInfo info = depthInfo.indexInfoOfDepth[depth]; if (info.clone == null) { if (depth == depthInfo.definitionDepth) { return(expr); } VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); string name = baseVar.Name + "_depth" + depth; IList <IStatement> stmts = Builder.StmtCollection(); // declare the clone varInfo.DefineAllIndexVars(context); IVariableDeclaration newvd = varInfo.DeriveIndexedVariable(stmts, context, name); VariableInformation newVariableInformation = VariableInformation.GetVariableInformation(context, newvd); newVariableInformation.LiteralIndexingDepth = info.literalIndexingDepth; if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } if (context.InputAttributes.Has <ChannelInfo>(baseVar)) { ChannelInfo ci = ChannelInfo.UseChannel(varInfo); ci.decl = newvd; context.OutputAttributes.Set(newvd, ci); } // define the clone // if depth < definitionDepth, index by the definitionDepth // else index by depth // e.g. // x[i] = definition // x_depth0[i] = Copy(x[i]) // x_depth2[i][j] = Copy(x[i][j]) int indexingDepth = System.Math.Max(depth, depthInfo.definitionDepth); IExpression lhs = Builder.VarRefExpr(newvd); // TODO: clone the next lower depth, not always the baseVar IExpression rhs = Builder.VarRefExpr(baseVar); AddCopyStatements(stmts, newVariableInformation, indexingDepth, lhs, rhs); // Reduce memory consumption by declaring the clone outside of unnecessary loops. // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item. Containers containers = info.containers; containers = Containers.RemoveUnusedLoops(containers, context, Builder.VarRefExpr(baseVar)); if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar)) { containers = Containers.RemoveStochasticConditionals(containers, context); } // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible, // then wrap the declaration with the remaining containers. int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); context.InputAttributes.Set(newvd, containers); info.clone = Builder.VarRefExpr(newvd); } List <IList <IExpression> > indices = Recognizer.GetIndices(expr); return(Builder.JaggedArrayIndex(info.clone, indices)); }