/// <summary> /// /// </summary> /// <param name="dtree"></param> private void PrintVariableImportance(CvDTree dtree) { Mat varImportance0 = dtree.GetVarImportance(); CvMat varImportance = varImportance0.ToCvMat(); if (varImportance == null) { Console.WriteLine("Error: Variable importance can not be retrieved"); return; } Console.Write("Print variable importance information? (y/n) "); string input = Console.ReadLine(); if (input[0] != 'y' && input[0] != 'Y') { return; } for (int i = 0; i < varImportance.Cols * varImportance.Rows; i++) { double val = varImportance.DataArrayDouble[i]; int len = VarDesc[i].IndexOf('('); Console.Write("{0}", VarDesc[i].Substring(0, len)); Console.WriteLine(": {0}%", val * 100.0); } }
/// <summary> /// /// </summary> /// <param name="data"></param> /// <param name="missing"></param> /// <param name="responses"></param> /// <param name="pWeight"></param> /// <returns></returns> private CvDTree MushroomCreateDTree(CvMat data, CvMat missing, CvMat responses, float pWeight) { float[] priors = { 1, pWeight }; CvMat varType = new CvMat(data.Cols + 1, 1, MatrixType.U8C1); Cv.Set(varType, CvScalar.ScalarAll(CvStatModel.CV_VAR_CATEGORICAL)); // all the variables are categorical CvDTree dtree = new CvDTree(); CvDTreeParams p = new CvDTreeParams(8, // max depth 10, // min sample count 0, // regression accuracy: N/A here true, // compute surrogate split, as we have missing data 15, // max number of categories (use sub-optimal algorithm for larger numbers) 10, // the number of cross-validation folds true, // use 1SE rule => smaller tree true, // throw away the pruned tree branches priors // the array of priors, the bigger p_weight, the more attention // to the poisonous mushrooms // (a mushroom will be judjed to be poisonous with bigger chance) ); dtree.Train(data, DTreeDataLayout.RowSample, responses, null, null, varType, missing, p); // compute hit-rate on the training database, demonstrates predict usage. int hr1 = 0, hr2 = 0, pTotal = 0; for (int i = 0; i < data.Rows; i++) { CvMat sample, mask; Cv.GetRow(data, out sample, i); Cv.GetRow(missing, out mask, i); double r = dtree.Predict(sample, mask).Value; bool d = Math.Abs(r - responses.DataArraySingle[i]) >= float.Epsilon; if (d) { if (r != 'p') { hr1++; } else { hr2++; } } //Console.WriteLine(responses.DataArraySingle[i]); pTotal += (responses.DataArraySingle[i] == (float)'p') ? 1 : 0; } Console.WriteLine("Results on the training database"); Console.WriteLine("\tPoisonous mushrooms mis-predicted: {0} ({1}%)", hr1, (double)hr1 * 100 / pTotal); Console.WriteLine("\tFalse-alarms: {0} ({1}%)", hr2, (double)hr2 * 100 / (data.Rows - pTotal)); varType.Dispose(); return(dtree); }
/// <summary> /// /// </summary> /// <param name="dtree"></param> private void InteractiveClassification(CvDTree dtree) { if (dtree == null) { return; } CvDTreeNode root = dtree.GetRoot(); CvDTreeTrainData data = dtree.GetData(); string input; for (; ;) { CvDTreeNode node; Console.Write("Start/Proceed with interactive mushroom classification (y/n): "); input = Console.ReadLine(); if (input[0] != 'y' && input[0] != 'Y') { break; } Console.WriteLine("Enter 1-letter answers, '?' for missing/unknown value..."); // custom version of predict node = root; for (; ;) { CvDTreeSplit split = node.Split; int dir = 0; if (node.Left == null || node.Tn <= dtree.GetPrunedTreeIdx() || node.Split == null) { break; } for (; split != null;) { int j; int vi = split.VarIdx; int count = data.CatCount.DataArrayInt32[vi]; Console.Write("{0}: ", VarDesc[vi]); input = Console.ReadLine(); if (input[0] == '?') { split = split.Next; continue; } // convert the input character to the normalized value of the variable unsafe { int *map = data.CatMap.DataInt32 + data.CatOfs.DataInt32[vi]; for (j = 0; j < count; j++) { if (map[j] == input[0]) { break; } } } if (j < count) { dir = (split.Subset[j >> 5] & (1 << (j & 31))) != 0 ? -1 : 1; if (split.Inversed) { dir = -dir; } break; } else { Console.WriteLine("Error: unrecognized value"); } } if (dir == 0) { Console.WriteLine("Impossible to classify the sample"); node = null; break; } node = dir < 0 ? node.Left : node.Right; } if (node != null) { Console.Write("Prediction result: the mushroom is {0}\n", node.ClassIdx == 0 ? "EDIBLE" : "POISONOUS"); } Console.Write("\n-----------------------------\n"); } }