Exemplo n.º 1
0
        /// <summary>
        /// Construct a node
        /// </summary>
        /// <param name="name">name of the node</param>
        /// <param name="domain">the domain of the node (list of states the node can be in)</param>
        /// <param name="values">the probabilities of those states given all possible parent states</param>
        /// <param name="parents">the parent nodes</param>
        public BayesianNode(string name, string[] domain, double[] values, params BayesianNode[] parents)
        {
            this.var      = new RandomVariable(name, domain);
            this.children = new List <BayesianNode>();
            foreach (BayesianNode p in parents)
            {
                p.children.Add(this);
            }

            RandomVariable[] vars = parents
                                    .Select <BayesianNode, RandomVariable>(p => p.var)
                                    .Concat(new RandomVariable[] { this.var })
                                    .ToArray();

            // Validating the node configuration
            int numberOfValues = vars.Aggregate <RandomVariable, int>(1, (acc, next) => acc * next.tokens.Length);

            if (numberOfValues != values.Length)
            {
                throw new ArgumentException("The expect number of values for node " + name + " is " + numberOfValues +
                                            ". The actual number of value given is " + values.Length);
            }

            this.cpt = new CPT(vars, values);
        }
Exemplo n.º 2
0
 private void CategoriseEntries(CPT factor, int[] commVarsIndices, Dictionary <CptKey, List <CptEntry> > mapping)
 {
     foreach (CptEntry entry in factor.cpt)
     {
         CptKey key = entry.key.Extract(commVarsIndices);
         if (!mapping.ContainsKey(key))
         {
             mapping.Add(key, new List <CptEntry>());
         }
         mapping[key].Add(entry);
     }
 }
Exemplo n.º 3
0
        internal CPT PointWiseProduct(CPT other)
        {
            IEnumerable <RandomVariable> commonVars = this.vars.Intersect(other.vars);
            HashSet <RandomVariable>     rightVars  = new HashSet <RandomVariable>(other.vars);

            rightVars.ExceptWith(commonVars);

            int[] commVarsIndicesInLeft  = commonVars.Select <RandomVariable, int>(v => keyPosMap[v]).ToArray();
            int[] commVarsIndicesInRight = commonVars.Select <RandomVariable, int>(v => other.keyPosMap[v]).ToArray();

            List <RandomVariable> varsInResultCPT   = new List <RandomVariable>(this.vars);
            List <RandomVariable> rightVarsInResult = new List <RandomVariable>(other.vars);

            rightVarsInResult.RemoveAll(v => commonVars.Contains(v));
            varsInResultCPT.AddRange(rightVarsInResult);

            Dictionary <RandomVariable, int> newKeyPosMap = new Dictionary <RandomVariable, int>();

            for (int i = 0; i < varsInResultCPT.Count(); i++)
            {
                newKeyPosMap.Add(varsInResultCPT[i], i);
            }

            int newCptSize = varsInResultCPT.Aggregate <RandomVariable, int>(1, (acc, v) => acc * v.tokens.Length);

            CptEntry[] newCptEntries = new CptEntry[newCptSize];

            Dictionary <CptKey, List <CptEntry> > mapping = new Dictionary <CptKey, List <CptEntry> >();

            CategoriseEntries(other, commVarsIndicesInRight, mapping);

            int cptIndex = 0;

            foreach (CptEntry leftEntry in this.cpt)
            {
                CptKey          key   = leftEntry.key.Extract(commVarsIndicesInLeft);
                List <CptEntry> right = mapping[key];
                foreach (CptEntry rightEntry in right)
                {
                    CptKey newKey = leftEntry.key.Concat(rightEntry.key.Remove(commVarsIndicesInRight));
                    newCptEntries[cptIndex] = new CptEntry(newKey, leftEntry.value * rightEntry.value);
                    cptIndex++;
                }
            }

            return(new CPT(varsInResultCPT.ToArray(), newKeyPosMap, newCptEntries));
        }
Exemplo n.º 4
0
        /// <summary>
        /// Given a BayesianNode object and an array of Proposition objects as observations,
        /// this function perform inferences and return the inferred distribution of the node variable.
        /// </summary>
        /// <param name="query">the BayesianNode object being query</param>
        /// <param name="observations">the array of Proposition objects</param>
        /// <returns>a distribution table for the values of the query node</returns>
        public double[] Infer(BayesianNode query, params Proposition[] observations)
        {
            try {
                string[] evidenceNames = observations.Select(o => o.name).ToArray();

                IEnumerable <BayesianNode> nodes = TopologicalSort(network.GetNodes()).Reverse();

                HashSet <BayesianNode> relevantNodes = new HashSet <BayesianNode>();
                relevantNodes.Add(query);
                foreach (Proposition p in observations)
                {
                    relevantNodes.Add(network.FindNode(p.name));
                }
                MarkRelevantVariables(relevantNodes, nodes);
                // Remove variables that are not ancestor of a query variable or evidence variable
                nodes = nodes.Intersect(relevantNodes);

                List <CPT> factors = new List <CPT>();
                foreach (BayesianNode node in nodes)
                {
                    factors.Add(node.MakeFactor(observations));
                    if (node.var.name.Equals(query.var.name) || Array.IndexOf(evidenceNames, node.var.name) >= 0)
                    {
                        continue;
                    }

                    CPT[] factorsToSumOut = factors.Where(f => f.ContainsVar(node.var)).ToArray();
                    factors.RemoveAll(f => Array.IndexOf(factorsToSumOut, f) != -1);

                    CPT factorToSumOut;
                    if (factorsToSumOut.Count() > 1)
                    {
                        factorToSumOut = factorsToSumOut.Skip(1).Aggregate(factorsToSumOut.First(), (acc, f) => acc.PointWiseProduct(f));
                    }
                    else
                    {
                        factorToSumOut = factorsToSumOut.First();
                    }

                    CPT afterSumOut = factorToSumOut.SumOut(node.var);
                    factors.Add(afterSumOut);
                }

                CPT result;
                if (factors.Count() > 1)
                {
                    result = factors.Skip(1).Aggregate <CPT, CPT>(factors.First(), (acc, f) => acc.PointWiseProduct(f));
                }
                else
                {
                    result = factors.First();
                }
                return(result.Distribution());
            } catch (Exception e)
            {
                throw new Exception("Unable to perform inference on the network. " +
                                    "Please make sure the network is valid and the propositions are valid. " +
                                    "Contact the author of this library if you suspect it is a problem in the library." +
                                    "Actual Error (Reason: " + e.Message + " " + e.StackTrace + ")"
                                    );
            }
        }