/// <summary> /// Converts a variable declaration by creating definition, marginal and uses channel variables. /// </summary> protected override IExpression ConvertVariableDeclExpr(IVariableDeclarationExpression ivde) { IVariableDeclaration ivd = ivde.Variable; VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); // If the variable is deterministic, return if (!vi.IsStochastic) { ProcessConstant(ivd); context.OutputAttributes.Set(ivd, new DescriptionAttribute("The constant '" + ivd.Name + "'")); return(ivde); } bool suppressVariableFactor = context.InputAttributes.Has <SuppressVariableFactor>(ivd); bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd); bool isConstant = false; bool isInferred = context.InputAttributes.Has <IsInferred>(ivd); int useCount; ChannelAnalysisTransform.UsageInfo info; if (!analysis.usageInfo.TryGetValue(ivd, out info)) { useCount = 0; } else { useCount = info.NumberOfUsesOld; } if (!(algorithm is Algorithms.GibbsSampling) && !isConstant && !suppressVariableFactor && (useCount == 1) && !isInferred && isDerived) { // this is optional suppressVariableFactor = true; context.InputAttributes.Set(ivd, new SuppressVariableFactor()); } context.InputAttributes.Remove <LoopContext>(ivd); context.InputAttributes.Set(ivd, new LoopContext(context)); // Create variable-to-channel information for the variable. VariableToChannelInformation vtc = new VariableToChannelInformation(); vtc.shareAllUses = (useCount == 1); Context.InputAttributes.Set(ivd, vtc); // Ensure the marginal prototype is set. MarginalPrototype mpa = Context.InputAttributes.Get <MarginalPrototype>(ivd); try { vi.SetMarginalPrototypeFromAttribute(mpa); } catch (ArgumentException ex) { Error(ex.Message); } // Create the definition channel vtc.defChannel = ChannelInfo.DefChannel(vi); vtc.defChannel.decl = ivd; // Always create a variable factor for a stochastic variable if (!isConstant && !suppressVariableFactor) { vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); // Create marginal channel vtc.marginalChannel = ChannelInfo.MarginalChannel(vi); vtc.marginalChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_marginal"); context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, vtc.marginalChannel.decl); context.OutputAttributes.Set(vtc.marginalChannel.decl, vtc.marginalChannel); context.OutputAttributes.Set(vtc.marginalChannel.decl, new DescriptionAttribute("marginal of '" + ivd.Name + "'")); SetMarginalPrototype(vtc.marginalChannel.decl); if (algorithm is GibbsSampling && ((GibbsSampling)algorithm).UseSideChannels) { Type marginalType = MessageTransform.GetDistributionType(vi.varType, vi.InnermostElementType, vi.marginalPrototypeExpression.GetExpressionType(), true); Type domainType = ivd.VariableType.DotNetType; vtc.samplesChannel = ChannelInfo.MarginalChannel(vi); vtc.samplesChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_samples"); context.OutputAttributes.Remove <InitialiseTo>(vtc.samplesChannel.decl); context.OutputAttributes.Set(vtc.samplesChannel.decl, vtc.samplesChannel); context.OutputAttributes.Set(vtc.samplesChannel.decl, new DescriptionAttribute("samples of '" + ivd.Name + "'")); Type samplesType = typeof(List <>).MakeGenericType(domainType); IExpression samples_mpe = Builder.NewObject(samplesType); VariableInformation samples_vi = VariableInformation.GetVariableInformation(context, vtc.samplesChannel.decl); samples_vi.marginalPrototypeExpression = samples_mpe; vtc.conditionalsChannel = ChannelInfo.MarginalChannel(vi); vtc.conditionalsChannel.decl = vi.DeriveIndexedVariable(stmts, context, vi.Name + "_conditionals"); context.OutputAttributes.Remove <InitialiseTo>(vtc.conditionalsChannel.decl); context.OutputAttributes.Set(vtc.conditionalsChannel.decl, vtc.conditionalsChannel); context.OutputAttributes.Set(vtc.conditionalsChannel.decl, new DescriptionAttribute("conditionals of '" + ivd.Name + "'")); Type conditionalsType = typeof(List <>).MakeGenericType(marginalType); IExpression conditionals_mpe = Builder.NewObject(conditionalsType); VariableInformation conditionals_vi = VariableInformation.GetVariableInformation(context, vtc.conditionalsChannel.decl); conditionals_vi.marginalPrototypeExpression = conditionals_mpe; } else { vtc.samplesChannel = vtc.marginalChannel; vtc.conditionalsChannel = vtc.marginalChannel; } // Create uses channel vtc.usageChannel = ChannelInfo.UseChannel(vi); vtc.usageChannel.decl = vi.DeriveArrayVariable(stmts, context, vi.Name + "_uses", Builder.LiteralExpr(useCount), Builder.VarDecl("_ind", typeof(int)), useLiteralIndices: true); context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, vtc.usageChannel.decl); context.OutputAttributes.Set(vtc.usageChannel.decl, vtc.usageChannel); context.OutputAttributes.Set(vtc.usageChannel.decl, new DescriptionAttribute("uses of '" + ivd.Name + "'")); SetMarginalPrototype(vtc.usageChannel.decl); //setAllGroupRoots(context, ivd, false); context.AddStatementsBeforeCurrent(stmts); // Append usageDepth indices to def/marginal/use expressions IExpression defExpr = Builder.VarRefExpr(ivd); IExpression marginalExpr = Builder.VarRefExpr(vtc.marginalChannel.decl); IExpression usageExpr = Builder.VarRefExpr(vtc.usageChannel.decl); IExpression countExpr = Builder.LiteralExpr(useCount); // Add clone factor tying together all of the channels IMethodInvokeExpression usesEqualDefExpression; Type[] genArgs = new Type[] { vi.varType }; if (algorithm is GibbsSampling && ((GibbsSampling)algorithm).UseSideChannels) { GibbsSampling gs = (GibbsSampling)algorithm; IExpression burnInExpr = Builder.LiteralExpr(gs.BurnIn); IExpression thinExpr = Builder.LiteralExpr(gs.Thin); IExpression samplesExpr = Builder.VarRefExpr(vtc.samplesChannel.decl); IExpression conditionalsExpr = Builder.VarRefExpr(vtc.conditionalsChannel.decl); if (isDerived) { Delegate d = new FuncOut3 <PlaceHolder, int, int, int, PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder[]>(Factor.ReplicateWithMarginalGibbs); usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, burnInExpr, thinExpr, marginalExpr, samplesExpr, conditionalsExpr); } else { Delegate d = new FuncOut3 <PlaceHolder, int, int, int, PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder[]>(Factor.UsesEqualDefGibbs); usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, burnInExpr, thinExpr, marginalExpr, samplesExpr, conditionalsExpr); } } else { Delegate d; if (isDerived) { d = new FuncOut <PlaceHolder, int, PlaceHolder, PlaceHolder[]>(Factor.ReplicateWithMarginal <PlaceHolder>); } else { d = new FuncOut <PlaceHolder, int, PlaceHolder, PlaceHolder[]>(Factor.UsesEqualDef <PlaceHolder>); } usesEqualDefExpression = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, marginalExpr); } if (isDerived) { context.OutputAttributes.Set(usesEqualDefExpression, new DerivedVariable()); // used by Gibbs } // Mark this as a pseudo-factor context.OutputAttributes.Set(usesEqualDefExpression, new IsVariableFactor()); if (useCount == 1) { context.OutputAttributes.Set(usesEqualDefExpression, new DivideMessages(false)); } else { context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, usesEqualDefExpression); } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, usesEqualDefExpression); IAssignExpression assignExpr = Builder.AssignExpr(usageExpr, usesEqualDefExpression); // Copy attributes across from input to output Context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, assignExpr); context.OutputAttributes.Remove <InitialiseTo>(ivd); if (vi.ArrayDepth == 0) { // Insert the UsesEqualDef statement after the declaration. // Note the variable will not have been defined yet. context.AddStatementAfterCurrent(Builder.ExprStatement(assignExpr)); } else { // For an array, the UsesEqualDef statement should be inserted after the array is allocated. // Store the statement for later use by ConvertArrayCreate. context.InputAttributes.Remove <LoopContext>(ivd); context.InputAttributes.Set(ivd, new LoopContext(context)); context.InputAttributes.Remove <Containers>(ivd); context.InputAttributes.Set(ivd, new Containers(context)); vtc.usesEqualDefsStatements = Builder.StmtCollection(); vtc.usesEqualDefsStatements.Add(Builder.ExprStatement(assignExpr)); } } // These must be set after the above or they will be copied to the other channels context.OutputAttributes.Set(ivd, vtc.defChannel); context.OutputAttributes.Set(vtc.defChannel.decl, new DescriptionAttribute("definition of '" + ivd.Name + "'")); return(ivde); }
/// <summary> /// Converts a variable declaration by creating definition, marginal and uses channel variables. /// </summary> protected override IExpression ConvertVariableDeclExpr(IVariableDeclarationExpression ivde) { IVariableDeclaration ivd = ivde.Variable; VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); // If the variable is deterministic, return bool isInferred = context.InputAttributes.Has <IsInferred>(ivd); if (!vi.IsStochastic) { return(ivde); } bool suppressVariableFactor = context.InputAttributes.Has <SuppressVariableFactor>(ivd); bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd); bool isStochastic = vi.IsStochastic; int useCount; ChannelAnalysisTransform.UsageInfo info; if (!analysis.usageInfo.TryGetValue(ivd, out info)) { useCount = 0; } else { useCount = info.NumberOfUsesOld; } if (!(algorithm is Algorithms.GibbsSampling) && isStochastic && !suppressVariableFactor && (useCount <= 1) && !isInferred && isDerived) { // this is optional suppressVariableFactor = true; context.InputAttributes.Set(ivd, new SuppressVariableFactor()); } context.InputAttributes.Remove <LoopContext>(ivd); context.InputAttributes.Set(ivd, new LoopContext(context)); // Create variable-to-channel information for the variable. VariableToChannelInformation vtci = new VariableToChannelInformation(); vtci.shareAllUses = (useCount <= 1); Context.InputAttributes.Set(ivd, vtci); // Create the definition channel vtci.defChannel = ChannelInfo.DefChannel(vi); vtci.defChannel.decl = ivd; // Always create a variable factor for a stochastic variable if (isInferred || (isStochastic && !suppressVariableFactor)) { vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); // Create marginal channel CreateMarginalChannel(vi, vtci, stmts); if (algorithm is GibbsSampling && ((GibbsSampling)algorithm).UseSideChannels) { CreateSamplesChannel(vi, vtci, stmts); CreateConditionalsChannel(vi, vtci, stmts); } else { vtci.samplesChannel = vtci.marginalChannel; vtci.conditionalsChannel = vtci.marginalChannel; } if (isStochastic) { // Create uses channel CreateUsesChannel(vi, useCount, vtci, stmts); } //setAllGroupRoots(context, ivd, false); context.AddStatementsBeforeCurrent(stmts); // Append usageDepth indices to def/marginal/use expressions IExpression defExpr = Builder.VarRefExpr(ivd); IExpression marginalExpr = Builder.VarRefExpr(vtci.marginalChannel.decl); IExpression countExpr = Builder.LiteralExpr(useCount); // Add clone factor tying together all of the channels IMethodInvokeExpression variableFactorExpr; Type[] genArgs = new Type[] { vi.varType }; if (algorithm is GibbsSampling && ((GibbsSampling)algorithm).UseSideChannels) { GibbsSampling gs = (GibbsSampling)algorithm; IExpression burnInExpr = Builder.LiteralExpr(gs.BurnIn); IExpression thinExpr = Builder.LiteralExpr(gs.Thin); IExpression samplesExpr = Builder.VarRefExpr(vtci.samplesChannel.decl); IExpression conditionalsExpr = Builder.VarRefExpr(vtci.conditionalsChannel.decl); if (isDerived) { Delegate d = new FuncOut3 <PlaceHolder, int, int, int, PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder[]>(Clone.ReplicateWithMarginalGibbs); variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, burnInExpr, thinExpr, marginalExpr, samplesExpr, conditionalsExpr); } else { Delegate d = new FuncOut3 <PlaceHolder, int, int, int, PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder[]>(Clone.UsesEqualDefGibbs); variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, burnInExpr, thinExpr, marginalExpr, samplesExpr, conditionalsExpr); } } else { Delegate d; if (isDerived) { d = new FuncOut <PlaceHolder, int, PlaceHolder, PlaceHolder[]>(Clone.ReplicateWithMarginal <PlaceHolder>); } else { d = new FuncOut <PlaceHolder, int, PlaceHolder, PlaceHolder[]>(Clone.UsesEqualDef <PlaceHolder>); } variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, countExpr, marginalExpr); } if (isDerived) { context.OutputAttributes.Set(variableFactorExpr, new DerivedVariable()); // used by Gibbs } // Mark this as a pseudo-factor context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor()); if (useCount <= 1) { context.OutputAttributes.Set(variableFactorExpr, new DivideMessages(false)); } else { context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, variableFactorExpr); } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr); if (vtci.usageChannel != null) { IExpression usageExpr = Builder.VarRefExpr(vtci.usageChannel.decl); IAssignExpression assignExpr = Builder.AssignExpr(usageExpr, variableFactorExpr); // Copy attributes across from input to output Context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, assignExpr); context.OutputAttributes.Remove <InitialiseTo>(ivd); if (vi.ArrayDepth == 0) { // Insert the UsesEqualDef statement after the declaration. // Note the variable will not have been defined yet. context.AddStatementAfterCurrent(Builder.ExprStatement(assignExpr)); } else { // For an array, the UsesEqualDef statement should be inserted after the array is allocated. // Store the statement for later use by ConvertArrayCreate. context.InputAttributes.Remove <LoopContext>(ivd); context.InputAttributes.Set(ivd, new LoopContext(context)); context.InputAttributes.Remove <Containers>(ivd); context.InputAttributes.Set(ivd, new Containers(context)); vtci.usesEqualDefsStatements = Builder.StmtCollection(); vtci.usesEqualDefsStatements.Add(Builder.ExprStatement(assignExpr)); } } } // These must be set after the above or they will be copied to the other channels if (isStochastic) { context.OutputAttributes.Set(ivd, vtci.defChannel); } if (!context.InputAttributes.Has <DescriptionAttribute>(vtci.defChannel.decl)) { context.OutputAttributes.Set(vtci.defChannel.decl, new DescriptionAttribute("definition of '" + ivd.Name + "'")); } return(ivde); }
private void BugsRats(bool initialiseAlpha, bool initialiseAlphaC) { Rand.Restart(0); double precOfGaussianPrior = 1.0E-6; double shapeRateOfGammaPrior = 0.02; // smallest choice that will avoid zeros double meanOfBetaPrior = 0.0; double meanOfAlphaPrior = 0.0; // The model int N = RatsHeightData.GetLength(0); int T = RatsHeightData.GetLength(1); double xbar = 22.0; double[] xDataZeroMean = new double[RatsXData.Length]; for (int i = 0; i < RatsXData.Length; i++) { xDataZeroMean[i] = RatsXData[i] - xbar; } Range r = new Range(N).Named("N"); Range w = new Range(T).Named("T"); VariableArray2D <double> y = Variable.Observed <double>(RatsHeightData, r, w).Named("y"); VariableArray <double> x = Variable.Observed <double>(xDataZeroMean, w).Named("x"); Variable <double> tauC = Variable.GammaFromShapeAndRate(shapeRateOfGammaPrior, shapeRateOfGammaPrior).Named("tauC"); Variable <double> alphaC = Variable.GaussianFromMeanAndPrecision(meanOfAlphaPrior, precOfGaussianPrior).Named("alphaC"); Variable <double> alphaTau = Variable.GammaFromShapeAndRate(shapeRateOfGammaPrior, shapeRateOfGammaPrior).Named("alphaTau"); Variable <double> betaC = Variable.GaussianFromMeanAndPrecision(meanOfBetaPrior, precOfGaussianPrior).Named("betaC"); Variable <double> betaTau = Variable.GammaFromShapeAndRate(shapeRateOfGammaPrior, shapeRateOfGammaPrior).Named("betaTau"); VariableArray <double> alpha = Variable.Array <double>(r).Named("alpha"); alpha[r] = Variable.GaussianFromMeanAndPrecision(alphaC, alphaTau).ForEach(r); VariableArray <double> beta = Variable.Array <double>(r).Named("beta"); beta[r] = Variable.GaussianFromMeanAndPrecision(betaC, betaTau).ForEach(r); VariableArray2D <double> mu = Variable.Array <double>(r, w).Named("mu"); VariableArray2D <double> betaX = Variable.Array <double>(r, w).Named("betax"); betaX[r, w] = beta[r] * x[w]; mu[r, w] = alpha[r] + betaX[r, w]; y[r, w] = Variable.GaussianFromMeanAndPrecision(mu[r, w], tauC); Variable <double> alpha0 = (alphaC - xbar * betaC).Named("alpha0"); InferenceEngine ie; GibbsSampling gs = new GibbsSampling(); // Initialise both alpha and beta together. // Initialising only alpha (or only beta) is not reliable because you could by chance get a large betaTau and small tauC to start, // at which point beta and alphaC become garbage, leading to alpha becoming garbage on the next iteration. bool initialiseBeta = initialiseAlpha; bool initialiseBetaC = initialiseAlphaC; if (initialiseAlpha) { Gaussian[] alphaInit = new Gaussian[N]; for (int i = 0; i < N; i++) { alphaInit[i] = Gaussian.FromMeanAndPrecision(250.0, 1.0); } alpha.InitialiseTo(Distribution <double> .Array(alphaInit)); } if (initialiseBeta) { Gaussian[] betaInit = new Gaussian[N]; for (int i = 0; i < N; i++) { betaInit[i] = Gaussian.FromMeanAndPrecision(6.0, 1.0); } beta.InitialiseTo(Distribution <double> .Array(betaInit)); } if (initialiseAlphaC) { alphaC.InitialiseTo(Gaussian.FromMeanAndVariance(250.0, 1.0)); } if (initialiseBetaC) { betaC.InitialiseTo(Gaussian.FromMeanAndVariance(6.0, 1.0)); } if (false) { //tauC.InitialiseTo(Gamma.FromMeanAndVariance(1.0, 0.1)); //alphaTau.InitialiseTo(Gamma.FromMeanAndVariance(1.0, 0.1)); //betaTau.InitialiseTo(Gamma.FromMeanAndVariance(1.0, 0.1)); } if (!initialiseAlpha && !initialiseBeta && !initialiseAlphaC && !initialiseBetaC) { gs.BurnIn = 1000; } ie = new InferenceEngine(gs); ie.ShowProgress = false; ie.ModelName = "BugsRats"; ie.NumberOfIterations = 4000; ie.OptimiseForVariables = new List <IVariable>() { alphaC, betaC, alpha0, tauC }; betaC.AddAttribute(QueryTypes.Marginal); betaC.AddAttribute(QueryTypes.Samples); alpha0.AddAttribute(QueryTypes.Marginal); alpha0.AddAttribute(QueryTypes.Samples); tauC.AddAttribute(QueryTypes.Marginal); tauC.AddAttribute(QueryTypes.Samples); // Inference object alphaCActual = ie.Infer(alphaC); Gaussian betaCMarg = ie.Infer <Gaussian>(betaC); Gaussian alpha0Marg = ie.Infer <Gaussian>(alpha0); Gamma tauCMarg = ie.Infer <Gamma>(tauC); // Check results against BUGS Gaussian betaCExpected = new Gaussian(6.185, System.Math.Pow(0.1068, 2)); Gaussian alpha0Expected = new Gaussian(106.6, System.Math.Pow(3.625, 2)); double sigmaMeanExpected = 6.082; double sigmaMean = System.Math.Sqrt(1.0 / tauCMarg.GetMean()); if (!initialiseAlpha && !initialiseAlphaC) { Debug.WriteLine("betaC = {0} should be {1}", betaCMarg, betaCExpected); Debug.WriteLine("alpha0 = {0} should be {1}", alpha0Marg, alpha0Expected); } Assert.True(GaussianDiff(betaCExpected, betaCMarg) < 0.1); Assert.True(GaussianDiff(alpha0Expected, alpha0Marg) < 0.1); Assert.True(MMath.AbsDiff(sigmaMeanExpected, sigmaMean, 0.1) < 0.1); IList <double> betaCSamples = ie.Infer <IList <double> >(betaC, QueryTypes.Samples); IList <double> alpha0Samples = ie.Infer <IList <double> >(alpha0, QueryTypes.Samples); IList <double> tauCSamples = ie.Infer <IList <double> >(tauC, QueryTypes.Samples); GaussianEstimator est = new GaussianEstimator(); foreach (double sample in betaCSamples) { est.Add(sample); } Gaussian betaCMarg2 = est.GetDistribution(new Gaussian()); Assert.True(GaussianDiff(betaCMarg, betaCMarg2) < 0.1); }