Beispiel #1
0
 public Varset And(Varset varset)
 {
     Varset cp = new Varset(this);
     cp.AlignLength(varset);
     cp.item.And(varset.item);
     return cp;
 }
Beispiel #2
0
 private void CreateTree()
 {
     BitArray countIndices = new BitArray(recordCount);
     countIndices.SetAll(true);
     Varset empty = new Varset(network.Size());
     root = MakeADTree(0, countIndices, 0, empty);
 }
Beispiel #3
0
        private ADNode MakeADTree(int i, BitArray recordNums, int depth, Varset variables)
        {
            // since this is index i, there are (variableCount - i) remaining variables.
            // therefore, it will have that many children
            int count = 0;
            for(int idx = 0; idx < recordNums.Count; idx++)
            {
                if (recordNums[idx])
                {
                    count += 1;
                }
            }
            ADNode adn = new ADNode(network.Size() - i, count);

            // check if we should just use a leaf list
            if (adn.Count < rMin)
            {
                BitArray leafList = new BitArray(recordNums);
                adn.LeafList = leafList;
                return adn;
            }

            // for each of the remaining variables
            for (int j = i; j < network.Size(); j++)
            {
                // create a vary node
                variables.Set(j, true);
                Varset newVariables = new Varset(variables);
                VaryNode child = MakeVaryNode(j, recordNums, depth, newVariables);
                adn.SetChild(j - i, child);
            }

            return adn;
        }
Beispiel #4
0
        private void Initialize(BayesianNetwork network, RecordFile recordFile)
        {
            this.network = network;
            recordCount = recordFile.Size();
            zero = new Varset(network.Size());

            consistentRecords = network.GetConsistentRecords(recordFile);
        }
 private double T(int variable, Varset parents)
 {
     double penalty = network.GetCardinality(variable) - 1;
     for (int pa = 0; pa < network.Size(); pa++)
     {
         if (parents.Get(pa))
         {
             penalty *= network.GetCardinality(pa);
         }
     }
     return penalty;
 }
Beispiel #6
0
 public Varset Add(Varset varset)
 {
     Varset cp = new Varset(this);
     Varset zero = new Varset(0);
     cp.AlignLength(varset);
     while (!varset.Equals(zero))
     {
         Varset tmp = cp.And(varset).LeftShift(1);
         cp = cp.Xor(varset);
         varset = tmp;
     }
     return cp;
 }
        private void Calculate(ContingencyTableNode ct, ulong currentBase, ulong index, Dictionary<ulong, int> paCounts, int variable, Varset variables, int previousVariable, ref double score)
        {
            // if this is a leaf
            if (ct.IsLeaf())
            {
                // update the instantiation count of this set of parents
                int count = paCounts.ContainsKey(index) ? paCounts[index] : 0;
                count += ct.Value;
                paCounts[index] = count;

                // update the score for this variable, parent instantiation
                score += ilogi[ct.Value];
                return;
            }

            // which actual variable are we looking at
            int thisVariable = previousVariable + 1;
            for (; thisVariable < network.Size(); thisVariable++)
            {
                if (variables.Get(thisVariable))
                {
                    break;
                }
            }

            // update the base and index if this is part of the parent set
            ulong nextBase = currentBase;
            if (thisVariable != variable)
            {
                nextBase *= (ulong)network.GetCardinality(thisVariable);
            }

            // recurse
            for (int k = 0; k < network.GetCardinality(thisVariable); k++)
            {
                ContingencyTableNode child = ct.GetChild(k);
                if (child != null)
                {
                    ulong newIndex = index;
                    if (thisVariable != variable)
                    {
                        newIndex += currentBase * (ulong)k;
                    }
                    Calculate(child, nextBase, newIndex, paCounts, variable, variables, thisVariable, ref score);
                }
            }
        }
        public void Prune(DoubleMap cache)
        {
            List<KeyValuePair<ulong, double>> pairs = new List<KeyValuePair<ulong, double>>();
            foreach (KeyValuePair<ulong, double> kvp in cache)
            {
                pairs.Add(kvp);
            }
            pairs.Sort(Comparison);

            // keep track of the ones that have been pruned
            BitArray prunedSet = new BitArray(pairs.Count);
            for (int i = 0; i < pairs.Count; i++)
            {
                if (prunedSet.Get(i))
                {
                    continue;
                }

                Varset pi = new Varset(pairs[i].Key);

                // make sure this variable set is not in an incomplete last layer
                if (pi.Cardinality() > highestCompletedLayer)
                {
                    prunedSet.Set(i, true);
                    continue;
                }

                for (int j = i + 1; j < pairs.Count; j++)
                {
                    if (prunedSet.Get(j))
                    {
                        continue;
                    }

                    // check if parents[i] is a subset of parents[j]
                    Varset pj = new Varset(pairs[j].Key);

                    if (pi.And(pj).Equals(pi)) // 部分集合かどうかの判定
                    {
                        // then we can prune pj
                        prunedSet.Set(j, true);
                        cache.Remove(pj.ToULong());
                    }
                }
            }
        }
        public double Calculate(int variable, Varset parents, Dictionary<ulong, int> paCounts)
        {
            parents.Set(variable, true);
            double score = 0;

            ContingencyTableNode ct = adTree.MakeContab(parents);

            Calculate(ct, 1, 0, paCounts, variable, parents, -1, ref score);

            foreach(KeyValuePair<ulong, int> pair in paCounts)
            {
                score -= ilogi[pair.Value];
            }

            parents.Set(variable, false);

            return score;
        }
        public override double CalculateScore(int variable, Varset parents, DoubleMap cache)
        {
            Varset parentsCp = new Varset(parents);
            // TODO check if this violates the constraints
            //if(constraints != NULL && !constraints.SatisfiesConstraints(variable, parents))
            //{
            //}

            // check for prunning
            double tVal = T(variable, parentsCp);

            for (int x = 0; x < network.Size(); x++)
            {
                if (parentsCp.Get(x))
                {
                    parentsCp.Set(x, false);

                    // TODO check the constraints
                    //if (invalidParents.Count > 0 && invalidParents.Contains(parentsCp.ToULong()))
                    //{
                    //    parentsCp.Set(x);
                    //    continue;
                    //}

                    double tmp = cache.ContainsKey(parentsCp.ToULong()) ? cache[parentsCp.ToULong()] : 0;
                    if (tmp + tVal > 0)
                    {
                        return 0;
                    }

                    parentsCp.Set(x, true);
                }
            }

            double score = llc.Calculate(variable, parentsCp);

            // structure penalty
            score -= tVal * baseComplexityPenalty;
            //Console.WriteLine("score: " + score);

            return score;
        }
Beispiel #11
0
        private VaryNode MakeVaryNode(int i, BitArray recordNums, int depth, Varset variables)
        {
            // this node will have variableCardinalities[i] children
            VaryNode vn = new VaryNode(network.GetCardinality(i));

            // split into childNums
            List<BitArray> childNums = new List<BitArray>();

            int mcv = -1;
            int mcvCount = -1;
            for (int k = 0; k < network.GetCardinality(i); k++)
            {
                childNums.Add(new BitArray(recordCount));
                childNums[k] = childNums[k].Or(recordNums);
                childNums[k] = childNums[k].And(consistentRecords[i][k]);

                // also look for the mcv
                int count = 0;
                for (int idx = 0; idx < childNums[k].Count; idx++)
                {
                    if (childNums[k][idx])
                    {
                        count += 1;
                    }
                }
                if (count > mcvCount)
                {
                    mcv = k;
                    mcvCount = count;
                }
            }

            // update the mcv
            vn.Mcv = mcv;

            // otherwise, rescue
            for (int k = 0; k < network.GetCardinality(i); k++)
            {
                int count = 0;
                for (int idx = 0; idx < childNums[k].Count; idx++)
                {
                    if (childNums[k][idx])
                    {
                        count += 1;
                    }
                }
                if (k == mcv || count == 0)
                {
                    continue;
                }

                ADNode child = MakeADTree(i + 1, childNums[k], depth + 1, variables);
                vn.SetChild(k, child);
            }

            return vn;
        }
Beispiel #12
0
        private ContingencyTableNode MakeContab(Varset remainingVariables, ADNode node, int nodeIndex)
        {
            // check base case
            if (remainingVariables.Equals(zero))
            {
                ContingencyTableNode ctn = new ContingencyTableNode(node.Count, 0, 1);
                return ctn;
            }

            int firstIndex = remainingVariables.FindFirst();
            int n = network.GetCardinality(firstIndex);
            VaryNode vn = node.GetChild(firstIndex - nodeIndex - 1);
            ContingencyTableNode ct = new ContingencyTableNode(0, n, 0);
            Varset newVariables = Varset.ClearCopy(remainingVariables, firstIndex);

            ContingencyTableNode ctMcv = MakeContab(newVariables, node, nodeIndex);

            for (int k = 0; k < n; k++)
            {
                if (vn.GetChild(k) == null)
                {
                    continue;
                }

                ADNode adn = vn.GetChild(k);

                ContingencyTableNode child = null;
                if (adn.LeafList.Count == 0) // これ注意
                {
                    child = MakeContab(newVariables, adn, firstIndex);
                }
                else
                {
                    child = MakeContabLeafList(newVariables, adn.LeafList);
                }

                ct.SetChild(k, child);
                ct.LeafCount += child.LeafCount;

                ctMcv.Subtract(ct.GetChild(k));
            }
            ct.SetChild(vn.Mcv, ctMcv);
            ct.LeafCount += ctMcv.LeafCount;

            return ct;
        }
Beispiel #13
0
        private ContingencyTableNode MakeContabLeafList(Varset variables, BitArray records)
        {
            Varset variablesCp = new Varset(variables);
            if (variablesCp.Equals(zero))
            {
                int count = 0;
                for (int i = 0; i < records.Count; i++)
                {
                    if (records[i])
                    {
                        count += 1;
                    }
                }
                return new ContingencyTableNode(count, 0, 1);
            }

            int firstIndex = variables.FindFirst();
            int cardinality = network.GetCardinality(firstIndex);
            ContingencyTableNode ct = new ContingencyTableNode(0, cardinality, 0);
            variablesCp.Set(firstIndex, false);
            Varset remainingVariables = new Varset(variablesCp);
            for (int k = 0; k < cardinality; k++)
            {
                BitArray r = new BitArray(recordCount);
                r = r.Or(records);
                r = r.And(consistentRecords[firstIndex][k]);

                int count = 0;
                for (int i = 0; i < r.Count; i++)
                {
                    if (r[i])
                    {
                        count += 1;
                    }
                }
                if (count > 0)
                {
                    ContingencyTableNode child = MakeContabLeafList(remainingVariables, r);
                    ct.SetChild(k, child);
                    ct.LeafCount += child.LeafCount;
                }
            }
            return ct;
        }
Beispiel #14
0
 public Varset LeftShift(int i)
 {
     Varset cp = new Varset(this);
     cp.item.Length += i;
     for (int k = cp.item.Length - 1; k > 0; k--)
     {
         cp.item[k] = cp.item[k - i];
     }
     cp.item[0] = false;
     return cp;
 }
Beispiel #15
0
        public void Execute(string[] args)
        {
            // 引数のチェック
            if (args.Length != 2)
            {
                // エラーメッセージ表示
                Console.WriteLine("[エラー] 引数の数が不正です");
                Environment.Exit(0);
            }

            // オプションのチェック
            // 暫定的に初期値を代入
            int rMin = 5; // The minimum number of records in the AD-tree nodes.
            char delimiter = ',';
            bool hasHeader = true;
            string sf = "BDeu";
            int maxParents = 0;
            string constraintsFile = "";
            int runningTime = -1;
            int threadCount = 20;
            bool prune = true;
            string inputFile = args[0];
            string outputFile = args[1];
            double equivarentSampleSize = 1;

            // csvファイルの読み込み
            RecordFile recordFile = new RecordFile();
            recordFile.ReadRecord(inputFile, hasHeader, delimiter);
            //recordFile.Print();

            // BayesianNetworkの初期化
            BayesianNetwork network = new BayesianNetwork(recordFile);

            // AD-Treeの初期化
            ADTree adTree = new ADTree(rMin, network, recordFile);

            // scoring functionの設定
            sf = sf.ToLower();
            if (maxParents > network.Size() || maxParents < 1)
            {
                maxParents = network.Size() - 1;
            }

            if (sf == "bic")
            {
                int maxParentCount = (int)Math.Log(2 * recordFile.Size() / Math.Log(recordFile.Size()));
                if(maxParentCount < maxParents)
                {
                    maxParents = maxParentCount;
                }
            }
            else if (sf == "fnml") { }
            else if (sf == "bdeu") { }
            else
            {
                throw new ArgumentException("Invalid scoring function. Options are: 'BIC', 'fNML' or 'BDeu'.");
            }

            // TODO constraintsの実装
            Constraints constraints = null;
            if (constraintsFile.Length > 0)
            {
                constraints = Constraints.ParseConstraints(constraintsFile, network);
            }

            ScoringFunction scoringFunction = null;

            List<double> ilogi = LogLikelihoodCalculator.GetLogCache(recordFile.Size());
            LogLikelihoodCalculator llc = new LogLikelihoodCalculator(adTree, network, ilogi);

            // TODO regretの実装
            //std::vector<std::vector<float>*>* regret = scoring::getRegretCache(recordFile.size(), network.getMaxCardinality());

            if (sf == "bic")
            {
                scoringFunction = new BICScoringFunction(network, recordFile, llc, constraints);
            }
            else if (sf == "fnml")
            {
                //
            }
            else if (sf == "bdeu")
            {
                scoringFunction = new BDeuScoringFunction(equivarentSampleSize, network, adTree, constraints);
            }

            ScoreCalculator scoreCalculator = new ScoreCalculator(scoringFunction, maxParents, network.Size(), runningTime, constraints);
            // TODO Parallelクラスによる並列処理
            Parallel.For(0, threadCount, i =>
            {
                for (int variable = 0; variable < network.Size(); variable++)
                {
                    if(variable % threadCount != i)
                    {
                        continue;
                    }

                    Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Time: " + DateTime.Now);

                    DoubleMap sc = new DoubleMap();
                    scoreCalculator.CalculateScores(variable, sc);

                    int size = sc.Count;
                    Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Size before pruning: " + size + ", Time: " + DateTime.Now);

                    if (prune)
                    {
                        scoreCalculator.Prune(sc);
                        int prunedSize = sc.Count;
                        Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Size after pruning: " + prunedSize + ", Time: " + DateTime.Now);
                    }

                    string varFilename = outputFile + "." + variable;
                    StreamWriter writer = new StreamWriter(varFilename, false);

                    Variable var = network.Get(variable);
                    writer.Write("VAR " + var.Name + "\n");
                    writer.Write("META arity=" + var.GetCardinality() + "\n");
                    writer.Write("META values=");
                    for (int k = 0; k < var.GetCardinality(); k++)
                    {
                        writer.Write(var.GetValue(k) + " ");
                    }
                    writer.Write("\n");

                    foreach (KeyValuePair<ulong, double> kvp in sc)
                    {
                        Varset parentSet = new Varset(kvp.Key);
                        double s = kvp.Value;

                        writer.Write(s + " ");

                        for (int p = 0; p < network.Size(); p++)
                        {
                            if (parentSet.Get(p))
                            {
                                writer.Write(network.Get(p).Name + " ");
                            }
                        }
                        writer.Write("\n");
                    }
                    writer.Write("\n");
                    writer.Close();
                }
            });

            // concatenate all of the files together
            StreamWriter outFile = new StreamWriter(outputFile, false); // 修正が必要かも
            // first, the header information
            string header = "META pss_version = 0.1\nMETA input_file=" + inputFile + "\nMETA num_records=" + recordFile.Size() + "\n";
            header += "META parent_limit=" + maxParents + "\nMETA score_type=" + sf + "\nMETA ess=" + equivarentSampleSize + "\n\n";

            outFile.Write(header);

            for (int variable = 0; variable < network.Size(); variable++)
            {
                string varFilename = outputFile + "." + variable;
                StreamReader reader = new StreamReader(varFilename);
                outFile.Write(reader.ReadToEnd());
                reader.Close();
            }
            outFile.Close();
        }
Beispiel #16
0
        private void CalculateScoresInternal(int variable, DoubleMap cache)
        {
            // calculate initial score
            Varset empty = new Varset(variableCount + 1); // 注意: c++だと0
            double score = scoringFunction.CalculateScore(variable, empty, cache);

            if (score < 1)
            {
                cache[empty.ToULong()] = score;
            }

            int prunedCount = 0;

            for (int layer = 1; layer <= maxParents && !outOfTime; layer++)
            {
                // DEBUG
                Console.WriteLine("layer: " + layer + ", prunedCount: " + prunedCount);

                Varset variables = new Varset(variableCount + 1); // 注意: c++だと0
                for (int i = 0; i < layer; i++)
                {
                    variables.Set(i, true);
                }

                Varset max = new Varset(variableCount);
                max.Set(variableCount, true);

                while (variables.LessThan(max) && !outOfTime)
                {
                    if (!variables.Get(variable))
                    {
                        score = scoringFunction.CalculateScore(variable, variables, cache);

                        if (score < 0)
                        {
                            cache[variables.ToULong()] = score;
                        }
                        else
                        {
                            prunedCount++;
                        }
                    }

                    variables = variables.NextPermutation();
                }

                if (!outOfTime)
                {
                    highestCompletedLayer = layer;
                }
            }
        }
Beispiel #17
0
 public Varset StaticAdd(Varset varset)
 {
     Varset cp = new Varset(this);
     int length = cp.item.Count;
     cp = cp.Add(varset);
     cp = cp.SubVarset(length);
     return cp;
 }
Beispiel #18
0
 public Varset(Varset varset)
 {
     item = new BitArray(varset.item);
 }
Beispiel #19
0
        public Varset Divide(Varset varset)
        {
            int length = this.item.Count;
            Varset n = new Varset(this);
            Varset d = new Varset(varset);
            //Console.Write("n: ");
            //n.Print();
            //Console.Write("d: ");
            //d.Print();
            Varset m = new Varset(n.item.Length);
            m.Set(0, true);
            Varset q = new Varset(n.item.Count);
            Varset zero = new Varset(n.item.Count);

            if (n.Equals(zero))
            {
                return zero;
            }
            else if (d.Equals(zero))
            {
                throw new ArgumentException("Zero Division.");
            }

            while (d.LessThan(n) || d.Equals(n))
            {
                d = d.LeftShift(1);
                m = m.LeftShift(1);
            }
            Varset one = new Varset(n.item.Length);
            one.Set(0, true);
            while (one.LessThan(m))
            {
                d = d.RightShift(1);
                m = m.RightShift(1);
                if (d.LessThan(n) || d.Equals(n))
                {
                    n = n.Subtract(d);
                    q = q.Or(m);
                }
            }
            q = q.SubVarset(length);
            n = n.SubVarset(length);
            //Console.Write("q: ");
            //q.Print();
            //Console.Write("n: ");
            //n.Print();
            return q;
        }
Beispiel #20
0
 public Varset Not()
 {
     Varset cp = new Varset(this);
     cp.item.Not();
     return cp;
 }
Beispiel #21
0
 public Varset NextPermutation()
 {
     Varset vs = new Varset(this);
     Varset one = new Varset(0);
     one.Set(0, true);
     Varset tmp = vs.Or(vs.Subtract(one)).StaticAdd(one);
     Varset nextVariables = tmp.Or(tmp.And(tmp.Not().StaticAdd(one)).Divide(vs.And(vs.Not().StaticAdd(one))).RightShift(1).Subtract(one));
     return nextVariables;
 }
Beispiel #22
0
 private Varset SubVarset(int index)
 {
     Varset cp = new Varset(this);
     //if (index < cp.item.Count)
     //{
     //    for (int i = index; i < cp.item.Count; i++)
     //    {
     //        cp.item[i] = false;
     //    }
     //}
     cp.item.Length = index;
     return cp;
 }
Beispiel #23
0
 private void AlignLength(Varset varset)
 {
     if (item.Count > varset.item.Count)
     {
         BitArray tmp = new BitArray(item.Count);
         varset.item.Length = item.Length;
         varset.item = varset.item.Or(tmp);
     }
     else if (item.Count < varset.item.Count)
     {
         BitArray tmp = new BitArray(varset.item.Count);
         item.Length = varset.item.Length;
         varset.item = varset.item.Or(tmp);
     }
 }
Beispiel #24
0
 public Varset Subtract(Varset varset)
 {
     Varset cp = new Varset(varset);
     if (cp.item.Length < item.Length)
     {
         AlignLength(cp);
     }
     Varset one = new Varset(1);
     one.Set(0, true);
     cp = cp.Not();
     cp = cp.StaticAdd(one);
     cp = cp.StaticAdd(this);
     return cp;
 }
Beispiel #25
0
 public ContingencyTableNode MakeContab(Varset variables)
 {
     return MakeContab(variables, root, -1);
 }
 public double Calculate(int variable, Varset parents)
 {
     Dictionary<ulong, int> paCounts = new Dictionary<ulong, int>();
     return Calculate(variable, parents, paCounts);
 }
Beispiel #27
0
        public bool Equals(Varset varset)
        {
            Varset cp = new Varset(this);
            Varset cp2 = new Varset(varset);
            cp.AlignLength(cp2);

            for (int i = 0; i < cp.item.Length; i++)
            {
                if (cp.item[i] != cp2.item[i])
                {
                    return false;
                }
            }
            return true;
        }
Beispiel #28
0
        public bool LessThan(Varset varset)
        {
            Varset cp = new Varset(this);
            Varset cp2 = new Varset(varset);
            cp.AlignLength(cp2);

            for (int i = cp.item.Length - 1; i >= 0; i--)
            {
                if (cp.item[i] && !cp2.item[i])
                {
                    return false;
                }
                else if (!cp.item[i] && cp2.item[i])
                {
                    return true;
                }
            }

            return false;
        }
Beispiel #29
0
 public static Varset ClearCopy(Varset varset, int index)
 {
     Varset cp = new Varset(varset);
     cp.Set(index, false);
     return cp;
 }
Beispiel #30
0
        public Varset RightShift(int i)
        {
            Varset cp = new Varset(this);
            for (int k = 0; k < cp.item.Count - i; k++)
            {
                cp.item[k] = cp.item[k + i];
            }

            for (int k = cp.item.Count - i; k < cp.item.Count; k++)
            {
                cp.item[k] = false;
            }
            return cp;
        }