示例#1
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="trainData"></param>
        /// <param name="subsampleIdx"></param>
        /// <param name="forest"></param>
        /// <returns></returns>
#else
        /// <summary>
        ///
        /// </summary>
        /// <param name="trainData"></param>
        /// <param name="subsampleIdx"></param>
        /// <param name="forest"></param>
        /// <returns></returns>
#endif
        public virtual bool Train(CvDTreeTrainData trainData, CvMat subsampleIdx, CvRTrees forest)
        {
            if (trainData == null)
            {
                throw new ArgumentNullException("trainData");
            }
            if (subsampleIdx == null)
            {
                throw new ArgumentNullException("subsampleIdx");
            }
            if (forest == null)
            {
                throw new ArgumentNullException("forest");
            }

            return(NativeMethods.ml_CvForestTree_train(
                       ptr, trainData.CvPtr, subsampleIdx.CvPtr, forest.CvPtr) != 0);
        }
示例#2
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="fs"></param>
        /// <param name="node"></param>
        /// <param name="forest"></param>
        /// <param name="data"></param>
#else
        /// <summary>
        ///
        /// </summary>
        /// <param name="fs"></param>
        /// <param name="node"></param>
        /// <param name="forest"></param>
        /// <param name="data"></param>
#endif
        public virtual void Read(CvFileStorage fs, CvFileNode node, CvRTrees forest, CvDTreeTrainData data)
        {
            if (fs == null)
            {
                throw new ArgumentNullException("fs");
            }
            if (node == null)
            {
                throw new ArgumentNullException("node");
            }
            if (forest == null)
            {
                throw new ArgumentNullException("forest");
            }
            if (data == null)
            {
                throw new ArgumentNullException("data");
            }

            NativeMethods.ml_CvForestTree_read(
                ptr, fs.CvPtr, node.CvPtr, forest.CvPtr, data.CvPtr);
        }
示例#3
0
        /// <summary>
        /// RTrees
        /// </summary>
        /// <param name="dataFilename"></param>
        /// <param name="filenameToSave"></param>
        /// <param name="filenameToLoad"></param>
        private void BuildRtreesClassifier(string dataFilename, string filenameToSave, string filenameToLoad)
        {
            CvMat data = null;
            CvMat responses = null;
            CvMat varType = null;
            CvMat sampleIdx = null;


            int nsamplesAll = 0, ntrainSamples = 0;
            double trainHr = 0, testHr = 0;
            CvRTrees forest = new CvRTrees();

            try
            {
                ReadNumClassData(dataFilename, 16, out data, out responses);
            }
            catch
            {
                Console.WriteLine("Could not read the database {0}", dataFilename);
                return;
            }
            Console.WriteLine("The database {0} is loaded.", dataFilename);

            nsamplesAll = data.Rows;
            ntrainSamples = (int)(nsamplesAll * 0.8);

            // Create or load Random Trees classifier
            if (filenameToLoad != null)
            {
                // load classifier from the specified file
                forest.Load(filenameToLoad);
                ntrainSamples = 0;
                if (forest.GetTreeCount() == 0)
                {
                    Console.WriteLine("Could not read the classifier {0}", filenameToLoad);
                    return;
                }
                Console.WriteLine("The classifier {0} is loaded.", filenameToLoad);
            }
            else
            {
                // create classifier by using <data> and <responses>
                Console.Write("Training the classifier ...");

                // 1. create type mask
                varType = new CvMat(data.Cols + 1, 1, MatrixType.U8C1);
                varType.Set(CvScalar.ScalarAll(CvStatModel.CV_VAR_ORDERED));
                varType.SetReal1D(data.Cols, CvStatModel.CV_VAR_CATEGORICAL);

                // 2. create sample_idx
                sampleIdx = new CvMat(1, nsamplesAll, MatrixType.U8C1);
                {
                    CvMat mat;
                    Cv.GetCols(sampleIdx, out mat, 0, ntrainSamples);
                    mat.Set(CvScalar.RealScalar(1));

                    Cv.GetCols(sampleIdx, out mat, ntrainSamples, nsamplesAll);
                    mat.SetZero();
                }

                // 3. train classifier
                forest.Train(
                    data, DTreeDataLayout.RowSample, responses, null, sampleIdx, varType, null,
                    new CvRTParams(10, 10, 0, false, 15, null, true, 4, new CvTermCriteria(100, 0.01f))
                );
                Console.WriteLine();
            }

            // compute prediction error on train and test data
            for (int i = 0; i < nsamplesAll; i++)
            {
                double r;
                CvMat sample;
                Cv.GetRow(data, out sample, i);

                r = forest.Predict(sample);
                r = Math.Abs((double)r - responses.DataArraySingle[i]) <= float.Epsilon ? 1 : 0;

                if (i < ntrainSamples)
                    trainHr += r;
                else
                    testHr += r;
            }

            testHr /= (double)(nsamplesAll - ntrainSamples);
            trainHr /= (double)ntrainSamples;
            Console.WriteLine("Recognition rate: train = {0:F1}%, test = {1:F1}%", trainHr * 100.0, testHr * 100.0);

            Console.WriteLine("Number of trees: {0}", forest.GetTreeCount());

            // Print variable importance
            Mat varImportance0 = forest.GetVarImportance();
            CvMat varImportance = varImportance0.ToCvMat();
            if (varImportance != null)
            {
                double rtImpSum = Cv.Sum(varImportance).Val0;
                Console.WriteLine("var#\timportance (in %):");
                for (int i = 0; i < varImportance.Cols; i++)
                {
                    Console.WriteLine("{0}\t{1:F1}", i, 100.0f * varImportance.DataArraySingle[i] / rtImpSum);
                }
            }

            // Print some proximitites
            Console.WriteLine("Proximities between some samples corresponding to the letter 'T':");
            {
                CvMat sample1, sample2;
                int[,] pairs = new int[,] { { 0, 103 }, { 0, 106 }, { 106, 103 }, { -1, -1 } };

                for (int i = 0; pairs[i, 0] >= 0; i++)
                {
                    Cv.GetRow(data, out sample1, pairs[i, 0]);
                    Cv.GetRow(data, out sample2, pairs[i, 1]);
                    Console.WriteLine("proximity({0},{1}) = {2:F1}%", pairs[i, 0], pairs[i, 1], forest.GetProximity(sample1, sample2) * 100.0);
                }
            }

            // Save Random Trees classifier to file if needed
            if (filenameToSave != null)
            {
                forest.Save(filenameToSave);
            }


            Console.Read();


            if (sampleIdx != null) sampleIdx.Dispose();
            if (varType != null) varType.Dispose();
            data.Dispose();
            responses.Dispose();
            forest.Dispose();
        }
示例#4
0
        /// <summary>
        /// 
        /// </summary>
        /// <param name="fs"></param>
        /// <param name="node"></param>
        /// <param name="forest"></param>
        /// <param name="data"></param>
#else
        /// <summary>
        /// 
        /// </summary>
        /// <param name="fs"></param>
        /// <param name="node"></param>
        /// <param name="forest"></param>
        /// <param name="data"></param>
#endif
        public virtual void Read(CvFileStorage fs, CvFileNode node, CvRTrees forest, CvDTreeTrainData data)
        {
            if (fs == null)
                throw new ArgumentNullException("fs");
            if (node == null)
                throw new ArgumentNullException("node");
            if (forest == null)
                throw new ArgumentNullException("forest");
            if (data == null)
                throw new ArgumentNullException("data");

            NativeMethods.ml_CvForestTree_read(
                ptr, fs.CvPtr, node.CvPtr, forest.CvPtr, data.CvPtr);
        }
示例#5
0
        /// <summary>
        /// 
        /// </summary>
        /// <param name="trainData"></param>
        /// <param name="subsampleIdx"></param>
        /// <param name="forest"></param>
        /// <returns></returns>
#else
        /// <summary>
        /// 
        /// </summary>
        /// <param name="trainData"></param>
        /// <param name="subsampleIdx"></param>
        /// <param name="forest"></param>
        /// <returns></returns>
#endif
        public virtual bool Train( CvDTreeTrainData trainData, CvMat subsampleIdx, CvRTrees forest )
        {
            if (trainData == null)
                throw new ArgumentNullException("trainData");
            if (subsampleIdx == null)
                throw new ArgumentNullException("subsampleIdx");
            if (forest == null)
                throw new ArgumentNullException("forest");

            return NativeMethods.ml_CvForestTree_train(
                ptr, trainData.CvPtr, subsampleIdx.CvPtr, forest.CvPtr) != 0;
        }