예제 #1
0
        public override double CalculateScore(int variable, Varset parents, DoubleMap cache)
        {
            Varset parentsCp = new Varset(parents);
            // TODO check if this violates the constraints
            //if(constraints != NULL && !constraints.SatisfiesConstraints(variable, parents))
            //{
            //}

            // check for prunning
            double tVal = T(variable, parentsCp);

            for (int x = 0; x < network.Size(); x++)
            {
                if (parentsCp.Get(x))
                {
                    parentsCp.Set(x, false);

                    // TODO check the constraints
                    //if (invalidParents.Count > 0 && invalidParents.Contains(parentsCp.ToULong()))
                    //{
                    //    parentsCp.Set(x);
                    //    continue;
                    //}

                    double tmp = cache.ContainsKey(parentsCp.ToULong()) ? cache[parentsCp.ToULong()] : 0;
                    if (tmp + tVal > 0)
                    {
                        return 0;
                    }

                    parentsCp.Set(x, true);
                }
            }

            double score = llc.Calculate(variable, parentsCp);

            // structure penalty
            score -= tVal * baseComplexityPenalty;
            //Console.WriteLine("score: " + score);

            return score;
        }
예제 #2
0
        public void Execute(string[] args)
        {
            // 引数のチェック
            if (args.Length != 2)
            {
                // エラーメッセージ表示
                Console.WriteLine("[エラー] 引数の数が不正です");
                Environment.Exit(0);
            }

            // オプションのチェック
            // 暫定的に初期値を代入
            int rMin = 5; // The minimum number of records in the AD-tree nodes.
            char delimiter = ',';
            bool hasHeader = true;
            string sf = "BDeu";
            int maxParents = 0;
            string constraintsFile = "";
            int runningTime = -1;
            int threadCount = 20;
            bool prune = true;
            string inputFile = args[0];
            string outputFile = args[1];
            double equivarentSampleSize = 1;

            // csvファイルの読み込み
            RecordFile recordFile = new RecordFile();
            recordFile.ReadRecord(inputFile, hasHeader, delimiter);
            //recordFile.Print();

            // BayesianNetworkの初期化
            BayesianNetwork network = new BayesianNetwork(recordFile);

            // AD-Treeの初期化
            ADTree adTree = new ADTree(rMin, network, recordFile);

            // scoring functionの設定
            sf = sf.ToLower();
            if (maxParents > network.Size() || maxParents < 1)
            {
                maxParents = network.Size() - 1;
            }

            if (sf == "bic")
            {
                int maxParentCount = (int)Math.Log(2 * recordFile.Size() / Math.Log(recordFile.Size()));
                if(maxParentCount < maxParents)
                {
                    maxParents = maxParentCount;
                }
            }
            else if (sf == "fnml") { }
            else if (sf == "bdeu") { }
            else
            {
                throw new ArgumentException("Invalid scoring function. Options are: 'BIC', 'fNML' or 'BDeu'.");
            }

            // TODO constraintsの実装
            Constraints constraints = null;
            if (constraintsFile.Length > 0)
            {
                constraints = Constraints.ParseConstraints(constraintsFile, network);
            }

            ScoringFunction scoringFunction = null;

            List<double> ilogi = LogLikelihoodCalculator.GetLogCache(recordFile.Size());
            LogLikelihoodCalculator llc = new LogLikelihoodCalculator(adTree, network, ilogi);

            // TODO regretの実装
            //std::vector<std::vector<float>*>* regret = scoring::getRegretCache(recordFile.size(), network.getMaxCardinality());

            if (sf == "bic")
            {
                scoringFunction = new BICScoringFunction(network, recordFile, llc, constraints);
            }
            else if (sf == "fnml")
            {
                //
            }
            else if (sf == "bdeu")
            {
                scoringFunction = new BDeuScoringFunction(equivarentSampleSize, network, adTree, constraints);
            }

            ScoreCalculator scoreCalculator = new ScoreCalculator(scoringFunction, maxParents, network.Size(), runningTime, constraints);
            // TODO Parallelクラスによる並列処理
            Parallel.For(0, threadCount, i =>
            {
                for (int variable = 0; variable < network.Size(); variable++)
                {
                    if(variable % threadCount != i)
                    {
                        continue;
                    }

                    Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Time: " + DateTime.Now);

                    DoubleMap sc = new DoubleMap();
                    scoreCalculator.CalculateScores(variable, sc);

                    int size = sc.Count;
                    Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Size before pruning: " + size + ", Time: " + DateTime.Now);

                    if (prune)
                    {
                        scoreCalculator.Prune(sc);
                        int prunedSize = sc.Count;
                        Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Size after pruning: " + prunedSize + ", Time: " + DateTime.Now);
                    }

                    string varFilename = outputFile + "." + variable;
                    StreamWriter writer = new StreamWriter(varFilename, false);

                    Variable var = network.Get(variable);
                    writer.Write("VAR " + var.Name + "\n");
                    writer.Write("META arity=" + var.GetCardinality() + "\n");
                    writer.Write("META values=");
                    for (int k = 0; k < var.GetCardinality(); k++)
                    {
                        writer.Write(var.GetValue(k) + " ");
                    }
                    writer.Write("\n");

                    foreach (KeyValuePair<ulong, double> kvp in sc)
                    {
                        Varset parentSet = new Varset(kvp.Key);
                        double s = kvp.Value;

                        writer.Write(s + " ");

                        for (int p = 0; p < network.Size(); p++)
                        {
                            if (parentSet.Get(p))
                            {
                                writer.Write(network.Get(p).Name + " ");
                            }
                        }
                        writer.Write("\n");
                    }
                    writer.Write("\n");
                    writer.Close();
                }
            });

            // concatenate all of the files together
            StreamWriter outFile = new StreamWriter(outputFile, false); // 修正が必要かも
            // first, the header information
            string header = "META pss_version = 0.1\nMETA input_file=" + inputFile + "\nMETA num_records=" + recordFile.Size() + "\n";
            header += "META parent_limit=" + maxParents + "\nMETA score_type=" + sf + "\nMETA ess=" + equivarentSampleSize + "\n\n";

            outFile.Write(header);

            for (int variable = 0; variable < network.Size(); variable++)
            {
                string varFilename = outputFile + "." + variable;
                StreamReader reader = new StreamReader(varFilename);
                outFile.Write(reader.ReadToEnd());
                reader.Close();
            }
            outFile.Close();
        }
        private void Calculate(ContingencyTableNode ct, ulong currentBase, ulong index, Dictionary<ulong, int> paCounts, int variable, Varset variables, int previousVariable, ref double score)
        {
            // if this is a leaf
            if (ct.IsLeaf())
            {
                // update the instantiation count of this set of parents
                int count = paCounts.ContainsKey(index) ? paCounts[index] : 0;
                count += ct.Value;
                paCounts[index] = count;

                // update the score for this variable, parent instantiation
                score += ilogi[ct.Value];
                return;
            }

            // which actual variable are we looking at
            int thisVariable = previousVariable + 1;
            for (; thisVariable < network.Size(); thisVariable++)
            {
                if (variables.Get(thisVariable))
                {
                    break;
                }
            }

            // update the base and index if this is part of the parent set
            ulong nextBase = currentBase;
            if (thisVariable != variable)
            {
                nextBase *= (ulong)network.GetCardinality(thisVariable);
            }

            // recurse
            for (int k = 0; k < network.GetCardinality(thisVariable); k++)
            {
                ContingencyTableNode child = ct.GetChild(k);
                if (child != null)
                {
                    ulong newIndex = index;
                    if (thisVariable != variable)
                    {
                        newIndex += currentBase * (ulong)k;
                    }
                    Calculate(child, nextBase, newIndex, paCounts, variable, variables, thisVariable, ref score);
                }
            }
        }
예제 #4
0
        private void CalculateScoresInternal(int variable, DoubleMap cache)
        {
            // calculate initial score
            Varset empty = new Varset(variableCount + 1); // 注意: c++だと0
            double score = scoringFunction.CalculateScore(variable, empty, cache);

            if (score < 1)
            {
                cache[empty.ToULong()] = score;
            }

            int prunedCount = 0;

            for (int layer = 1; layer <= maxParents && !outOfTime; layer++)
            {
                // DEBUG
                Console.WriteLine("layer: " + layer + ", prunedCount: " + prunedCount);

                Varset variables = new Varset(variableCount + 1); // 注意: c++だと0
                for (int i = 0; i < layer; i++)
                {
                    variables.Set(i, true);
                }

                Varset max = new Varset(variableCount);
                max.Set(variableCount, true);

                while (variables.LessThan(max) && !outOfTime)
                {
                    if (!variables.Get(variable))
                    {
                        score = scoringFunction.CalculateScore(variable, variables, cache);

                        if (score < 0)
                        {
                            cache[variables.ToULong()] = score;
                        }
                        else
                        {
                            prunedCount++;
                        }
                    }

                    variables = variables.NextPermutation();
                }

                if (!outOfTime)
                {
                    highestCompletedLayer = layer;
                }
            }
        }
예제 #5
0
        public override double CalculateScore(int variable, Varset parents, DoubleMap cache)
        {
            Scratch s = scratchSpace[variable];

            // TODO check if this violates the constraints
            //if (constraints != NULL && !constraints->satisfiesConstraints(variable, parents))
            //{
            //    s->invalidParents.insert(parents);
            //    return 1;
            //}

            for (int x = 0; x < network.Size(); x++)
            {
                if (parents.Get(x))
                {
                    parents.Set(x, false);

                    // TODO check the constraints
                    //if (invalidParents.Count > 0 && invalidParents.Contains(parents.ToULong()))
                    //{
                    //    // we cannot say anything if we skipped this because of constraints
                    //    parents.Set(x, true);
                    //    continue;
                    //}

                    parents.Set(x, true);
                }
            }

            Lg(parents, ref s);
            Varset variables = new Varset(parents);
            variables.Set(variable, true);

            s.Score = 0;

            ContingencyTableNode ct = adTree.MakeContab(variables);

            Dictionary<ulong, int> paCounts = new Dictionary<ulong, int>();
            Calculate(ct, 1, 0, paCounts, variables, -1, ref s);

            // check constraints (Theorem 9 from de Campos and Ji '11)
            // only bother if the alpha bound is okay
            // check if can prune
            if (s.Aij <= 0.8349)
            {
                double bound = -1.0 * ct.LeafCount * s.Lri;

                // check each subset
                for (int x = 0; x < network.Size(); x++)
                {
                    if (parents.Get(x))
                    {
                        parents.Set(x, false);

                        // check the constraints
                        //if (s->invalidParents.find(parents) != s->invalidParents.end())
                        //{
                        //    // we cannot say anything if we skipped this because of constraints
                        //    VARSET_SET(parents, x);
                        //    continue;
                        //}

                        double tmp = cache.ContainsKey(parents.ToULong()) ? cache[parents.ToULong()] : 0;

                        // if the score is larger (better) than the bound, then we can prune
                        if (tmp > bound)
                        {
                            return 0;
                        }

                        parents.Set(x, true);
                    }
                }
            }

            foreach (KeyValuePair<ulong, int> kvp in paCounts)
            {
                s.Score += s.Lgij;
                s.Score -= ScoreCalculator.GammaLn(s.Aij + kvp.Value);
            }

            //parents.Print();
            //Console.WriteLine(variable);
            return s.Score;
        }
예제 #6
0
        private void Lg(Varset parents, ref Scratch s)
        {
            int r = 1;
            for (int pa = 0; pa < network.Size(); pa++)
            {
                if (parents.Get(pa))
                {
                    r *= network.GetCardinality(pa);
                }
            }

            s.Aij = ess / r;
            s.Lgij = ScoreCalculator.GammaLn(s.Aij);
            r *= network.GetCardinality(s.Variable);
            s.Aijk = ess / r;
            s.Lgijk = ScoreCalculator.GammaLn(s.Aijk);
        }
예제 #7
0
        private void Calculate(ContingencyTableNode ct, ulong currentBase, ulong index, Dictionary<ulong, int> paCounts, Varset variables, int previousVariable, ref Scratch s)
        {
            Varset variablesCp = new Varset(variables);
            // if this is a leaf in the AD-tree
            if (ct.IsLeaf())
            {
                // update the instantiation count of this set of parents
                int count = paCounts.ContainsKey(index) ? paCounts[index] : 0;
                count += ct.Value;

                if (count > 0)
                {
                    paCounts[index] = count;

                    // update the score for this variable, parent instantiation
                    double temp = ScoreCalculator.GammaLn(s.Aijk + ct.Value);
                    s.Score -= s.Lgijk;
                    s.Score += temp;
                }
                return;
            }

            // which actual variable are we looking at
            int thisVariable = previousVariable + 1;
            for (; thisVariable < network.Size(); thisVariable++)
            {
                if (variables.Get(thisVariable))
                {
                    break;
                }
            }

            // update the base and index if this is part of the parent set
            ulong nextBase = currentBase;
            if (thisVariable != s.Variable)
            {
                nextBase *= (ulong)network.GetCardinality(thisVariable);
            }

            // recurse
            for (int k = 0; k < network.GetCardinality(thisVariable); k++)
            {
                ContingencyTableNode child = ct.GetChild(k);
                if (child != null)
                {
                    ulong newIndex = index;
                    if (thisVariable != s.Variable)
                    {
                        newIndex += currentBase * (ulong)k;
                    }
                    Calculate(child, nextBase, newIndex, paCounts, variables, thisVariable, ref s);
                }
            }
        }
예제 #8
0
 private double T(int variable, Varset parents)
 {
     double penalty = network.GetCardinality(variable) - 1;
     for (int pa = 0; pa < network.Size(); pa++)
     {
         if (parents.Get(pa))
         {
             penalty *= network.GetCardinality(pa);
         }
     }
     return penalty;
 }