Example #1
0
        public void LDAIrisVariableImpactTest()
        {
            IClassificationProblemData problemData = LoadIrisProblem();
            IClassificationSolution    solution    = LinearDiscriminantAnalysis.CreateLinearDiscriminantAnalysisSolution(problemData);

            ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
            Dictionary <string, double> expectedImpacts = GetExpectedValuesForIrisLDAModel();

            CheckDefaultAsserts(solution, expectedImpacts);
        }
Example #2
0
        public void KNNIrisVariableImpactTest()
        {
            IClassificationProblemData problemData = LoadIrisProblem();
            IClassificationSolution    solution    = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);

            ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
            Dictionary <string, double> expectedImpacts = GetExpectedValuesForIrisKNNModel();

            CheckDefaultAsserts(solution, expectedImpacts);
        }
Example #3
0
        public void WrongDataSetVariableImpactClassificationTest()
        {
            IClassificationProblemData problemData = LoadIrisProblem();
            IClassificationSolution    solution    = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);

            ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
            Dictionary <string, double> expectedImpacts = GetExpectedValuesForIrisKNNModel();

            solution.ProblemData = LoadMammographyProblem();
            ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
        }
        public static IEnumerable <Tuple <string, double> > CalculateImpacts(
            IClassificationSolution solution,
            ReplacementMethodEnum replacementMethod             = ReplacementMethodEnum.Shuffle,
            FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
            DataPartitionEnum dataPartition = DataPartitionEnum.Training)
        {
            IEnumerable <int>    rows = GetPartitionRows(dataPartition, solution.ProblemData);
            IEnumerable <double> estimatedClassValues = solution.GetEstimatedClassValues(rows);
            var model = (IClassificationModel)solution.Model.Clone(); //mkommend: clone of model is necessary, because the thresholds for IDiscriminantClassificationModels are updated

            return(CalculateImpacts(model, solution.ProblemData, estimatedClassValues, rows, replacementMethod, factorReplacementMethod));
        }
        private void RemoveClassificationSolution(IClassificationSolution solution)
        {
            if (!Model.Models.Contains(solution.Model))
            {
                throw new ArgumentException();
            }
            Model.Remove(solution.Model);
            trainingPartitions.Remove(solution.Model);
            testPartitions.Remove(solution.Model);

            trainingEvaluationCache.Clear();
            testEvaluationCache.Clear();
            evaluationCache.Clear();
        }
        private void AddClassificationSolution(IClassificationSolution solution)
        {
            if (Model.Models.Contains(solution.Model))
            {
                throw new ArgumentException();
            }
            Model.Add(solution.Model);
            trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
            testPartitions[solution.Model]     = solution.ProblemData.TestPartition;

            trainingEvaluationCache.Clear();
            testEvaluationCache.Clear();
            evaluationCache.Clear();
        }
Example #7
0
        private void CheckDefaultAsserts(IClassificationSolution solution, Dictionary <string, double> expectedImpacts)
        {
            IClassificationProblemData problemData     = solution.ProblemData;
            IEnumerable <double>       estimatedValues = solution.GetEstimatedClassValues(solution.ProblemData.TrainingIndices);

            var solutionImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
            var modelImpacts    = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution.Model, problemData, estimatedValues, problemData.TrainingIndices);

            //Both ways should return equal results
            Assert.IsTrue(solutionImpacts.SequenceEqual(modelImpacts));

            //Check if impacts are as expected
            Assert.AreEqual(modelImpacts.Count(), expectedImpacts.Count);
            Assert.IsTrue(modelImpacts.All(v => v.Item2.IsAlmost(expectedImpacts[v.Item1])));
        }
Example #8
0
        public void PerformanceVariableImpactClassificationTest()
        {
            int rows    = 1500;
            int columns = 77;
            IClassificationProblemData problemData = CreateDefaultProblem(rows, columns);
            IClassificationSolution    solution    = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);

            Stopwatch watch = new Stopwatch();

            watch.Start();
            var results = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);

            watch.Stop();

            TestContext.WriteLine("");
            TestContext.WriteLine("Calculated cells per millisecond: {0}.", rows * columns / watch.ElapsedMilliseconds);
        }
        protected override void Run(CancellationToken cancellationToken)
        {
            double rmsError, relClassificationError, outOfBagRmsError, outOfBagRelClassificationError;

            if (SetSeedRandomly)
            {
                Seed = Random.RandomSeedGenerator.GetSeed();
            }

            var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);

            Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError)));
            Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError)));
            Results.Add(new Result("Root mean square error (out-of-bag)", "The out-of-bag root of the mean of squared errors of the random forest regression solution.", new DoubleValue(outOfBagRmsError)));
            Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error  of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError)));


            IClassificationSolution solution = null;

            if (ModelCreation == ModelCreation.Model)
            {
                solution = model.CreateClassificationSolution(Problem.ProblemData);
            }
            else if (ModelCreation == ModelCreation.SurrogateModel)
            {
                var problemData    = Problem.ProblemData;
                var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M, problemData.ClassValues.ToArray());

                solution = surrogateModel.CreateClassificationSolution(problemData);
            }

            if (solution != null)
            {
                Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution));
            }
        }
    private void RemoveClassificationSolution(IClassificationSolution solution) {
      if (!Model.Models.Contains(solution.Model)) throw new ArgumentException();
      Model.Remove(solution.Model);
      trainingPartitions.Remove(solution.Model);
      testPartitions.Remove(solution.Model);

      trainingEvaluationCache.Clear();
      testEvaluationCache.Clear();
      evaluationCache.Clear();
    }
    private void AddClassificationSolution(IClassificationSolution solution) {
      if (Model.Models.Contains(solution.Model)) throw new ArgumentException();
      Model.Add(solution.Model);
      trainingPartitions[solution.Model] = solution.ProblemData.TrainingPartition;
      testPartitions[solution.Model] = solution.ProblemData.TestPartition;

      trainingEvaluationCache.Clear();
      testEvaluationCache.Clear();
      evaluationCache.Clear();
    }
 //mkommend: annoying name clash with static method, open to better naming suggestions
 public IEnumerable <Tuple <string, double> > Calculate(IClassificationSolution solution)
 {
     return(CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition));
 }