Ejemplo n.º 1
0
        public ImageTrainDataSet SplitTestData(int num)
        {
            Random rand = new Random();
            //List<ImageTrainData> ret = _self.GetRange(0, num);
            //_self.RemoveRange(0, num);
            int numOutput         = this.GetNumOutput();
            ImageTrainDataSet ret = new ImageTrainDataSet();

            while (ret.Count() < num)
            {
                for (int i = 0; i < numOutput; i++)
                {
                    while (true)
                    {
                        int index = rand.Next(this.Count);
                        if (this[index].labelIndex == i)
                        {
                            ret.Add(this[index]);
                            this.RemoveAt(index);
                            break;
                        }
                    }
                }
            }
            return(ret);
        }
Ejemplo n.º 2
0
        public ImageTrainDataSet BalanceData()
        {
            Random            rand = new Random();
            ImageTrainDataSet ret  = new ImageTrainDataSet();

            ret.AddRange(this);
            TrainDataInfo[] infos    = this.GetInfo(this.GetNumOutput());
            int             maxCount = infos.Max(x => x.count);

            for (int i = 0; i < infos.Length; i++)
            {
                for (int j = infos[i].count; j < maxCount; j++)
                {
                    //Random _self for that index
                    while (true)
                    {
                        int index = rand.Next(this.Count);
                        if (this[index].labelIndex == i)
                        {
                            ret.Add(this[index].Clone());
                            break;
                        }
                    }
                }
            }
            return(ret);
        }
Ejemplo n.º 3
0
        private void testWithTestSetToolStripMenuItem_Click(object sender, EventArgs e)
        {
            ImageTrainDataSet testDataSet = FeatureDetector.GetAllTrainImageData(testDataPath, configure.trainFolders);

            testDataSet.flgCache = true;
            byte[] hash = imgClassifier.ComputeHash(testDataSet);
            logger.logStr(hash.ToHex());
            int testErrCount = imgClassifier.Evaluate(testDataSet);

            logger.logStr(String.Format("test {0}", testErrCount));
        }
Ejemplo n.º 4
0
 private void loadTrainDataToolStripMenuItem_Click_1(object sender, EventArgs e)
 {
     trainData = FeatureDetector.GetAllTrainImageData(dataPath, configure.trainFolders);
     trainData = trainData.BalanceData();
     trainData.Shuffle();
     testData = trainData.SplitTestData((int)(0.1 * trainData.Count()));
     TrainDataInfo[] infos     = trainData.GetInfo(configure.trainFolders.Count());
     TrainDataInfo[] testinfos = testData.GetInfo(configure.trainFolders.Count());
     logger.logStr(Utils.ToJsonString(infos, true));
     logger.logStr(Utils.ToJsonString(testinfos, true));
 }
Ejemplo n.º 5
0
        public int Evaluate(ImageTrainDataSet trainData)
        {
            double[][] features    = trainData.GetFeature(bow, mask);
            int[]      labelIndexs = trainData.GetLabelIndexs();
            String[]   labels      = trainData.GetLabels();
            int        errorCount  = 0;

            for (int i = 0; i < trainData.Count(); i++)
            {
                //double[] feature = bow.Transform(images[i]);
                double[] answer = network.Compute(features[i]);

                int expected = labelIndexs[i];
                int actual; answer.Max(out actual);
                if (actual != expected)
                {
                    errorCount++;
                }
            }
            return(errorCount);
        }
Ejemplo n.º 6
0
        public static ImageTrainDataSet GetAllTrainImageData(String folder, String[] subFolders, String imageFilter = "*.png", String imageFilterOut = "mask.png")
        {
            ImageTrainDataSet ret = new ImageTrainDataSet();

            for (int i = 0; i < subFolders.Length; i++)
            {
                String          folderName = folder + @"\" + subFolders[i];
                ImageFileData[] images     = GetImagesFromDir(folderName, imageFilter, imageFilterOut);
                foreach (ImageFileData img in images)
                {
                    ret.Add(

                        new ImageTrainData()
                    {
                        label      = subFolders[i],
                        labelIndex = i,
                        fileName   = img.fileName,
                    }
                        );
                }
            }
            return(ret);
        }
Ejemplo n.º 7
0
 public byte[] ComputeHash(ImageTrainDataSet trainData)
 {
     return(trainData.MD5Feature(bow, mask));
 }
Ejemplo n.º 8
0
        private void trainToolStripMenuItem_Click_1(object sender, EventArgs e)
        {
            ImageTrainDataSet testDataSet = FeatureDetector.GetAllTrainImageData(testDataPath, configure.trainFolders);

            testDataSet.flgCache = true;

            int[]    labelIndexs = trainData.GetLabelIndexs();
            String[] labels      = trainData.GetLabels();
            var      bow         = Accord.IO.Serializer.Load <BagOfVisualWords>(dataPath + String.Format(@"\train-{0}.bow", bowSize));

            double[][] features  = trainData.GetFeature(bow, mask);
            int        numOutput = trainData.GetNumOutput();
            var        function  = new SigmoidFunction();

            logger.logStr("Start Training");
            bool flgFound = false;
            int  count    = 0;

            while ((flgFound == false) && (count < 100))
            {
                count++;
                var network = new ActivationNetwork(function, bow.NumberOfOutputs, 20, numOutput);
                new NguyenWidrow(network).Randomize();
                var teacher = new ParallelResilientBackpropagationLearning(network);

                BowImageClassifier trainImgClassifier = new BowImageClassifier();
                trainImgClassifier.Init(bow, network, mask);
                //creat output
                double[][] outputs    = trainData.GetOutputs(numOutput);
                double     avgError   = 10000.0;
                double     prevError  = avgError;
                double     bestError  = avgError;
                int        errorCount = 0;
                while ((errorCount < 3) && (avgError > 0.00001))
                {
                    //Application.DoEvents();
                    double[] errors = new double[10];
                    for (int i = 0; i < 10; i++)
                    {
                        errors[i] = teacher.RunEpoch(features, outputs);
                    }
                    avgError = errors.Average();
                    if (prevError > avgError)
                    {
                        int trainError   = trainImgClassifier.Evaluate(trainData);
                        int testError    = trainImgClassifier.Evaluate(testData);
                        int testSetError = trainImgClassifier.Evaluate(testDataSet);
                        logger.logStr(String.Format("{0} {1} {2} {3} {4} #{5}", avgError, prevError, trainError, testError, testSetError, errorCount));
                        prevError = avgError;
                        //save best error
                        if (bestError > avgError)
                        {
                            bestError = avgError;
                            Accord.IO.Serializer.Save(network, dataPath + String.Format(@"\train-{0}.net", bow.NumberOfOutputs));
                        }
                        if (trainError + testError + testSetError == 0)
                        {
                            flgFound = true;
                            Accord.IO.Serializer.Save(network, dataPath + String.Format(@"\train-{0}.net", bow.NumberOfOutputs));
                            break;
                        }
                    }
                    else
                    {
                        logger.logStr(String.Format("{0}", avgError));
                        prevError = 10000.0;
                        errorCount++;
                    }
                    Application.DoEvents();
                }
                logger.logStr("Done " + bestError + " " + count);
            }
        }