示例#1
0
 internal void PreCalcFeatureValue(SampleCollection posSamples, SampleCollection negSamples)
 {
     //             foreach (WeakClassifier weak in _weakClassifiers)
     //             {
     //                 weak.PreCalcFeatureValue(posSamples, negSamples);
     //             }
     Parallel.ForEach(_weakClassifiers,
                      weak =>
     {
         weak.PreCalcFeatureValue(posSamples, negSamples);
     });
 }
示例#2
0
        private PredictResult EvaluateErrorRate(SampleCollection validateSamples, double minHitRate, double maxFalsePositiveRate)
        {
            const MyFloat  FLT_EPSILON = 1.192092896e-07F;       /* smallest such that 1.0+FLT_EPSILON != 1.0 */
            int            numPos = validateSamples.PosCount, numNeg = validateSamples.NegCount;
            int            count = validateSamples.Count;
            int            i, numFalse = 0, numPosTrue = 0;
            List <MyFloat> values = new List <MyFloat>(numPos);

            //统计正样本的特征值
            foreach (ISample sample in validateSamples)
            {
                if (sample.IsPositive)
                {
                    values.Add(this.PredictGetSum(sample));
                }
            }
            values.Sort();
            int thresholdIdx = (int)((1.0 - minHitRate) * numPos);

            _threshold = values[thresholdIdx];
            numPosTrue = numPos - thresholdIdx;

            for (i = thresholdIdx - 1; i >= 0; i--)
            {
                if (Math.Abs(values[i] - _threshold) < FLT_EPSILON)
                {
                    numPosTrue++;
                }
            }
            double hitRate = ((double)numPosTrue) / ((double)numPos);

            //统计负样本的分类误差
            foreach (ISample sample in validateSamples)
            {
                if (!sample.IsPositive && this.Predict(sample))
                {
                    numFalse++;
                }
            }
            double falseAlarm = ((double)numFalse) / ((double)numNeg);

            PredictResult result = new PredictResult();

            result.Count             = count;
            result.PosCount          = numPos;
            result.NegCount          = numNeg;
            result.FalsePositiveRate = falseAlarm;
            result.HitRate           = hitRate;
            return(result);
        }
示例#3
0
        public void PreCalcFeatureValue(SampleCollection posSamples, SampleCollection negSamples)
        {
            int numPos = posSamples.Count, numNeg = negSamples.Count;
            int count = numNeg + numPos;

            _featureValues = new MyFloat[count];
            int     i, j;
            MyFloat value;

            for (i = 0; i < numPos; i++)
            {
                value             = _feature.GetValue(posSamples[i]);
                _featureValues[i] = value;
            }
            j = numPos;
            for (i = 0; i < numNeg; i++)
            {
                value             = _feature.GetValue(negSamples[i]);
                _featureValues[j] = value;
                j++;
            }
        }
示例#4
0
        /// <summary>
        /// 使用指定的样本集和样本权值训练所有的弱分类器
        /// </summary>
        /// <param name="posSamples"></param>
        /// <param name="negSamples"></param>
        /// <param name="sampleWeight"></param>
        internal void Train(SampleCollection posSamples, SampleCollection negSamples, MyFloat[] sampleWeight)
        {
            //             List<WeakClassifier.FeatureValueWithPosFlag> trainTmp = new List<WeakClassifier.FeatureValueWithPosFlag>(sampleWeight.Length);
            //             foreach (WeakClassifier weak in _weakClassifiers)
            //             {
            //                 weak.Train(posSamples, negSamples, sampleWeight, trainTmp);
            //             }

            ConcurrentStack <List <WeakClassifier.FeatureValueWithPosFlag> > stack = new ConcurrentStack <List <WeakClassifier.FeatureValueWithPosFlag> >();

            Parallel.ForEach(_weakClassifiers,
                             () =>
            {
                List <WeakClassifier.FeatureValueWithPosFlag> trainTmp = null;
                if (stack.Count > 0)
                {
                    if (stack.TryPop(out trainTmp) == false)
                    {
                        trainTmp = null;
                    }
                }
                if (trainTmp == null)
                {
                    trainTmp = new List <WeakClassifier.FeatureValueWithPosFlag>(sampleWeight.Length);
                }
                return(trainTmp);
            },
                             (weak, state, trainTmp) =>
            {
                weak.Train(posSamples, negSamples, sampleWeight, trainTmp);
                return(trainTmp);
            },
                             (trainTmp) =>
            {
                stack.Push(trainTmp);
            }
                             );
        }
示例#5
0
        /// <summary>
        /// 训练级联分类器中一层的强分类器
        /// </summary>
        /// <param name="posSamples"></param>
        /// <param name="negSamples"></param>
        /// <param name="maxFalsePositiveRate"></param>
        /// <param name="minHitRate"></param>
        /// <returns>训练的统计结果</returns>
        public PredictResult Train(SampleCollection posSamples,
                                   SampleCollection negSamples,
                                   SampleCollection validateSamples,
                                   double maxFalsePositiveRate,
                                   double minHitRate)
        {
            List <WeakClassifier> weakClassifiers = new List <WeakClassifier>(10);
            PredictResult         result          = new PredictResult();

            MyFloat[]             sampleWeight = InitWeight(posSamples.Count, negSamples.Count);
            WeakClassifierManager allWeak      = WeakClassifierManager.Instance;
            Stopwatch             watch        = new Stopwatch();

            watch.Start();
            allWeak.PreCalcFeatureValue(posSamples, negSamples);
            watch.Stop();
            if (DebugMsg.Debug)
            {
                string msg = string.Format("所有弱分类器特征值预计算完成,用时:{0}\r\n",
                                           watch.Elapsed.ToString());
                DebugMsg.AddMessage(msg, 0);
            }

            int trainTime = 0;

            do
            {
                if (++trainTime != 1)
                {
                    NormalizeWeight(sampleWeight);
                }

                if (DebugMsg.Debug)
                {
                    string msg = string.Format("开始训练第{0}个弱分类器\r\n",
                                               trainTime);
                    DebugMsg.AddMessage(msg, 0);
                }
                watch.Reset();
                watch.Start();


                allWeak.Train(posSamples, negSamples, sampleWeight);
                WeakClassifier newBestClassifier = AdaBoost(posSamples.Count, negSamples.Count, sampleWeight);
                //UpdateWeights(newBestClassifier, posSamples.Count, negSamples.Count, sampleWeight);
                weakClassifiers.Add(newBestClassifier);
                _classifiers = weakClassifiers.ToArray();

                result = EvaluateErrorRate(validateSamples, minHitRate, maxFalsePositiveRate);
                watch.Stop();

                if (DebugMsg.Debug)
                {
                    string msg = string.Format("训练完成,花费时间{0}\r\n检测率:\t{1:P5}\t误检率:\t{2:P5}\r\n",
                                               watch.Elapsed.ToString(),
                                               result.HitRate,
                                               result.FalsePositiveRate);
                    DebugMsg.AddMessage(msg, 1);
                }
            } while (result.FalsePositiveRate > maxFalsePositiveRate);

            allWeak.ReleaseTrainData();
            foreach (WeakClassifier weak in _classifiers)
            {
                weak.ReleaseTrainData();
            }
            return(result);
        }
示例#6
0
 private bool Init(SampleCollection posSamples,
                   SampleCollection negSamples)
 {
     throw new NotImplementedException();
 }