public BICScoringFunction(BayesianNetwork network, RecordFile recordFile, LogLikelihoodCalculator llc, Constraints constraints) { this.network = network; this.llc = llc; this.constraints = constraints; baseComplexityPenalty = Math.Log(recordFile.Size()) / 2; }
public List<List<BitArray>> GetConsistentRecords(RecordFile recordFile) { List<List<BitArray>> consistentRecords = new List<List<BitArray>>(); for(int i = 0; i < Size(); i++) { consistentRecords.Add(new List<BitArray>()); int count = GetCardinality(i); for(int value = 0; value < count; value++) { consistentRecords[i].Add(new BitArray(recordFile.Size())); } } for (int index = 0; index < recordFile.Size(); index++) { Record record = recordFile.Records[index]; for (int variable = 0; variable < Size(); variable++) { string v = record[variable]; int value = Get(variable).ValueToIndex[v]; consistentRecords[variable][value].Set(index, true); } } return consistentRecords; }
private void Initialize(BayesianNetwork network, RecordFile recordFile) { this.network = network; recordCount = recordFile.Size(); zero = new Varset(network.Size()); consistentRecords = network.GetConsistentRecords(recordFile); }
public ADTree(int rMin, BayesianNetwork network, RecordFile recordFile) { this.rMin = rMin; Initialize(network, recordFile); CreateTree(); }
public void Execute(string[] args) { // 引数のチェック if (args.Length != 2) { // エラーメッセージ表示 Console.WriteLine("[エラー] 引数の数が不正です"); Environment.Exit(0); } // オプションのチェック // 暫定的に初期値を代入 int rMin = 5; // The minimum number of records in the AD-tree nodes. char delimiter = ','; bool hasHeader = true; string sf = "BDeu"; int maxParents = 0; string constraintsFile = ""; int runningTime = -1; int threadCount = 20; bool prune = true; string inputFile = args[0]; string outputFile = args[1]; double equivarentSampleSize = 1; // csvファイルの読み込み RecordFile recordFile = new RecordFile(); recordFile.ReadRecord(inputFile, hasHeader, delimiter); //recordFile.Print(); // BayesianNetworkの初期化 BayesianNetwork network = new BayesianNetwork(recordFile); // AD-Treeの初期化 ADTree adTree = new ADTree(rMin, network, recordFile); // scoring functionの設定 sf = sf.ToLower(); if (maxParents > network.Size() || maxParents < 1) { maxParents = network.Size() - 1; } if (sf == "bic") { int maxParentCount = (int)Math.Log(2 * recordFile.Size() / Math.Log(recordFile.Size())); if(maxParentCount < maxParents) { maxParents = maxParentCount; } } else if (sf == "fnml") { } else if (sf == "bdeu") { } else { throw new ArgumentException("Invalid scoring function. Options are: 'BIC', 'fNML' or 'BDeu'."); } // TODO constraintsの実装 Constraints constraints = null; if (constraintsFile.Length > 0) { constraints = Constraints.ParseConstraints(constraintsFile, network); } ScoringFunction scoringFunction = null; List<double> ilogi = LogLikelihoodCalculator.GetLogCache(recordFile.Size()); LogLikelihoodCalculator llc = new LogLikelihoodCalculator(adTree, network, ilogi); // TODO regretの実装 //std::vector<std::vector<float>*>* regret = scoring::getRegretCache(recordFile.size(), network.getMaxCardinality()); if (sf == "bic") { scoringFunction = new BICScoringFunction(network, recordFile, llc, constraints); } else if (sf == "fnml") { // } else if (sf == "bdeu") { scoringFunction = new BDeuScoringFunction(equivarentSampleSize, network, adTree, constraints); } ScoreCalculator scoreCalculator = new ScoreCalculator(scoringFunction, maxParents, network.Size(), runningTime, constraints); // TODO Parallelクラスによる並列処理 Parallel.For(0, threadCount, i => { for (int variable = 0; variable < network.Size(); variable++) { if(variable % threadCount != i) { continue; } Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Time: " + DateTime.Now); DoubleMap sc = new DoubleMap(); scoreCalculator.CalculateScores(variable, sc); int size = sc.Count; Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Size before pruning: " + size + ", Time: " + DateTime.Now); if (prune) { scoreCalculator.Prune(sc); int prunedSize = sc.Count; Console.WriteLine("Thread: " + i + ", Variable: " + variable + ", Size after pruning: " + prunedSize + ", Time: " + DateTime.Now); } string varFilename = outputFile + "." + variable; StreamWriter writer = new StreamWriter(varFilename, false); Variable var = network.Get(variable); writer.Write("VAR " + var.Name + "\n"); writer.Write("META arity=" + var.GetCardinality() + "\n"); writer.Write("META values="); for (int k = 0; k < var.GetCardinality(); k++) { writer.Write(var.GetValue(k) + " "); } writer.Write("\n"); foreach (KeyValuePair<ulong, double> kvp in sc) { Varset parentSet = new Varset(kvp.Key); double s = kvp.Value; writer.Write(s + " "); for (int p = 0; p < network.Size(); p++) { if (parentSet.Get(p)) { writer.Write(network.Get(p).Name + " "); } } writer.Write("\n"); } writer.Write("\n"); writer.Close(); } }); // concatenate all of the files together StreamWriter outFile = new StreamWriter(outputFile, false); // 修正が必要かも // first, the header information string header = "META pss_version = 0.1\nMETA input_file=" + inputFile + "\nMETA num_records=" + recordFile.Size() + "\n"; header += "META parent_limit=" + maxParents + "\nMETA score_type=" + sf + "\nMETA ess=" + equivarentSampleSize + "\n\n"; outFile.Write(header); for (int variable = 0; variable < network.Size(); variable++) { string varFilename = outputFile + "." + variable; StreamReader reader = new StreamReader(varFilename); outFile.Write(reader.ReadToEnd()); reader.Close(); } outFile.Close(); }
public void AddValues(RecordFile recordFile) { recordFile.Records.ForEach(line => AddValue(line[index])); }
private void Initialize(RecordFile recordFile) { // 変数名の設定 for (int i = 0; i < recordFile.Header.Count; i++) { Variable v = new Variable(this, i); if (recordFile.HasHeader) { v.Name = recordFile.Header[i]; } else { string variableName = "Variable_" + i; v.Name = variableName; } nameToIndex[v.Name] = i; variables.Add(v); } // 値をセットする variables.ForEach(variable => variable.AddValues(recordFile)); }
public BayesianNetwork(RecordFile recordFile) { Initialize(recordFile); }