/// <summary>
 /// Stacking Regression Ensemble Learner.
 /// Combines several models into a single ensemble model using a top or meta level model to combine the models.
 /// The bottom level models generates output for the top level model using cross validation.
 /// Default is 5-fold RandomCrossValidation.
 /// </summary>
 /// <param name="learners">Learners in the ensemble</param>
 /// <param name="metaLearner">Meta learner or top level model for combining the ensemble models</param>
 /// <param name="includeOriginalFeaturesForMetaLearner">True; the meta learner also receives the original features.
 /// False; the meta learner only receives the output of the ensemble models as features. Default is true</param>
 public RegressionStackingEnsembleLearner(
     IIndexedLearner <double>[] learners,
     ILearner <double> metaLearner,
     bool includeOriginalFeaturesForMetaLearner = true)
     : this(learners, (obs, targets) => metaLearner.Learn(obs, targets),
            new RandomCrossValidation <double>(5, 42), includeOriginalFeaturesForMetaLearner)
 {
 }
 /// <summary>
 /// Stacking Classification Ensemble Learner.
 /// Combines several models into a single ensemble model using a top or meta level model to combine the models.
 /// The bottom level models generates output for the top level model using cross validation.
 /// </summary>
 /// <param name="learners">Learners in the ensemble</param>
 /// <param name="metaLearner">Meta learner or top level model for combining the ensemble models</param>
 /// <param name="crossValidation">Cross validation method</param>
 /// <param name="includeOriginalFeaturesForMetaLearner">True; the meta learner also receives the original features.
 /// False; the meta learner only receives the output of the ensemble models as features. Default is true</param>
 public ClassificationStackingEnsembleLearner(
     IIndexedLearner <ProbabilityPrediction>[] learners,
     ILearner <ProbabilityPrediction> metaLearner,
     ICrossValidation <ProbabilityPrediction> crossValidation,
     bool includeOriginalFeaturesForMetaLearner = true)
     : this(learners, (obs, targets) => metaLearner.Learn(obs, targets),
            crossValidation, includeOriginalFeaturesForMetaLearner)
 {
 }
示例#3
0
        public void TrainModel()
        {
            var miniBatches = CreateMiniBatches(trainingInputOutputPairs, batchSize);

            foreach (var batch in miniBatches)
            {
                learner.Learn(batch);
            }
        }
示例#4
0
 static void TeachSomething(ILearner learner, string lesson)
 {
     learner.Learn(lesson);
 }
示例#5
0
 public void Teach(Lesson lesson, ILearner learner)
 {
     learner.Learn(lesson);
 }
示例#6
0
    public void Train(string dataName, string savePath)
    {
        var parser     = new CsvParser(() => new StreamReader(root + dataName + ".csv"), ',');
        var targetName = "f5";
        var enemyName  = "f6";

        try
        {
            // read data
            var observations = parser.EnumerateRows(c => c != targetName && c != enemyName)
                               .ToF64Matrix();

            var targets = parser.EnumerateRows(targetName)
                          .ToF64Vector();

            //Debug.Log("READ");

            // split data
            //var splitter = new StratifiedTrainingTestIndexSplitter<double>(trainingPercentage: 0.9, seed: 1);
            //var trainingTestSplit = splitter.SplitSet(observations, targets);
            //var trainingSet = trainingTestSplit.TrainingSet;
            //var testSet = trainingTestSplit.TestSet;

            //Debug.Log("SPLIT");

            //model = learner.Learn(trainingSet.Observations, trainingSet.Targets);
            model = learner.Learn(observations, targets);

            //Debug.Log("TRAINED");

            //var met1 = new MeanSquaredErrorRegressionMetric();
            //var met2 = new CoefficientOfDeterminationMetric();
            //var met3 = new MeanAbsolutErrorRegressionMetric();
            //var met4 = new NormalizedGiniCoefficientRegressionMetric();
            //var met5 = new RootMeanSquarePercentageRegressionMetric();

            //var pred = model.Predict(testSet.Observations);
            //Debug.Log(met1.Error(testSet.Targets, pred) + " <-> " + type);
            //Debug.Log(met2.Error(testSet.Targets, pred) + " <-> " + type);
            //Debug.Log(met3.Error(testSet.Targets, pred) + " <-> " + type);
            //Debug.Log(met4.Error(testSet.Targets, pred) + " <-> " + type);
            //Debug.Log(met5.Error(testSet.Targets, pred) + " <-> " + type);

            // NEURAL TEST
            Debug.Log("TRAINED");

            // save
            SaveModel(savePath);

            // load
            //LoadModel("" + LevelData.instance.stage + "/" + LevelData.instance.playerID + ".model");
        }
        catch (Exception e)
        {
            Debug.Log(e.Message);
            model = null;
        }
        finally
        {
        }
    }