private static double CalculateQualityForReplacement(
            IClassificationModel model,
            ModifiableDataset modifiableDataset,
            string variableName,
            IList originalValues,
            IEnumerable <int> rows,
            IList replacementValues,
            IEnumerable <double> targetValues)
        {
            modifiableDataset.ReplaceVariable(variableName, replacementValues);
            var discModel = model as IDiscriminantFunctionClassificationModel;

            if (discModel != null)
            {
                var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);
                discModel.RecalculateModelParameters(problemData, rows);
            }

            //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
            var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
            var ret       = CalculateQuality(targetValues, estimates);

            modifiableDataset.ReplaceVariable(variableName, originalValues);

            return(ret);
        }
コード例 #2
0
        public static bool IsProblemDataCompatible(IClassificationModel model, IClassificationProblemData problemData, out string errorMessage)
        {
            if (model == null)
            {
                throw new ArgumentNullException("model", "The provided model is null.");
            }
            if (problemData == null)
            {
                throw new ArgumentNullException("problemData", "The provided problemData is null.");
            }
            errorMessage = string.Empty;

            if (model.TargetVariable != problemData.TargetVariable)
            {
                errorMessage = string.Format("The target variable of the model {0} does not match the target variable of the problemData {1}.", model.TargetVariable, problemData.TargetVariable);
            }

            var evaluationErrorMessage = string.Empty;
            var datasetCompatible      = model.IsDatasetCompatible(problemData.Dataset, out evaluationErrorMessage);

            if (!datasetCompatible)
            {
                errorMessage += evaluationErrorMessage;
            }

            return(string.IsNullOrEmpty(errorMessage));
        }
        private static IList GetReplacementValues(ModifiableDataset modifiableDataset,
                                                  string variableName,
                                                  IClassificationModel model,
                                                  IEnumerable <int> rows,
                                                  IEnumerable <double> targetValues,
                                                  out IList originalValues,
                                                  ReplacementMethodEnum replacementMethod             = ReplacementMethodEnum.Shuffle,
                                                  FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best)
        {
            IList replacementValues = null;

            if (modifiableDataset.VariableHasType <double>(variableName))
            {
                originalValues    = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
                replacementValues = GetReplacementValuesForDouble(modifiableDataset, rows, (List <double>)originalValues, replacementMethod);
            }
            else if (modifiableDataset.VariableHasType <string>(variableName))
            {
                originalValues    = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
                replacementValues = GetReplacementValuesForString(model, modifiableDataset, variableName, rows, (List <string>)originalValues, targetValues, factorReplacementMethod);
            }
            else
            {
                throw new NotSupportedException("Variable not supported");
            }

            return(replacementValues);
        }
        public static IEnumerable <Tuple <string, double> > CalculateImpacts(
            IClassificationModel model,
            IClassificationProblemData problemData,
            IEnumerable <double> estimatedClassValues,
            IEnumerable <int> rows,
            ReplacementMethodEnum replacementMethod             = ReplacementMethodEnum.Shuffle,
            FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best)
        {
            //fholzing: try and catch in case a different dataset is loaded, otherwise statement is neglectable
            var missingVariables = model.VariablesUsedForPrediction.Except(problemData.Dataset.VariableNames);

            if (missingVariables.Any())
            {
                throw new InvalidOperationException(string.Format("Can not calculate variable impacts, because the model uses inputs missing in the dataset ({0})", string.Join(", ", missingVariables)));
            }
            IEnumerable <double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
            var originalQuality = CalculateQuality(targetValues, estimatedClassValues);

            var impacts           = new Dictionary <string, double>();
            var inputvariables    = new HashSet <string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
            var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();

            foreach (var inputVariable in inputvariables)
            {
                impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData, modifiableDataset, rows, replacementMethod, factorReplacementMethod, targetValues, originalQuality);
            }

            return(impacts.Select(i => Tuple.Create(i.Key, i.Value)));
        }
        private static IList GetReplacementValuesForString(IClassificationModel model,
                                                           ModifiableDataset modifiableDataset,
                                                           string variableName,
                                                           IEnumerable <int> rows,
                                                           List <string> originalValues,
                                                           IEnumerable <double> targetValues,
                                                           FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Shuffle)
        {
            List <string> replacementValues = null;
            IRandom       random            = new FastRandom(31415);

            switch (factorReplacementMethod)
            {
            case FactorReplacementMethodEnum.Best:
                // try replacing with all possible values and find the best replacement value
                var bestQuality = double.NegativeInfinity;
                foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct())
                {
                    List <string> curReplacementValues = Enumerable.Repeat(repl, modifiableDataset.Rows).ToList();
                    //fholzing: this result could be used later on (theoretically), but is neglected for better readability/method consistency
                    var newValue   = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, curReplacementValues, targetValues);
                    var curQuality = newValue;

                    if (curQuality > bestQuality)
                    {
                        bestQuality       = curQuality;
                        replacementValues = curReplacementValues;
                    }
                }
                break;

            case FactorReplacementMethodEnum.Mode:
                var mostCommonValue = rows.Select(r => originalValues[r])
                                      .GroupBy(v => v)
                                      .OrderByDescending(g => g.Count())
                                      .First().Key;
                replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
                break;

            case FactorReplacementMethodEnum.Shuffle:
                // new var has same empirical distribution but the relation to y is broken
                // prepare a complete column for the dataset
                replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
                // shuffle only the selected rows
                var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
                int i = 0;
                // update column values
                foreach (var r in rows)
                {
                    replacementValues[r] = shuffledValues[i++];
                }
                break;

            default:
                throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
            }

            return(replacementValues);
        }
コード例 #6
0
 public void Remove(IClassificationModel model)
 {
     models.Remove(model);
     if (!models.Any())
     {
         TargetVariable = string.Empty;
     }
 }
コード例 #7
0
 public void Add(IClassificationModel model)
 {
     if (string.IsNullOrEmpty(TargetVariable))
     {
         TargetVariable = model.TargetVariable;
     }
     models.Add(model);
 }
コード例 #8
0
 protected ClassificationSolutionBase(IClassificationModel model, IClassificationProblemData problemData)
   : base(model, problemData) {
   Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
   Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
   Add(new Result(TrainingNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the model on the training partition.", new DoubleValue()));
   Add(new Result(TestNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the model on the test partition.", new DoubleValue()));
   Add(new Result(ClassificationPerformanceMeasuresResultName, @"Classification performance measures.\n
                           In a multiclass classification all misclassifications of the negative class will be treated as true negatives except on positive class estimations.",
                         new ClassificationPerformanceMeasuresResultCollection()));
 }
コード例 #9
0
 protected ClassificationSolutionBase(IClassificationModel model, IClassificationProblemData problemData)
     : base(model, problemData)
 {
     Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
     Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
     Add(new Result(TrainingNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the model on the training partition.", new DoubleValue()));
     Add(new Result(TestNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the model on the test partition.", new DoubleValue()));
     Add(new Result(ClassificationPerformanceMeasuresResultName, @"Classification performance measures.\n
                       In a multiclass classification all misclassifications of the negative class will be treated as true negatives except on positive class estimations.",
                    new ClassificationPerformanceMeasuresResultCollection()));
 }
    public static double Calculate(IClassificationModel model, IClassificationProblemData problemData, IEnumerable<int> rows) {
      var estimations = model.GetEstimatedClassValues(problemData.Dataset, rows).GetEnumerator();
      if (!estimations.MoveNext()) return double.NaN;

      var penalty = 0.0;
      var count = 0;
      foreach (var r in rows) {
        var actualClass = problemData.Dataset.GetDoubleValue(problemData.TargetVariable, r);
        penalty += problemData.GetClassificationPenalty(actualClass, estimations.Current);
        estimations.MoveNext();
        count++;
      }
      return penalty / count;
    }
コード例 #11
0
        private List <Tuple <string, double> > CalculateVariableImpacts(List <string> originalVariableOrdering,
                                                                        IClassificationModel model,
                                                                        IClassificationProblemData problemData,
                                                                        IEnumerable <double> estimatedClassValues,
                                                                        ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum dataPartition,
                                                                        ClassificationSolutionVariableImpactsCalculator.ReplacementMethodEnum replMethod,
                                                                        ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum factorReplMethod,
                                                                        CancellationToken token,
                                                                        IProgress progress)
        {
            List <Tuple <string, double> > impacts = new List <Tuple <string, double> >();
            int count              = originalVariableOrdering.Count;
            int i                  = 0;
            var modifiableDataset  = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
            IEnumerable <int> rows = ClassificationSolutionVariableImpactsCalculator.GetPartitionRows(dataPartition, problemData);

            //Calculate original quality-values (via calculator, default is R²)
            IEnumerable <double> targetValuesPartition         = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
            IEnumerable <double> estimatedClassValuesPartition = Content.GetEstimatedClassValues(rows);

            var originalCalculatorValue = ClassificationSolutionVariableImpactsCalculator.CalculateQuality(targetValuesPartition, estimatedClassValuesPartition);
            var clonedModel             = (IClassificationModel)model.Clone();

            foreach (var variableName in originalVariableOrdering)
            {
                if (cancellationToken.Token.IsCancellationRequested)
                {
                    return(null);
                }
                progress.ProgressValue = (double)++i / count;
                progress.Message       = string.Format("Calculating impact for variable {0} ({1} of {2})", variableName, i, count);

                double impact = 0;
                //If the variable isn't used for prediction, it has zero impact.
                if (model.VariablesUsedForPrediction.Contains(variableName))
                {
                    impact = ClassificationSolutionVariableImpactsCalculator.CalculateImpact(variableName, clonedModel, problemData, modifiableDataset, rows, replMethod, factorReplMethod, targetValuesPartition, originalCalculatorValue);
                }
                impacts.Add(new Tuple <string, double>(variableName, impact));
            }

            return(impacts);
        }
        public static double Calculate(IClassificationModel model, IClassificationProblemData problemData, IEnumerable <int> rows)
        {
            var estimations = model.GetEstimatedClassValues(problemData.Dataset, rows).GetEnumerator();

            if (!estimations.MoveNext())
            {
                return(double.NaN);
            }

            var penalty = 0.0;
            var count   = 0;

            foreach (var r in rows)
            {
                var actualClass = problemData.Dataset.GetDoubleValue(problemData.TargetVariable, r);
                penalty += problemData.GetClassificationPenalty(actualClass, estimations.Current);
                estimations.MoveNext();
                count++;
            }
            return(penalty / count);
        }
        public static double CalculateImpact(string variableName,
                                             IClassificationModel model,
                                             IClassificationProblemData problemData,
                                             ModifiableDataset modifiableDataset,
                                             IEnumerable <int> rows,
                                             ReplacementMethodEnum replacementMethod             = ReplacementMethodEnum.Shuffle,
                                             FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
                                             IEnumerable <double> targetValues = null,
                                             double quality = double.NaN)
        {
            if (!model.VariablesUsedForPrediction.Contains(variableName))
            {
                return(0.0);
            }
            if (!problemData.Dataset.VariableNames.Contains(variableName))
            {
                throw new InvalidOperationException(string.Format("Can not calculate variable impact, because the model uses inputs missing in the dataset ({0})", variableName));
            }

            if (targetValues == null)
            {
                targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
            }
            if (quality == double.NaN)
            {
                quality = CalculateQuality(model.GetEstimatedClassValues(modifiableDataset, rows), targetValues);
            }

            IList originalValues    = null;
            IList replacementValues = GetReplacementValues(modifiableDataset, variableName, model, rows, targetValues, out originalValues, replacementMethod, factorReplacementMethod);

            double newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, replacementValues, targetValues);
            double impact   = quality - newValue;

            return(impact);
        }
コード例 #14
0
 public void Add(IClassificationModel model) {
   models.Add(model);
 }
 private bool RowIsTestForModel(int currentRow, IClassificationModel model)
 {
     return(testPartitions == null || !testPartitions.ContainsKey(model) ||
            (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End));
 }
コード例 #16
0
 /// <summary>
 /// コンストラクター
 /// </summary>
 /// <param name="settingsModel">設定Modelのインスタンス</param>
 /// <param name="classificationModel">分類Modelのインスタンス</param>
 public AppModel(ISettingsModel settingsModel, IClassificationModel classificationModel)
 {
     SettingsModel       = settingsModel;
     ClassificationModel = classificationModel;
 }
コード例 #17
0
 public void Add(IClassificationModel model) {
   if (string.IsNullOrEmpty(TargetVariable)) TargetVariable = model.TargetVariable;
   models.Add(model);
 }
コード例 #18
0
 private bool RowIsTestForModel(int currentRow, IClassificationModel model) {
   return testPartitions == null || !testPartitions.ContainsKey(model) ||
           (testPartitions[model].Start <= currentRow && currentRow < testPartitions[model].End);
 }
コード例 #19
0
 public void Remove(IClassificationModel model)
 {
     models.Remove(model);
 }
コード例 #20
0
 public void Add(IClassificationModel model)
 {
     models.Add(model);
 }
コード例 #21
0
 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
     : base(model, problemData)
 {
     evaluationCache = new Dictionary <int, double>(problemData.Dataset.Rows);
     CalculateClassificationResults();
 }
コード例 #22
0
 public void Remove(IClassificationModel model) {
   models.Remove(model);
   if (!models.Any()) TargetVariable = string.Empty;
 }
コード例 #23
0
 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
   : base(model, problemData) {
   evaluationCache = new Dictionary<int, double>(problemData.Dataset.Rows);
   CalculateClassificationResults();
 }
コード例 #24
0
 public void Remove(IClassificationModel model) {
   models.Remove(model);
 }