/// <summary> /// Builds the model necessary to infer marginals for the supplied variables and algorithm. /// </summary> /// <param name="engine">The inference algorithm being used</param> /// <param name="inferOnlySpecifiedVars">If true, inference will be restricted to only the variables given.</param> /// <param name="vars">Variables to infer.</param> /// <returns></returns> /// <remarks> /// Algorithm: starting from the variables to infer, we search through the graph to build up a "searched set". /// Each Variable and MethodInvoke in this set has an associated timestamp. /// We sort by timestamp, and then generate code. /// </remarks> public ITypeDeclaration Build(InferenceEngine engine, bool inferOnlySpecifiedVars, IEnumerable <IVariable> vars) { List <IStatementBlock> openBlocks = StatementBlock.GetOpenBlocks(); if (openBlocks.Count > 0) { throw new InvalidOperationException("The block " + openBlocks[0] + " has not been closed."); } Reset(); this.inferOnlySpecifiedVars = inferOnlySpecifiedVars; variablesToInfer.AddRange(vars); foreach (IVariable var in vars) { toSearch.Push(var); } while (toSearch.Count > 0) { IModelExpression expr = toSearch.Pop(); SearchExpressionUntyped(expr); } // lock in the set of model expressions. ModelExpressions = new List <IModelExpression>(searched); List <int> timestamps = new List <int>(); List <IModelExpression> exprs = new List <IModelExpression>(); foreach (IModelExpression expr in ModelExpressions) { if (expr is Variable var) { exprs.Add(var); timestamps.Add(var.timestamp); } else if (expr is MethodInvoke mi) { exprs.Add(mi); timestamps.Add(mi.timestamp); } } Collection.Sort(timestamps, exprs); foreach (IModelExpression expr in exprs) { BuildExpressionUntyped(expr); } foreach (IModelExpression expr in exprs) { FinishExpressionUntyped(expr, engine.Algorithm); } return(modelType); }
public void AddAttribute(ICompilerAttribute attr) { InferenceEngine.InvalidateAllEngines(this); attributes.Add(attr); }
static void Main(string[] args) { /********* arguments **********/ string dataDir = args[0]; string datasetFilename = args[1]; /******************************/ /************ data ************/ string[] lines = File.ReadAllLines(dataDir + datasetFilename); int numSamples = lines.Length; double[] x0Data = new double[numSamples]; double[] x1Data = new double[numSamples]; bool[] yData = new bool[numSamples]; for (int i = 0; i < numSamples; i++) { string[] strArray = lines[i].Split('|'); double[] doubleArray = Array.ConvertAll <string, double>(strArray, Convert.ToDouble); x0Data[i] = doubleArray[0]; x1Data[i] = doubleArray[1]; if (doubleArray[2] == 1) { yData[i] = true; } else { yData[i] = false; } } /********************************/ /********* model setup **********/ Range n = new Range(numSamples); M.Variable <bool> y = M.Variable.New <bool>(); M.Variable <double> x0 = M.Variable.New <double>(); M.Variable <double> x1 = M.Variable.New <double>(); double variance = 1.0; M.Variable <double> x0c0Mean = M.Variable.GaussianFromMeanAndVariance(0, 10); M.Variable <double> x0c1Mean = M.Variable.GaussianFromMeanAndVariance(0, 10); M.Variable <double> x1c0Mean = M.Variable.GaussianFromMeanAndVariance(0, 10); M.Variable <double> x1c1Mean = M.Variable.GaussianFromMeanAndVariance(0, 10); // stores the product of all messages sent the means by the previous batches M.Variable <D.Gaussian> x0c0MeanMessage = M.Variable.Observed <D.Gaussian>(D.Gaussian.Uniform()); M.Variable <D.Gaussian> x0c1MeanMessage = M.Variable.Observed <D.Gaussian>(D.Gaussian.Uniform()); M.Variable <D.Gaussian> x1c0MeanMessage = M.Variable.Observed <D.Gaussian>(D.Gaussian.Uniform()); M.Variable <D.Gaussian> x1c1MeanMessage = M.Variable.Observed <D.Gaussian>(D.Gaussian.Uniform()); M.Variable.ConstrainEqualRandom(x0c0Mean, x0c0MeanMessage); M.Variable.ConstrainEqualRandom(x0c1Mean, x0c1MeanMessage); M.Variable.ConstrainEqualRandom(x1c0Mean, x1c0MeanMessage); M.Variable.ConstrainEqualRandom(x1c1Mean, x1c1MeanMessage); x0c0Mean.AddAttribute(QueryTypes.Marginal); x0c0Mean.AddAttribute(QueryTypes.MarginalDividedByPrior); x0c1Mean.AddAttribute(QueryTypes.Marginal); x0c1Mean.AddAttribute(QueryTypes.MarginalDividedByPrior); x1c0Mean.AddAttribute(QueryTypes.Marginal); x1c0Mean.AddAttribute(QueryTypes.MarginalDividedByPrior); x1c1Mean.AddAttribute(QueryTypes.Marginal); x1c1Mean.AddAttribute(QueryTypes.MarginalDividedByPrior); D.Gaussian x0c0MeanMarginal = D.Gaussian.Uniform(); D.Gaussian x0c1MeanMarginal = D.Gaussian.Uniform(); D.Gaussian x1c0MeanMarginal = D.Gaussian.Uniform(); D.Gaussian x1c1MeanMarginal = D.Gaussian.Uniform(); M.Variable <double> cPrior = M.Variable.Beta(1, 1); y = M.Variable.Bernoulli(cPrior); using (M.Variable.IfNot(y)) { x0.SetTo(M.Variable.GaussianFromMeanAndVariance(x0c0Mean, variance)); x1.SetTo(M.Variable.GaussianFromMeanAndVariance(x1c0Mean, variance)); } using (M.Variable.If(y)) { x0.SetTo(M.Variable.GaussianFromMeanAndVariance(x0c1Mean, variance)); x1.SetTo(M.Variable.GaussianFromMeanAndVariance(x1c1Mean, variance)); } /******* inference engine *******/ M.InferenceEngine engine = new M.InferenceEngine(new A.ExpectationPropagation()); engine.ShowProgress = false; // engine.ShowFactorGraph = true; /********************************/ // the less this is, the more important role the prior over the mean is contributing to the posterior. double k = 10.0; double[] x0c0Meannatural = { 0, 0 }; double[] x0c1Meannatural = { 0, 0 }; double[] x1c0Meannatural = { 0, 0 }; double[] x1c1Meannatural = { 0, 0 }; var results = new StringBuilder(); results.AppendLine("classPost|meanPost0|meanPost1|meanPost2|meanPost3"); for (int t = 0; t < numSamples; t++) { x0c0MeanMessage.ObservedValue = D.Gaussian.Uniform(); x0c1MeanMessage.ObservedValue = D.Gaussian.Uniform(); x1c0MeanMessage.ObservedValue = D.Gaussian.Uniform(); x1c1MeanMessage.ObservedValue = D.Gaussian.Uniform(); x0.ObservedValue = x0Data[t]; x1.ObservedValue = x1Data[t]; y.ObservedValue = yData[t]; D.Gaussian x0c0MeanDataLikelihood = engine.Infer <D.Gaussian>(x0c0Mean, QueryTypes.MarginalDividedByPrior); D.Gaussian x0c1MeanDataLikelihood = engine.Infer <D.Gaussian>(x0c1Mean, QueryTypes.MarginalDividedByPrior); D.Gaussian x1c0MeanDataLikelihood = engine.Infer <D.Gaussian>(x1c0Mean, QueryTypes.MarginalDividedByPrior); D.Gaussian x1c1MeanDataLikelihood = engine.Infer <D.Gaussian>(x1c1Mean, QueryTypes.MarginalDividedByPrior); D.Beta postClass = engine.Infer <D.Beta>(cPrior); double x0c0Meanmb, x0c0Meanb; x0c0MeanDataLikelihood.GetNatural(out x0c0Meanmb, out x0c0Meanb); double x0c1Meanmb, x0c1Meanb; x0c1MeanDataLikelihood.GetNatural(out x0c1Meanmb, out x0c1Meanb); double x1c0Meanmb, x1c0Meanb; x1c0MeanDataLikelihood.GetNatural(out x1c0Meanmb, out x1c0Meanb); double x1c1Meanmb, x1c1Meanb; x1c1MeanDataLikelihood.GetNatural(out x1c1Meanmb, out x1c1Meanb); if (t > k) { x0c0Meannatural[0] = (x0c0Meannatural[0] + x0c0Meanmb) * (k / (k + 1)); x0c0Meannatural[1] = (x0c0Meannatural[1] + x0c0Meanb) * (k / (k + 1)); x0c1Meannatural[0] = (x0c1Meannatural[0] + x0c1Meanmb) * (k / (k + 1)); x0c1Meannatural[1] = (x0c1Meannatural[1] + x0c1Meanb) * (k / (k + 1)); x1c0Meannatural[0] = (x1c0Meannatural[0] + x1c0Meanmb) * (k / (k + 1)); x1c0Meannatural[1] = (x1c0Meannatural[1] + x1c0Meanb) * (k / (k + 1)); x1c1Meannatural[0] = (x1c1Meannatural[0] + x1c1Meanmb) * (k / (k + 1)); x1c1Meannatural[1] = (x1c1Meannatural[1] + x1c1Meanb) * (k / (k + 1)); } else { x0c0Meannatural[0] = x0c0Meannatural[0] + x0c0Meanmb; x0c0Meannatural[1] = x0c0Meannatural[1] + x0c0Meanb; x0c1Meannatural[0] = x0c1Meannatural[0] + x0c1Meanmb; x0c1Meannatural[1] = x0c1Meannatural[1] + x0c1Meanb; x1c0Meannatural[0] = x1c0Meannatural[0] + x1c0Meanmb; x1c0Meannatural[1] = x1c0Meannatural[1] + x1c0Meanb; x1c1Meannatural[0] = x1c1Meannatural[0] + x1c1Meanmb; x1c1Meannatural[1] = x1c1Meannatural[1] + x1c1Meanb; } x0c0MeanMessage.ObservedValue = new D.Gaussian(x0c0Meannatural[0] / x0c0Meannatural[1], 1 / x0c0Meannatural[1]); x0c1MeanMessage.ObservedValue = new D.Gaussian(x0c1Meannatural[0] / x0c1Meannatural[1], 1 / x0c1Meannatural[1]); x1c0MeanMessage.ObservedValue = new D.Gaussian(x1c0Meannatural[0] / x1c0Meannatural[1], 1 / x1c0Meannatural[1]); x1c1MeanMessage.ObservedValue = new D.Gaussian(x1c1Meannatural[0] / x1c1Meannatural[1], 1 / x1c1Meannatural[1]); // these are the posterior distribution over the means x0c0MeanMarginal = engine.Infer <D.Gaussian>(x0c0Mean); x0c1MeanMarginal = engine.Infer <D.Gaussian>(x0c1Mean); x1c0MeanMarginal = engine.Infer <D.Gaussian>(x1c0Mean); x1c1MeanMarginal = engine.Infer <D.Gaussian>(x1c1Mean); var newLine = string.Format("{0}|{1}|{2}|{3}|{4}", postClass.GetMean(), x0c0MeanMarginal.GetMean(), x1c0MeanMarginal.GetMean(), x0c1MeanMarginal.GetMean(), x1c1MeanMarginal.GetMean()); results.AppendLine(newLine); } File.WriteAllText(dataDir + "results.csv", results.ToString()); // using (M.Variable.ForEach(n)) // { // y[n] = M.Variable.Bernoulli(cPrior); // using (M.Variable.IfNot(y[n])) // { // x0[n].SetTo(M.Variable.GaussianFromMeanAndVariance(x0c0Mean, variance)); // x1[n].SetTo(M.Variable.GaussianFromMeanAndVariance(x1c0Mean, variance)); // } // using (M.Variable.If(y[n])) // { // x0[n].SetTo(M.Variable.GaussianFromMeanAndVariance(x0c1Mean, variance)); // x1[n].SetTo(M.Variable.GaussianFromMeanAndVariance(x1c1Mean, variance)); // } // } // /********************************/ // /********* observations *********/ // x0.ObservedValue = x0Data; // x1.ObservedValue = x1Data; // y.ObservedValue = yData; // /********************************/ // /******* inference engine *******/ // M.InferenceEngine engine = new M.InferenceEngine(new A.ExpectationPropagation()); // engine.ShowProgress = false; // // engine.ShowFactorGraph = true; // /********************************/ // /********** posteriors **********/ // D.Gaussian postx0c0Mean = engine.Infer<D.Gaussian>(x0c0Mean); // D.Gaussian postx0c1Mean = engine.Infer<D.Gaussian>(x0c1Mean); // D.Gaussian postx1c0Mean = engine.Infer<D.Gaussian>(x1c0Mean); // D.Gaussian postx1c1Mean = engine.Infer<D.Gaussian>(x1c1Mean); // D.Beta postClass = engine.Infer<D.Beta>(cPrior); // /********************************/ // /********** print outs **********/ // Console.WriteLine("Posterior class: {0}", postClass); // Console.WriteLine("Posterior class0 means: {0} {1}", postx0c0Mean, postx1c0Mean); // Console.WriteLine("Posterior class1 means: {0} {1}", postx0c1Mean, postx1c1Mean); // /********************************/ // /***** creating results.csv *****/ // var results = new StringBuilder(); // results.AppendLine("classPost|meanPost0|meanPost1"); // var line = string.Format("{0}|{1}|{2}", 1-postClass.GetMean(), postx0c0Mean.GetMean(), postx1c0Mean.GetMean()); // results.AppendLine(line.Replace(',', '.')); // line = string.Format("{0}|{1}|{2}", postClass.GetMean(), postx0c1Mean.GetMean(), postx1c1Mean.GetMean()); // results.AppendLine(line.Replace(',', '.')); // File.WriteAllText(dataDir + "results.csv", results.ToString()); // /*********************************/ }
/// <summary> /// Get the abstract syntax tree for the generated code. /// </summary> /// <param name="engine"></param> /// <returns></returns> public List <ITypeDeclaration> GetGeneratedSyntax(InferenceEngine engine) { SetModelName(engine.ModelNamespace, engine.ModelName); return(engine.Compiler.GetTransformedDeclaration(modelType, null, Attributes)); }
/// <summary> /// Update all the SharedVariables registered with this model. /// </summary> /// <param name="engine"></param> /// <param name="batchNumber">A number from 0 to BatchCount-1</param> public void InferShared(InferenceEngine engine, int batchNumber) { SharedVariables.SetInput(this, batchNumber); SharedVariables.InferOutput(engine, this, batchNumber); }
/// <summary> /// Infer the shared variable's output message for the given model and batch number. /// </summary> /// <param name="engine">The inference engine.</param> /// <param name="modelNumber">The model id.</param> /// <param name="batchNumber">The batch number.</param> public abstract void InferOutput(InferenceEngine engine, Model modelNumber, int batchNumber);