Пример #1
0
        /// <summary>
        /// Builds the model necessary to infer marginals for the supplied variables and algorithm.
        /// </summary>
        /// <param name="engine">The inference algorithm being used</param>
        /// <param name="inferOnlySpecifiedVars">If true, inference will be restricted to only the variables given.</param>
        /// <param name="vars">Variables to infer.</param>
        /// <returns></returns>
        /// <remarks>
        /// Algorithm: starting from the variables to infer, we search through the graph to build up a "searched set".
        /// Each Variable and MethodInvoke in this set has an associated timestamp.
        /// We sort by timestamp, and then generate code.
        /// </remarks>
        public ITypeDeclaration Build(InferenceEngine engine, bool inferOnlySpecifiedVars, IEnumerable <IVariable> vars)
        {
            List <IStatementBlock> openBlocks = StatementBlock.GetOpenBlocks();

            if (openBlocks.Count > 0)
            {
                throw new InvalidOperationException("The block " + openBlocks[0] + " has not been closed.");
            }
            Reset();
            this.inferOnlySpecifiedVars = inferOnlySpecifiedVars;
            variablesToInfer.AddRange(vars);
            foreach (IVariable var in vars)
            {
                toSearch.Push(var);
            }
            while (toSearch.Count > 0)
            {
                IModelExpression expr = toSearch.Pop();
                SearchExpressionUntyped(expr);
            }
            // lock in the set of model expressions.
            ModelExpressions = new List <IModelExpression>(searched);
            List <int> timestamps         = new List <int>();
            List <IModelExpression> exprs = new List <IModelExpression>();

            foreach (IModelExpression expr in ModelExpressions)
            {
                if (expr is Variable var)
                {
                    exprs.Add(var);
                    timestamps.Add(var.timestamp);
                }
                else if (expr is MethodInvoke mi)
                {
                    exprs.Add(mi);
                    timestamps.Add(mi.timestamp);
                }
            }
            Collection.Sort(timestamps, exprs);
            foreach (IModelExpression expr in exprs)
            {
                BuildExpressionUntyped(expr);
            }
            foreach (IModelExpression expr in exprs)
            {
                FinishExpressionUntyped(expr, engine.Algorithm);
            }
            return(modelType);
        }
Пример #2
0
 public void AddAttribute(ICompilerAttribute attr)
 {
     InferenceEngine.InvalidateAllEngines(this);
     attributes.Add(attr);
 }
Пример #3
0
        static void Main(string[] args)
        {
            /********* arguments **********/
            string dataDir         = args[0];
            string datasetFilename = args[1];

            /******************************/

            /************ data ************/
            string[] lines      = File.ReadAllLines(dataDir + datasetFilename);
            int      numSamples = lines.Length;

            double[] x0Data = new double[numSamples];
            double[] x1Data = new double[numSamples];
            bool[]   yData  = new bool[numSamples];

            for (int i = 0; i < numSamples; i++)
            {
                string[] strArray    = lines[i].Split('|');
                double[] doubleArray = Array.ConvertAll <string, double>(strArray, Convert.ToDouble);

                x0Data[i] = doubleArray[0];
                x1Data[i] = doubleArray[1];

                if (doubleArray[2] == 1)
                {
                    yData[i] = true;
                }
                else
                {
                    yData[i] = false;
                }
            }
            /********************************/

            /********* model setup **********/
            Range n = new Range(numSamples);

            M.Variable <bool>   y  = M.Variable.New <bool>();
            M.Variable <double> x0 = M.Variable.New <double>();
            M.Variable <double> x1 = M.Variable.New <double>();

            double variance = 1.0;

            M.Variable <double> x0c0Mean = M.Variable.GaussianFromMeanAndVariance(0, 10);
            M.Variable <double> x0c1Mean = M.Variable.GaussianFromMeanAndVariance(0, 10);
            M.Variable <double> x1c0Mean = M.Variable.GaussianFromMeanAndVariance(0, 10);
            M.Variable <double> x1c1Mean = M.Variable.GaussianFromMeanAndVariance(0, 10);

            // stores the product of all messages sent the means by the previous batches
            M.Variable <D.Gaussian> x0c0MeanMessage = M.Variable.Observed <D.Gaussian>(D.Gaussian.Uniform());
            M.Variable <D.Gaussian> x0c1MeanMessage = M.Variable.Observed <D.Gaussian>(D.Gaussian.Uniform());
            M.Variable <D.Gaussian> x1c0MeanMessage = M.Variable.Observed <D.Gaussian>(D.Gaussian.Uniform());
            M.Variable <D.Gaussian> x1c1MeanMessage = M.Variable.Observed <D.Gaussian>(D.Gaussian.Uniform());

            M.Variable.ConstrainEqualRandom(x0c0Mean, x0c0MeanMessage);
            M.Variable.ConstrainEqualRandom(x0c1Mean, x0c1MeanMessage);
            M.Variable.ConstrainEqualRandom(x1c0Mean, x1c0MeanMessage);
            M.Variable.ConstrainEqualRandom(x1c1Mean, x1c1MeanMessage);

            x0c0Mean.AddAttribute(QueryTypes.Marginal);
            x0c0Mean.AddAttribute(QueryTypes.MarginalDividedByPrior);
            x0c1Mean.AddAttribute(QueryTypes.Marginal);
            x0c1Mean.AddAttribute(QueryTypes.MarginalDividedByPrior);
            x1c0Mean.AddAttribute(QueryTypes.Marginal);
            x1c0Mean.AddAttribute(QueryTypes.MarginalDividedByPrior);
            x1c1Mean.AddAttribute(QueryTypes.Marginal);
            x1c1Mean.AddAttribute(QueryTypes.MarginalDividedByPrior);

            D.Gaussian x0c0MeanMarginal = D.Gaussian.Uniform();
            D.Gaussian x0c1MeanMarginal = D.Gaussian.Uniform();
            D.Gaussian x1c0MeanMarginal = D.Gaussian.Uniform();
            D.Gaussian x1c1MeanMarginal = D.Gaussian.Uniform();

            M.Variable <double> cPrior = M.Variable.Beta(1, 1);

            y = M.Variable.Bernoulli(cPrior);

            using (M.Variable.IfNot(y))
            {
                x0.SetTo(M.Variable.GaussianFromMeanAndVariance(x0c0Mean, variance));
                x1.SetTo(M.Variable.GaussianFromMeanAndVariance(x1c0Mean, variance));
            }
            using (M.Variable.If(y))
            {
                x0.SetTo(M.Variable.GaussianFromMeanAndVariance(x0c1Mean, variance));
                x1.SetTo(M.Variable.GaussianFromMeanAndVariance(x1c1Mean, variance));
            }

            /******* inference engine *******/
            M.InferenceEngine engine = new M.InferenceEngine(new A.ExpectationPropagation());
            engine.ShowProgress = false;
            // engine.ShowFactorGraph = true;
            /********************************/

            // the less this is, the more important role the prior over the mean is contributing to the posterior.
            double k = 10.0;

            double[] x0c0Meannatural = { 0, 0 };
            double[] x0c1Meannatural = { 0, 0 };
            double[] x1c0Meannatural = { 0, 0 };
            double[] x1c1Meannatural = { 0, 0 };

            var results = new StringBuilder();

            results.AppendLine("classPost|meanPost0|meanPost1|meanPost2|meanPost3");

            for (int t = 0; t < numSamples; t++)
            {
                x0c0MeanMessage.ObservedValue = D.Gaussian.Uniform();
                x0c1MeanMessage.ObservedValue = D.Gaussian.Uniform();
                x1c0MeanMessage.ObservedValue = D.Gaussian.Uniform();
                x1c1MeanMessage.ObservedValue = D.Gaussian.Uniform();

                x0.ObservedValue = x0Data[t];
                x1.ObservedValue = x1Data[t];
                y.ObservedValue  = yData[t];

                D.Gaussian x0c0MeanDataLikelihood = engine.Infer <D.Gaussian>(x0c0Mean, QueryTypes.MarginalDividedByPrior);
                D.Gaussian x0c1MeanDataLikelihood = engine.Infer <D.Gaussian>(x0c1Mean, QueryTypes.MarginalDividedByPrior);
                D.Gaussian x1c0MeanDataLikelihood = engine.Infer <D.Gaussian>(x1c0Mean, QueryTypes.MarginalDividedByPrior);
                D.Gaussian x1c1MeanDataLikelihood = engine.Infer <D.Gaussian>(x1c1Mean, QueryTypes.MarginalDividedByPrior);

                D.Beta postClass = engine.Infer <D.Beta>(cPrior);

                double x0c0Meanmb, x0c0Meanb;
                x0c0MeanDataLikelihood.GetNatural(out x0c0Meanmb, out x0c0Meanb);
                double x0c1Meanmb, x0c1Meanb;
                x0c1MeanDataLikelihood.GetNatural(out x0c1Meanmb, out x0c1Meanb);
                double x1c0Meanmb, x1c0Meanb;
                x1c0MeanDataLikelihood.GetNatural(out x1c0Meanmb, out x1c0Meanb);
                double x1c1Meanmb, x1c1Meanb;
                x1c1MeanDataLikelihood.GetNatural(out x1c1Meanmb, out x1c1Meanb);

                if (t > k)
                {
                    x0c0Meannatural[0] = (x0c0Meannatural[0] + x0c0Meanmb) * (k / (k + 1));
                    x0c0Meannatural[1] = (x0c0Meannatural[1] + x0c0Meanb) * (k / (k + 1));
                    x0c1Meannatural[0] = (x0c1Meannatural[0] + x0c1Meanmb) * (k / (k + 1));
                    x0c1Meannatural[1] = (x0c1Meannatural[1] + x0c1Meanb) * (k / (k + 1));
                    x1c0Meannatural[0] = (x1c0Meannatural[0] + x1c0Meanmb) * (k / (k + 1));
                    x1c0Meannatural[1] = (x1c0Meannatural[1] + x1c0Meanb) * (k / (k + 1));
                    x1c1Meannatural[0] = (x1c1Meannatural[0] + x1c1Meanmb) * (k / (k + 1));
                    x1c1Meannatural[1] = (x1c1Meannatural[1] + x1c1Meanb) * (k / (k + 1));
                }
                else
                {
                    x0c0Meannatural[0] = x0c0Meannatural[0] + x0c0Meanmb;
                    x0c0Meannatural[1] = x0c0Meannatural[1] + x0c0Meanb;
                    x0c1Meannatural[0] = x0c1Meannatural[0] + x0c1Meanmb;
                    x0c1Meannatural[1] = x0c1Meannatural[1] + x0c1Meanb;
                    x1c0Meannatural[0] = x1c0Meannatural[0] + x1c0Meanmb;
                    x1c0Meannatural[1] = x1c0Meannatural[1] + x1c0Meanb;
                    x1c1Meannatural[0] = x1c1Meannatural[0] + x1c1Meanmb;
                    x1c1Meannatural[1] = x1c1Meannatural[1] + x1c1Meanb;
                }

                x0c0MeanMessage.ObservedValue = new D.Gaussian(x0c0Meannatural[0] / x0c0Meannatural[1], 1 / x0c0Meannatural[1]);
                x0c1MeanMessage.ObservedValue = new D.Gaussian(x0c1Meannatural[0] / x0c1Meannatural[1], 1 / x0c1Meannatural[1]);
                x1c0MeanMessage.ObservedValue = new D.Gaussian(x1c0Meannatural[0] / x1c0Meannatural[1], 1 / x1c0Meannatural[1]);
                x1c1MeanMessage.ObservedValue = new D.Gaussian(x1c1Meannatural[0] / x1c1Meannatural[1], 1 / x1c1Meannatural[1]);

                // these are the posterior distribution over the means
                x0c0MeanMarginal = engine.Infer <D.Gaussian>(x0c0Mean);
                x0c1MeanMarginal = engine.Infer <D.Gaussian>(x0c1Mean);
                x1c0MeanMarginal = engine.Infer <D.Gaussian>(x1c0Mean);
                x1c1MeanMarginal = engine.Infer <D.Gaussian>(x1c1Mean);

                var newLine = string.Format("{0}|{1}|{2}|{3}|{4}", postClass.GetMean(), x0c0MeanMarginal.GetMean(), x1c0MeanMarginal.GetMean(), x0c1MeanMarginal.GetMean(), x1c1MeanMarginal.GetMean());
                results.AppendLine(newLine);
            }

            File.WriteAllText(dataDir + "results.csv", results.ToString());



            // using (M.Variable.ForEach(n))
            // {
            //     y[n] = M.Variable.Bernoulli(cPrior);

            //     using (M.Variable.IfNot(y[n]))
            //     {
            //         x0[n].SetTo(M.Variable.GaussianFromMeanAndVariance(x0c0Mean, variance));
            //         x1[n].SetTo(M.Variable.GaussianFromMeanAndVariance(x1c0Mean, variance));
            //     }
            //     using (M.Variable.If(y[n]))
            //     {
            //         x0[n].SetTo(M.Variable.GaussianFromMeanAndVariance(x0c1Mean, variance));
            //         x1[n].SetTo(M.Variable.GaussianFromMeanAndVariance(x1c1Mean, variance));
            //     }
            // }
            // /********************************/

            // /********* observations *********/
            // x0.ObservedValue = x0Data;
            // x1.ObservedValue = x1Data;
            // y.ObservedValue = yData;
            // /********************************/

            // /******* inference engine *******/
            // M.InferenceEngine engine = new M.InferenceEngine(new A.ExpectationPropagation());
            // engine.ShowProgress = false;
            // // engine.ShowFactorGraph = true;
            // /********************************/

            // /********** posteriors **********/
            // D.Gaussian postx0c0Mean = engine.Infer<D.Gaussian>(x0c0Mean);
            // D.Gaussian postx0c1Mean = engine.Infer<D.Gaussian>(x0c1Mean);
            // D.Gaussian postx1c0Mean = engine.Infer<D.Gaussian>(x1c0Mean);
            // D.Gaussian postx1c1Mean = engine.Infer<D.Gaussian>(x1c1Mean);
            // D.Beta postClass = engine.Infer<D.Beta>(cPrior);
            // /********************************/

            // /********** print outs **********/
            // Console.WriteLine("Posterior class: {0}", postClass);
            // Console.WriteLine("Posterior class0 means: {0} {1}", postx0c0Mean, postx1c0Mean);
            // Console.WriteLine("Posterior class1 means: {0} {1}", postx0c1Mean, postx1c1Mean);
            // /********************************/

            // /***** creating results.csv *****/
            // var results = new StringBuilder();
            // results.AppendLine("classPost|meanPost0|meanPost1");
            // var line = string.Format("{0}|{1}|{2}", 1-postClass.GetMean(), postx0c0Mean.GetMean(), postx1c0Mean.GetMean());
            // results.AppendLine(line.Replace(',', '.'));
            // line = string.Format("{0}|{1}|{2}", postClass.GetMean(), postx0c1Mean.GetMean(), postx1c1Mean.GetMean());
            // results.AppendLine(line.Replace(',', '.'));
            // File.WriteAllText(dataDir + "results.csv", results.ToString());
            // /*********************************/
        }
Пример #4
0
 /// <summary>
 /// Get the abstract syntax tree for the generated code.
 /// </summary>
 /// <param name="engine"></param>
 /// <returns></returns>
 public List <ITypeDeclaration> GetGeneratedSyntax(InferenceEngine engine)
 {
     SetModelName(engine.ModelNamespace, engine.ModelName);
     return(engine.Compiler.GetTransformedDeclaration(modelType, null, Attributes));
 }
Пример #5
0
 /// <summary>
 /// Update all the SharedVariables registered with this model.
 /// </summary>
 /// <param name="engine"></param>
 /// <param name="batchNumber">A number from 0 to BatchCount-1</param>
 public void InferShared(InferenceEngine engine, int batchNumber)
 {
     SharedVariables.SetInput(this, batchNumber);
     SharedVariables.InferOutput(engine, this, batchNumber);
 }
Пример #6
0
 /// <summary>
 /// Infer the shared variable's output message for the given model and batch number.
 /// </summary>
 /// <param name="engine">The inference engine.</param>
 /// <param name="modelNumber">The model id.</param>
 /// <param name="batchNumber">The batch number.</param>
 public abstract void InferOutput(InferenceEngine engine, Model modelNumber, int batchNumber);