static void Main() { // Set random seed Tools.SetupGenerator(12345); // Read a trainign data file Console.WriteLine("Reading training data\n"); const string inputPath = @"..\..\..\Data\training.csv"; var training = new CsvReader(new StreamReader(inputPath)); // Take 200,000 records for training data var records = training.GetRecords<TrainingRecord>().Take(200000).ToList(); // Set up inputs array var inputs = records.Select(x => new[] { x.PRI_jet_num, x.DER_mass_MMC, x.DER_mass_transverse_met_lep, x.DER_mass_vis, x.DER_pt_h, x.DER_deltaeta_jet_jet, x.DER_mass_jet_jet, x.DER_prodeta_jet_jet, x.DER_deltar_tau_lep, x.DER_pt_tot, x.DER_sum_pt, x.DER_pt_ratio_lep_tau, x.DER_met_phi_centrality, x.DER_lep_eta_centrality }).ToArray(); var labels = records.Select(x => x.Label == "s" ? 1 : 0).ToArray(); var weights = records.Select(x => x.Weight).ToArray(); // Train a forest of 1000 trees Console.WriteLine("Training forest..."); var cls = new AmsRandomForest(1000, 30); cls.Fit(inputs, labels, weights); // Take remaining 50,000 records as held out test set records = training.GetRecords<TrainingRecord>().Take(50000).ToList(); inputs = records.Select(x => new[] { x.PRI_jet_num, x.DER_mass_MMC, x.DER_mass_transverse_met_lep, x.DER_mass_vis, x.DER_pt_h, x.DER_deltaeta_jet_jet, x.DER_mass_jet_jet, x.DER_prodeta_jet_jet, x.DER_deltar_tau_lep, x.DER_pt_tot, x.DER_sum_pt, x.DER_pt_ratio_lep_tau, x.DER_met_phi_centrality, x.DER_lep_eta_centrality }).ToArray(); labels = records.Select(x => x.Label == "s" ? 1 : 0).ToArray(); weights = records.Select(x => x.Weight).ToArray(); // Make predictions at different thresholds to optimise AMS Console.WriteLine("Make predictions and optimise AMS:"); var bestThreshold = 0.1; var bestAms = 0.0; for (var t = 0.1; t < 1.0; t += 0.1) { var p = cls.Predict(inputs, t); var ams = Ams.CalculateAms(labels, p, weights); Console.WriteLine("AMS@{0:0.0} : {1:f6}", t, ams); if (ams > bestAms) { bestAms = ams; bestThreshold = t; } } // Output confusion matrix var predictions = cls.Predict(inputs, bestThreshold); var confusionMatrix = new int[2, 2]; for (var i = 0; i < predictions.Length; i++) { if (predictions[i] == 0 && records[i].Label == "b") confusionMatrix[0, 0]++; if (predictions[i] == 1 && records[i].Label == "b") confusionMatrix[1, 0]++; if (predictions[i] == 0 && records[i].Label == "s") confusionMatrix[0, 1]++; if (predictions[i] == 1 && records[i].Label == "s") confusionMatrix[1, 1]++; } Console.WriteLine(); Console.WriteLine("Confusion Matrix:\n"); Console.WriteLine("\tReal\n\tb\ts"); Console.WriteLine("Pred b\t{0}\t{1}\nPred s\t{2}\t{3}", confusionMatrix[0, 0], confusionMatrix[0, 1], confusionMatrix[1, 0], confusionMatrix[1, 1]); // Write submission file WriteSubmission(cls, bestThreshold); }
private static void WriteSubmission(AmsRandomForest cls, double bestThreshold) { // Write output Console.WriteLine(); Console.WriteLine("Writing Outputs"); const string testPath = @"..\..\..\Data\test.csv"; var test = new CsvReader(new StreamReader(testPath)); var records = test.GetRecords<TestRecord>().ToList(); var inputs = records.Select(x => new[] { x.PRI_jet_num, x.DER_mass_MMC, x.DER_mass_transverse_met_lep, x.DER_mass_vis, x.DER_pt_h, x.DER_deltaeta_jet_jet, x.DER_mass_jet_jet, x.DER_prodeta_jet_jet, x.DER_deltar_tau_lep, x.DER_pt_tot, x.DER_sum_pt, x.DER_pt_ratio_lep_tau, x.DER_met_phi_centrality, x.DER_lep_eta_centrality }).ToArray(); var nodeScores = cls.GetNodeScores(inputs); var predictions = cls.Predict(inputs, bestThreshold); var m = new int[nodeScores.Length][]; for (var i = 0; i < nodeScores.Length; i++) { m[i] = new[] {records[i].EventId, predictions[i]}; } Array.Sort(nodeScores, m); var output = new string[inputs.Length + 1]; output[0] = "EventId,RankOrder,Class"; for (var i = 0; i < inputs.Length; i++) { output[i + 1] = string.Format("{0},{1},{2}", m[i][0], i + 1, m[i][1] == 1 ? "s" : "b"); } File.WriteAllLines(@"..\..\..\Data\submission.csv", output); }