private static CrossValidationValues <TModel> cloneValue(CrossValidationValues <TModel> value, bool includeModel) { if (includeModel) { return(new CrossValidationValues <TModel>((TModel)value.Model.Clone(), value.TrainingValue, value.ValidationValue) { Tag = value.Tag }); } else { return(new CrossValidationValues <TModel>(null, value.TrainingValue, value.ValidationValue) { Tag = value.Tag }); } }
/// <summary> /// Computes the cross validation algorithm. /// </summary> /// public CrossValidationResult <TModel> Compute() { if (Fitting == null) { throw new InvalidOperationException("Fitting function must have been previously defined."); } var models = new CrossValidationValues <TModel> [folds.Length]; if (RunInParallel) { Parallel.For(0, folds.Length, i => { int[] trainingSet, validationSet; // Create training and validation sets CreatePartitions(i, out trainingSet, out validationSet); // Fit and evaluate the model models[i] = fitting(i, trainingSet, validationSet); }); } else { for (int i = 0; i < folds.Length; i++) { int[] trainingSet, validationSet; // Create training and validation sets CreatePartitions(i, out trainingSet, out validationSet); // Fit and evaluate the model models[i] = fitting(i, trainingSet, validationSet); } } // Return cross-validation statistics return(new CrossValidationResult <TModel>(this, models)); }
/// <summary> /// Starts the model training, calling the <see cref="IterationFunction"/> /// on each iteration. /// </summary> /// /// <returns>True if the model training has converged, false otherwise.</returns> /// public bool Compute() { double lastError = Double.PositiveInfinity; for (int i = 0; i < MaxIterations; i++) { CrossValidationValues <TModel> value = IterationFunction(i); double currentError = value.TrainingValue; // If the storage mode is set to all models, the history should store // all created models alongside with the validation and training errors. if (Mode == ModelStorageMode.AllModels) { // Create a copy of the model information and of the created model. We // have to clone it because it will keep changing in further iterations. CrossValidationValues <TModel> clone = cloneValue(value, includeModel: true); History[i] = clone; // Check if we should store the value as current maximum or minimum if (MinValidationValue.Value == null || MaxValidationValue.Value == null) { // If this is the first iteration, store the first model as max/min MinValidationValue = new KeyValuePair <int, CrossValidationValues <TModel> >(i, clone); MaxValidationValue = new KeyValuePair <int, CrossValidationValues <TModel> >(i, clone); } else { // Store information only if the model is better if (value.ValidationValue < MinValidationValue.Value.ValidationValue) { MinValidationValue = new KeyValuePair <int, CrossValidationValues <TModel> >(i, clone); } if (value.ValidationValue > MaxValidationValue.Value.ValidationValue) { MaxValidationValue = new KeyValuePair <int, CrossValidationValues <TModel> >(i, clone); } } } else // if (Mode == ModelStorageMode.MinimumOnly || Mode == ModelStorageMode.MaximumOnly) { // Create a copy of the model information and of the created model. We // will not include the model at this step because we will be storing it // only if it is a minimum. CrossValidationValues <TModel> copy = cloneValue(value, includeModel: false); History[i] = copy; // Check if we should store the value as current maximum or minimum if (MinValidationValue.Value == null || MaxValidationValue.Value == null) { // If this is the first iteration, store the first model as current maximum and minimum CrossValidationValues <TModel> clone = cloneValue(value, includeModel: true); MaxValidationValue = MinValidationValue = new KeyValuePair <int, CrossValidationValues <TModel> >(i, clone); } else { // Store information only if the model is better if (value.ValidationValue < MinValidationValue.Value.ValidationValue) { CrossValidationValues <TModel> clone = cloneValue(value, includeModel: true); MinValidationValue = new KeyValuePair <int, CrossValidationValues <TModel> >(i, clone); } if (value.ValidationValue > MaxValidationValue.Value.ValidationValue) { CrossValidationValues <TModel> clone = cloneValue(value, includeModel: true); MaxValidationValue = new KeyValuePair <int, CrossValidationValues <TModel> >(i, clone); } } } // Check for convergence if (Math.Abs(currentError - lastError) < Tolerance * Math.Abs(lastError)) { return(true); // converged } } // Maximum iterations reached return(Tolerance == 0 ? false : true); }