Beispiel #1
0
        /// <summary>
        /// 
        /// </summary>
        /// <param name="dtree"></param>
        private void PrintVariableImportance(CvDTree dtree )
        {
            Mat varImportance0 = dtree.GetVarImportance();
            CvMat varImportance = varImportance0.ToCvMat();

            if( varImportance == null )
            {
                Console.WriteLine( "Error: Variable importance can not be retrieved" );
                return;
            }

            Console.Write( "Print variable importance information? (y/n) " );
            string input = Console.ReadLine();
            if( input[0] != 'y' && input[0] != 'Y' )
            {
                return;
            }

            for (int i = 0; i < varImportance.Cols * varImportance.Rows; i++)
            {
                double val = varImportance.DataArrayDouble[i];
                int len = VarDesc[i].IndexOf('(');
                Console.Write("{0}", VarDesc[i].Substring(0, len));
                Console.WriteLine(": {0}%", val * 100.0);
            }
        }
Beispiel #2
0
        /// <summary>
        /// 
        /// </summary>
        /// <param name="dtree"></param>
        private void InteractiveClassification(CvDTree dtree)
        {
            if (dtree == null)
            {
                return;
            }

            CvDTreeNode root = dtree.GetRoot();
            CvDTreeTrainData data = dtree.GetData();
            string input;

            for (; ; )
            {
                CvDTreeNode node;

                Console.Write("Start/Proceed with interactive mushroom classification (y/n): ");
                input = Console.ReadLine();
                if (input[0] != 'y' && input[0] != 'Y')
                {
                    break;
                }
                Console.WriteLine("Enter 1-letter answers, '?' for missing/unknown value...");

                // custom version of predict
                node = root;
                for (; ; )
                {
                    CvDTreeSplit split = node.Split;
                    int dir = 0;

                    if (node.Left == null || node.Tn <= dtree.GetPrunedTreeIdx() || node.Split == null)
                    {
                        break;
                    }

                    for (; split != null; )
                    {
                        int j;
                        int vi = split.VarIdx;
                        int count = data.CatCount.DataArrayInt32[vi];


                        Console.Write("{0}: ", VarDesc[vi]);
                        input = Console.ReadLine();


                        if (input[0] == '?')
                        {
                            split = split.Next;
                            continue;
                        }

                        // convert the input character to the normalized value of the variable
                        unsafe
                        {
                            int* map = data.CatMap.DataInt32 + data.CatOfs.DataInt32[vi];
                            for (j = 0; j < count; j++)
                            {
                                if (map[j] == input[0])
                                {
                                    break;
                                }
                            }
                        }

                        if (j < count)
                        {
                            dir = (split.Subset[j >> 5] & (1 << (j & 31))) != 0 ? -1 : 1;
                            if (split.Inversed)
                            {
                                dir = -dir;
                            }
                            break;
                        }
                        else
                        {
                            Console.WriteLine("Error: unrecognized value");
                        }
                    }

                    if (dir == 0)
                    {
                        Console.WriteLine("Impossible to classify the sample");
                        node = null;
                        break;
                    }
                    node = dir < 0 ? node.Left : node.Right;
                }

                if (node != null)
                {
                    Console.Write("Prediction result: the mushroom is {0}\n", node.ClassIdx == 0 ? "EDIBLE" : "POISONOUS");
                }
                Console.Write("\n-----------------------------\n");
            }
        }
Beispiel #3
0
        /// <summary>
        /// 
        /// </summary>
        /// <param name="data"></param>
        /// <param name="missing"></param>
        /// <param name="responses"></param>
        /// <param name="pWeight"></param>
        /// <returns></returns>
        private CvDTree MushroomCreateDTree(CvMat data, CvMat missing, CvMat responses, float pWeight)
        {
            float[] priors = { 1, pWeight };

            CvMat varType = new CvMat(data.Cols + 1, 1, MatrixType.U8C1);
            Cv.Set(varType, CvScalar.ScalarAll(CvStatModel.CV_VAR_CATEGORICAL)); // all the variables are categorical

            CvDTree dtree = new CvDTree();

            CvDTreeParams p = new CvDTreeParams(8, // max depth
                                            10, // min sample count
                                            0, // regression accuracy: N/A here
                                            true, // compute surrogate split, as we have missing data
                                            15, // max number of categories (use sub-optimal algorithm for larger numbers)
                                            10, // the number of cross-validation folds
                                            true, // use 1SE rule => smaller tree
                                            true, // throw away the pruned tree branches
                                            priors // the array of priors, the bigger p_weight, the more attention
                // to the poisonous mushrooms
                // (a mushroom will be judjed to be poisonous with bigger chance)
            );

            dtree.Train(data, DTreeDataLayout.RowSample, responses, null, null, varType, missing, p);

            // compute hit-rate on the training database, demonstrates predict usage.
            int hr1 = 0, hr2 = 0, pTotal = 0;
            for (int i = 0; i < data.Rows; i++)
            {
                CvMat sample, mask;
                Cv.GetRow(data, out sample, i);
                Cv.GetRow(missing, out mask, i);
                double r = dtree.Predict(sample, mask).Value;
                bool d = Math.Abs(r - responses.DataArraySingle[i]) >= float.Epsilon;
                if (d)
                {
                    if (r != 'p')
                        hr1++;
                    else
                        hr2++;
                }
                //Console.WriteLine(responses.DataArraySingle[i]);
                pTotal += (responses.DataArraySingle[i] == (float)'p') ? 1 : 0;
            }

            Console.WriteLine("Results on the training database");
            Console.WriteLine("\tPoisonous mushrooms mis-predicted: {0} ({1}%)", hr1, (double)hr1 * 100 / pTotal);
            Console.WriteLine("\tFalse-alarms: {0} ({1}%)", hr2, (double)hr2 * 100 / (data.Rows - pTotal));

            varType.Dispose();

            return dtree;
        }