public ClassificationEnsembleSolution(IEnumerable <IClassificationModel> models, IClassificationProblemData problemData, IEnumerable <IntRange> trainingPartitions, IEnumerable <IntRange> testPartitions)
            : base(new ClassificationEnsembleModel(Enumerable.Empty <IClassificationModel>()), new ClassificationEnsembleProblemData(problemData))
        {
            this.trainingPartitions      = new Dictionary <IClassificationModel, IntRange>();
            this.testPartitions          = new Dictionary <IClassificationModel, IntRange>();
            this.classificationSolutions = new ItemCollection <IClassificationSolution>();

            List <IClassificationSolution> solutions = new List <IClassificationSolution>();
            var modelEnumerator             = models.GetEnumerator();
            var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
            var testPartitionEnumerator     = testPartitions.GetEnumerator();

            while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext())
            {
                var p = (IClassificationProblemData)problemData.Clone();
                p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
                p.TrainingPartition.End   = trainingPartitionEnumerator.Current.End;
                p.TestPartition.Start     = testPartitionEnumerator.Current.Start;
                p.TestPartition.End       = testPartitionEnumerator.Current.End;

                solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
            }
            if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext())
            {
                throw new ArgumentException();
            }

            trainingEvaluationCache = new Dictionary <int, double>(problemData.TrainingIndices.Count());
            testEvaluationCache     = new Dictionary <int, double>(problemData.TestIndices.Count());

            RegisterClassificationSolutionsEventHandler();
            classificationSolutions.AddRange(solutions);
        }
        // keep for compatibility with old API
        public static RandomForestClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
                                                                                                  out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError)
        {
            var model = CreateRandomForestClassificationModel(problemData, nTrees, r, m, seed,
                                                              out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);

            return(new RandomForestClassificationSolution(model, (IClassificationProblemData)problemData.Clone()));
        }
        public static IClassificationSolution CreateNearestNeighbourClassificationSolution(IClassificationProblemData problemData, int k, double[] weights = null)
        {
            var problemDataClone = (IClassificationProblemData)problemData.Clone();

            return(new NearestNeighbourClassificationSolution(Train(problemDataClone, k, weights), problemDataClone));
        }
Exemplo n.º 4
0
        public static IClassificationSolution CreateLinearDiscriminantAnalysisSolution(IClassificationProblemData problemData)
        {
            var    dataset        = problemData.Dataset;
            string targetVariable = problemData.TargetVariable;
            IEnumerable <string> allowedInputVariables = problemData.AllowedInputVariables;
            IEnumerable <int>    rows = problemData.TrainingIndices;
            int nClasses            = problemData.ClassNames.Count();
            var doubleVariableNames = allowedInputVariables.Where(dataset.VariableHasType <double>).ToArray();
            var factorVariableNames = allowedInputVariables.Where(dataset.VariableHasType <string>).ToArray();

            double[,] inputMatrix = dataset.ToArray(doubleVariableNames.Concat(new string[] { targetVariable }), rows);

            var factorVariables = dataset.GetFactorVariableValues(factorVariableNames, rows);
            var factorMatrix    = dataset.ToArray(factorVariables, rows);

            inputMatrix = factorMatrix.HorzCat(inputMatrix);

            if (inputMatrix.Cast <double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
            {
                throw new NotSupportedException("Linear discriminant analysis does not support NaN or infinity values in the input dataset.");
            }

            // change class values into class index
            int           targetVariableColumn = inputMatrix.GetLength(1) - 1;
            List <double> classValues          = problemData.ClassValues.OrderBy(x => x).ToList();

            for (int row = 0; row < inputMatrix.GetLength(0); row++)
            {
                inputMatrix[row, targetVariableColumn] = classValues.IndexOf(inputMatrix[row, targetVariableColumn]);
            }
            int info;

            double[] w;
            alglib.fisherlda(inputMatrix, inputMatrix.GetLength(0), inputMatrix.GetLength(1) - 1, nClasses, out info, out w);
            if (info < 1)
            {
                throw new ArgumentException("Error in calculation of linear discriminant analysis solution");
            }

            var nFactorCoeff = factorMatrix.GetLength(1);
            var tree         = LinearModelToTreeConverter.CreateTree(factorVariables, w.Take(nFactorCoeff).ToArray(),
                                                                     doubleVariableNames, w.Skip(nFactorCoeff).Take(doubleVariableNames.Length).ToArray());

            var model = CreateDiscriminantFunctionModel(tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter(), problemData, rows);
            SymbolicDiscriminantFunctionClassificationSolution solution = new SymbolicDiscriminantFunctionClassificationSolution(model, (IClassificationProblemData)problemData.Clone());

            return(solution);
        }
Exemplo n.º 5
0
        public static IClassificationSolution CreateNeuralNetworkClassificationSolution(IClassificationProblemData problemData, int nLayers, int nHiddenNodes1, int nHiddenNodes2, double decay, int restarts,
                                                                                        out double rmsError, out double avgRelError, out double relClassError)
        {
            var    dataset        = problemData.Dataset;
            string targetVariable = problemData.TargetVariable;
            IEnumerable <string> allowedInputVariables = problemData.AllowedInputVariables;
            IEnumerable <int>    rows = problemData.TrainingIndices;

            double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
            if (inputMatrix.Cast <double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
            {
                throw new NotSupportedException("Neural network classification does not support NaN or infinity values in the input dataset.");
            }

            int nRows     = inputMatrix.GetLength(0);
            int nFeatures = inputMatrix.GetLength(1) - 1;

            double[] classValues = dataset.GetDoubleValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
            int      nClasses    = classValues.Count();
            // map original class values to values [0..nClasses-1]
            Dictionary <double, double> classIndices = new Dictionary <double, double>();

            for (int i = 0; i < nClasses; i++)
            {
                classIndices[classValues[i]] = i;
            }
            for (int row = 0; row < nRows; row++)
            {
                inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
            }

            alglib.multilayerperceptron multiLayerPerceptron = null;
            if (nLayers == 0)
            {
                alglib.mlpcreatec0(allowedInputVariables.Count(), nClasses, out multiLayerPerceptron);
            }
            else if (nLayers == 1)
            {
                alglib.mlpcreatec1(allowedInputVariables.Count(), nHiddenNodes1, nClasses, out multiLayerPerceptron);
            }
            else if (nLayers == 2)
            {
                alglib.mlpcreatec2(allowedInputVariables.Count(), nHiddenNodes1, nHiddenNodes2, nClasses, out multiLayerPerceptron);
            }
            else
            {
                throw new ArgumentException("Number of layers must be zero, one, or two.", "nLayers");
            }
            alglib.mlpreport rep;

            int info;

            // using mlptrainlm instead of mlptraines or mlptrainbfgs because only one parameter is necessary
            alglib.mlptrainlm(multiLayerPerceptron, inputMatrix, nRows, decay, restarts, out info, out rep);
            if (info != 2)
            {
                throw new ArgumentException("Error in calculation of neural network classification solution");
            }

            rmsError      = alglib.mlprmserror(multiLayerPerceptron, inputMatrix, nRows);
            avgRelError   = alglib.mlpavgrelerror(multiLayerPerceptron, inputMatrix, nRows);
            relClassError = alglib.mlpclserror(multiLayerPerceptron, inputMatrix, nRows) / (double)nRows;

            var problemDataClone = (IClassificationProblemData)problemData.Clone();

            return(new NeuralNetworkClassificationSolution(new NeuralNetworkModel(multiLayerPerceptron, targetVariable, allowedInputVariables, problemDataClone.ClassValues.ToArray()), problemDataClone));
        }
Exemplo n.º 6
0
        public void CustomModelVariableImpactNoInfluenceTest()
        {
            IClassificationProblemData problemData = CreateDefaultProblem();
            ISymbolicExpressionTree    tree        = CreateCustomExpressionTreeNoInfluenceX1();
            var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());

            model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
            IClassificationSolution     solution        = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
            Dictionary <string, double> expectedImpacts = GetExpectedValuesForCustomProblemNoInfluence();

            CheckDefaultAsserts(solution, expectedImpacts);
        }
Exemplo n.º 7
0
        public static IClassificationSolution CreateNearestNeighbourClassificationSolution(IClassificationProblemData problemData, int k)
        {
            var problemDataClone = (IClassificationProblemData)problemData.Clone();

            return(new NearestNeighbourClassificationSolution(problemDataClone, Train(problemDataClone, k)));
        }
Exemplo n.º 8
0
        // BackwardsCompatibility3.4
        #region Backwards compatible code, remove with 3.5
        public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable <string> allowedInputVariables,
                                                                                                    int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv)
        {
            ISupportVectorMachineModel model;

            Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, degree, out model, out nSv);
            var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());

            trainingAccuracy = solution.TrainingAccuracy;
            testAccuracy     = solution.TestAccuracy;

            return(solution);
        }
Exemplo n.º 9
0
        public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize = 6)
        {
            var classValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
            var model1      = FindBestDoubleVariableModel(problemData, minBucketSize);
            var model2      = FindBestFactorModel(problemData);

            if (model1 == null && model2 == null)
            {
                throw new InvalidProgramException("Could not create OneR solution");
            }
            else if (model1 == null)
            {
                return(new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone()));
            }
            else if (model2 == null)
            {
                return(new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone()));
            }
            else
            {
                var model1EstimatedValues = model1.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
                var model1NumCorrect      = classValues.Zip(model1EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);

                var model2EstimatedValues = model2.GetEstimatedClassValues(problemData.Dataset, problemData.TrainingIndices);
                var model2NumCorrect      = classValues.Zip(model2EstimatedValues, (a, b) => a.IsAlmost(b)).Count(e => e);

                if (model1NumCorrect > model2NumCorrect)
                {
                    return(new OneRClassificationSolution(model1, (IClassificationProblemData)problemData.Clone()));
                }
                else
                {
                    return(new OneFactorClassificationSolution(model2, (IClassificationProblemData)problemData.Clone()));
                }
            }
        }
    public ClassificationEnsembleSolution(IEnumerable<IClassificationModel> models, IClassificationProblemData problemData, IEnumerable<IntRange> trainingPartitions, IEnumerable<IntRange> testPartitions)
      : base(new ClassificationEnsembleModel(Enumerable.Empty<IClassificationModel>()), new ClassificationEnsembleProblemData(problemData)) {
      this.trainingPartitions = new Dictionary<IClassificationModel, IntRange>();
      this.testPartitions = new Dictionary<IClassificationModel, IntRange>();
      this.classificationSolutions = new ItemCollection<IClassificationSolution>();

      List<IClassificationSolution> solutions = new List<IClassificationSolution>();
      var modelEnumerator = models.GetEnumerator();
      var trainingPartitionEnumerator = trainingPartitions.GetEnumerator();
      var testPartitionEnumerator = testPartitions.GetEnumerator();

      while (modelEnumerator.MoveNext() & trainingPartitionEnumerator.MoveNext() & testPartitionEnumerator.MoveNext()) {
        var p = (IClassificationProblemData)problemData.Clone();
        p.TrainingPartition.Start = trainingPartitionEnumerator.Current.Start;
        p.TrainingPartition.End = trainingPartitionEnumerator.Current.End;
        p.TestPartition.Start = testPartitionEnumerator.Current.Start;
        p.TestPartition.End = testPartitionEnumerator.Current.End;

        solutions.Add(modelEnumerator.Current.CreateClassificationSolution(p));
      }
      if (modelEnumerator.MoveNext() | trainingPartitionEnumerator.MoveNext() | testPartitionEnumerator.MoveNext()) {
        throw new ArgumentException();
      }

      trainingEvaluationCache = new Dictionary<int, double>(problemData.TrainingIndices.Count());
      testEvaluationCache = new Dictionary<int, double>(problemData.TestIndices.Count());

      RegisterClassificationSolutionsEventHandler();
      classificationSolutions.AddRange(solutions);
    }
Exemplo n.º 11
0
        protected override void Run()
        {
            IClassificationProblemData problemData            = Problem.ProblemData;
            IEnumerable <string>       selectedInputVariables = problemData.AllowedInputVariables;
            int nSv;
            ISupportVectorMachineModel model;

            Run(problemData, selectedInputVariables, GetSvmType(SvmType.Value), GetKernelType(KernelType.Value), Cost.Value, Nu.Value, Gamma.Value, Degree.Value, out model, out nSv);

            if (CreateSolution)
            {
                var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
                Results.Add(new Result("Support vector classification solution", "The support vector classification solution.",
                                       solution));
            }

            {
                // calculate classification metrics
                // calculate regression model metrics
                var ds         = problemData.Dataset;
                var trainRows  = problemData.TrainingIndices;
                var testRows   = problemData.TestIndices;
                var yTrain     = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
                var yTest      = ds.GetDoubleValues(problemData.TargetVariable, testRows);
                var yPredTrain = model.GetEstimatedClassValues(ds, trainRows);
                var yPredTest  = model.GetEstimatedClassValues(ds, testRows);

                OnlineCalculatorError error;
                var trainAccuracy = OnlineAccuracyCalculator.Calculate(yPredTrain, yTrain, out error);
                if (error != OnlineCalculatorError.None)
                {
                    trainAccuracy = double.MaxValue;
                }
                var testAccuracy = OnlineAccuracyCalculator.Calculate(yPredTest, yTest, out error);
                if (error != OnlineCalculatorError.None)
                {
                    testAccuracy = double.MaxValue;
                }

                Results.Add(new Result("Accuracy (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainAccuracy)));
                Results.Add(new Result("Accuracy (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testAccuracy)));

                Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.",
                                       new IntValue(nSv)));
            }
        }
 public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData)
 {
     return(new RandomForestClassificationSolution(this, (IClassificationProblemData)problemData.Clone()));
 }
Exemplo n.º 13
0
        public static IClassificationSolution CreateLogitClassificationSolution(IClassificationProblemData problemData, out double rmsError, out double relClassError)
        {
            var    dataset        = problemData.Dataset;
            string targetVariable = problemData.TargetVariable;
            IEnumerable <string> allowedInputVariables = problemData.AllowedInputVariables;
            IEnumerable <int>    rows = problemData.TrainingIndices;

            double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
            if (inputMatrix.Cast <double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
            {
                throw new NotSupportedException("Multinomial logit classification does not support NaN or infinity values in the input dataset.");
            }

            alglib.logitmodel lm  = new alglib.logitmodel();
            alglib.mnlreport  rep = new alglib.mnlreport();
            int nRows             = inputMatrix.GetLength(0);
            int nFeatures         = inputMatrix.GetLength(1) - 1;

            double[] classValues = dataset.GetDoubleValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
            int      nClasses    = classValues.Count();
            // map original class values to values [0..nClasses-1]
            Dictionary <double, double> classIndices = new Dictionary <double, double>();

            for (int i = 0; i < nClasses; i++)
            {
                classIndices[classValues[i]] = i;
            }
            for (int row = 0; row < nRows; row++)
            {
                inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
            }
            int info;

            alglib.mnltrainh(inputMatrix, nRows, nFeatures, nClasses, out info, out lm, out rep);
            if (info != 1)
            {
                throw new ArgumentException("Error in calculation of logit classification solution");
            }

            rmsError      = alglib.mnlrmserror(lm, inputMatrix, nRows);
            relClassError = alglib.mnlrelclserror(lm, inputMatrix, nRows);

            MultinomialLogitClassificationSolution solution = new MultinomialLogitClassificationSolution((IClassificationProblemData)problemData.Clone(), new MultinomialLogitModel(lm, targetVariable, allowedInputVariables, classValues));

            return(solution);
        }
        public static IClassificationSolution CreateLinearDiscriminantAnalysisSolution(IClassificationProblemData problemData)
        {
            var    dataset        = problemData.Dataset;
            string targetVariable = problemData.TargetVariable;
            IEnumerable <string> allowedInputVariables = problemData.AllowedInputVariables;
            IEnumerable <int>    rows = problemData.TrainingIndices;
            int nClasses = problemData.ClassNames.Count();

            double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
            if (inputMatrix.Cast <double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
            {
                throw new NotSupportedException("Linear discriminant analysis does not support NaN or infinity values in the input dataset.");
            }

            // change class values into class index
            int           targetVariableColumn = inputMatrix.GetLength(1) - 1;
            List <double> classValues          = problemData.ClassValues.OrderBy(x => x).ToList();

            for (int row = 0; row < inputMatrix.GetLength(0); row++)
            {
                inputMatrix[row, targetVariableColumn] = classValues.IndexOf(inputMatrix[row, targetVariableColumn]);
            }
            int info;

            double[] w;
            alglib.fisherlda(inputMatrix, inputMatrix.GetLength(0), allowedInputVariables.Count(), nClasses, out info, out w);
            if (info < 1)
            {
                throw new ArgumentException("Error in calculation of linear discriminant analysis solution");
            }

            ISymbolicExpressionTree     tree      = new SymbolicExpressionTree(new ProgramRootSymbol().CreateTreeNode());
            ISymbolicExpressionTreeNode startNode = new StartSymbol().CreateTreeNode();

            tree.Root.AddSubtree(startNode);
            ISymbolicExpressionTreeNode addition = new Addition().CreateTreeNode();

            startNode.AddSubtree(addition);

            int col = 0;

            foreach (string column in allowedInputVariables)
            {
                VariableTreeNode vNode = (VariableTreeNode) new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
                vNode.VariableName = column;
                vNode.Weight       = w[col];
                addition.AddSubtree(vNode);
                col++;
            }

            var model = LinearDiscriminantAnalysis.CreateDiscriminantFunctionModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter(), problemData, rows);
            SymbolicDiscriminantFunctionClassificationSolution solution = new SymbolicDiscriminantFunctionClassificationSolution(model, (IClassificationProblemData)problemData.Clone());

            return(solution);
        }
        public static IClassificationSolution CreateNeuralNetworkEnsembleClassificationSolution(IClassificationProblemData problemData, int ensembleSize, int nLayers, int nHiddenNodes1, int nHiddenNodes2, double decay, int restarts,
                                                                                                out double rmsError, out double avgRelError, out double relClassError)
        {
            var    dataset        = problemData.Dataset;
            string targetVariable = problemData.TargetVariable;
            IEnumerable <string> allowedInputVariables = problemData.AllowedInputVariables;
            IEnumerable <int>    rows = problemData.TrainingIndices;

            double[,] inputMatrix = dataset.ToArray(allowedInputVariables.Concat(new string[] { targetVariable }), rows);
            if (inputMatrix.ContainsNanOrInfinity())
            {
                throw new NotSupportedException("Neural network ensemble classification does not support NaN or infinity values in the input dataset.");
            }

            int nRows     = inputMatrix.GetLength(0);
            int nFeatures = inputMatrix.GetLength(1) - 1;

            double[] classValues = dataset.GetDoubleValues(targetVariable).Distinct().OrderBy(x => x).ToArray();
            int      nClasses    = classValues.Count();
            // map original class values to values [0..nClasses-1]
            Dictionary <double, double> classIndices = new Dictionary <double, double>();

            for (int i = 0; i < nClasses; i++)
            {
                classIndices[classValues[i]] = i;
            }
            for (int row = 0; row < nRows; row++)
            {
                inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
            }

            alglib.mlpensemble mlpEnsemble = null;
            if (nLayers == 0)
            {
                alglib.mlpecreatec0(allowedInputVariables.Count(), nClasses, ensembleSize, out mlpEnsemble);
            }
            else if (nLayers == 1)
            {
                alglib.mlpecreatec1(allowedInputVariables.Count(), nHiddenNodes1, nClasses, ensembleSize, out mlpEnsemble);
            }
            else if (nLayers == 2)
            {
                alglib.mlpecreatec2(allowedInputVariables.Count(), nHiddenNodes1, nHiddenNodes2, nClasses, ensembleSize, out mlpEnsemble);
            }
            else
            {
                throw new ArgumentException("Number of layers must be zero, one, or two.", "nLayers");
            }
            alglib.mlpreport rep;

            int info;

            alglib.mlpetraines(mlpEnsemble, inputMatrix, nRows, decay, restarts, out info, out rep);
            if (info != 6)
            {
                throw new ArgumentException("Error in calculation of neural network ensemble classification solution");
            }

            rmsError      = alglib.mlpermserror(mlpEnsemble, inputMatrix, nRows);
            avgRelError   = alglib.mlpeavgrelerror(mlpEnsemble, inputMatrix, nRows);
            relClassError = alglib.mlperelclserror(mlpEnsemble, inputMatrix, nRows);
            var problemDataClone = (IClassificationProblemData)problemData.Clone();

            return(new NeuralNetworkEnsembleClassificationSolution(new NeuralNetworkEnsembleModel(mlpEnsemble, targetVariable, allowedInputVariables, problemDataClone.ClassValues.ToArray()), problemDataClone));
        }
Exemplo n.º 16
0
        public static IClassificationSolution CreateOneRSolution(IClassificationProblemData problemData, int minBucketSize = 6)
        {
            var          bestClassified         = 0;
            List <Split> bestSplits             = null;
            string       bestVariable           = string.Empty;
            double       bestMissingValuesClass = double.NaN;
            var          classValues            = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);

            foreach (var variable in problemData.AllowedInputVariables)
            {
                var inputValues = problemData.Dataset.GetDoubleValues(variable, problemData.TrainingIndices);
                var samples     = inputValues.Zip(classValues, (i, v) => new Sample(i, v)).OrderBy(s => s.inputValue);

                var missingValuesDistribution = samples.Where(s => double.IsNaN(s.inputValue)).GroupBy(s => s.classValue).ToDictionary(s => s.Key, s => s.Count()).MaxItems(s => s.Value).FirstOrDefault();

                //calculate class distributions for all distinct inputValues
                List <Dictionary <double, int> > classDistributions = new List <Dictionary <double, int> >();
                List <double> thresholds = new List <double>();
                double        lastValue  = double.NaN;
                foreach (var sample in samples.Where(s => !double.IsNaN(s.inputValue)))
                {
                    if (sample.inputValue > lastValue || double.IsNaN(lastValue))
                    {
                        if (!double.IsNaN(lastValue))
                        {
                            thresholds.Add((lastValue + sample.inputValue) / 2);
                        }
                        lastValue = sample.inputValue;
                        classDistributions.Add(new Dictionary <double, int>());
                        foreach (var classValue in problemData.ClassValues)
                        {
                            classDistributions[classDistributions.Count - 1][classValue] = 0;
                        }
                    }
                    classDistributions[classDistributions.Count - 1][sample.classValue]++;
                }
                thresholds.Add(double.PositiveInfinity);

                var distribution = classDistributions[0];
                var threshold    = thresholds[0];
                var splits       = new List <Split>();

                for (int i = 1; i < classDistributions.Count; i++)
                {
                    var samplesInSplit = distribution.Max(d => d.Value);
                    //join splits if there are too few samples in the split or the distributions has the same maximum class value as the current split
                    if (samplesInSplit < minBucketSize ||
                        classDistributions[i].MaxItems(d => d.Value).Select(d => d.Key).Contains(
                            distribution.MaxItems(d => d.Value).Select(d => d.Key).First()))
                    {
                        foreach (var classValue in classDistributions[i])
                        {
                            distribution[classValue.Key] += classValue.Value;
                        }
                        threshold = thresholds[i];
                    }
                    else
                    {
                        splits.Add(new Split(threshold, distribution.MaxItems(d => d.Value).Select(d => d.Key).First()));
                        distribution = classDistributions[i];
                        threshold    = thresholds[i];
                    }
                }
                splits.Add(new Split(double.PositiveInfinity, distribution.MaxItems(d => d.Value).Select(d => d.Key).First()));

                int correctClassified = 0;
                int splitIndex        = 0;
                foreach (var sample in samples.Where(s => !double.IsNaN(s.inputValue)))
                {
                    while (sample.inputValue >= splits[splitIndex].thresholdValue)
                    {
                        splitIndex++;
                    }
                    correctClassified += sample.classValue == splits[splitIndex].classValue ? 1 : 0;
                }
                correctClassified += missingValuesDistribution.Value;

                if (correctClassified > bestClassified)
                {
                    bestClassified         = correctClassified;
                    bestSplits             = splits;
                    bestVariable           = variable;
                    bestMissingValuesClass = missingValuesDistribution.Value == 0 ? double.NaN : missingValuesDistribution.Key;
                }
            }

            //remove neighboring splits with the same class value
            for (int i = 0; i < bestSplits.Count - 1; i++)
            {
                if (bestSplits[i].classValue == bestSplits[i + 1].classValue)
                {
                    bestSplits.Remove(bestSplits[i]);
                    i--;
                }
            }

            var model    = new OneRClassificationModel(bestVariable, bestSplits.Select(s => s.thresholdValue).ToArray(), bestSplits.Select(s => s.classValue).ToArray(), bestMissingValuesClass);
            var solution = new OneRClassificationSolution(model, (IClassificationProblemData)problemData.Clone());

            return(solution);
        }