Beispiel #1
0
        /// <summary>
        /// Compute sample weights such that the class distribution of y becomes
        /// balanced.
        /// </summary>
        /// <param name="?"></param>
        private static double[] BalanceWeights(int[] y)
        {
            var encoder = new LabelEncoder <int>();

            y = encoder.FitTransform(y);
            var bins = Np.BinCount(y);


            var weights = bins.ElementsAt(y).Select(v => 1.0 / v * bins.Min()).ToArray();

            return(weights);
        }
Beispiel #2
0
        public void TestMinSamplesLeaf()
        {
            foreach (var name in CLF_TREES)
            {
                var est = CreateClassifier <double>(name, min_samples_leaf: 5, random: new Random(0));
                est.Fit(X, y);
                var @out        = est.Tree.apply(X.ToDenseMatrix());
                var node_counts = Np.BinCount(@out.Select(v => (int)v).ToArray());
                var leaf_count  = node_counts.Where(v => v != 0).ToList(); // drop inner nodes
                Assert.IsTrue(leaf_count.Min() > 4,
                              "Failed with {0}".Frmt(name));
            }

            foreach (var name in RegTrees)
            {
                var est = CreateRegressor(name, min_samples_leaf: 5, random: new Random(0));
                est.Fit(X, y);
                var @out       = est.Tree.apply(X.ToDenseMatrix());
                var nodeCounts = Np.BinCount(@out.Select(v => (int)v).ToArray());
                var leafCount  = nodeCounts.Where(v => v != 0).ToList(); // drop inner nodes
                Assert.IsTrue(leafCount.Min() > 4,
                              "Failed with {0}".Frmt(name));
            }
        }
Beispiel #3
0
        public void TestSampleWeight()
        {
            // Test that zero-weighted samples are not taken into account
            var X = Enumerable.Range(0, 100).ToColumnMatrix();
            var y = Enumerable.Repeat(1, 100).ToArray();

            Array.Clear(y, 0, 50);

            var sampleWeight = Enumerable.Repeat(1, 100).ToVector();

            sampleWeight.SetSubVector(0, 50, Enumerable.Repeat(0, 50).ToVector());

            var clf = new DecisionTreeClassifier <int>(random: new Random(0));

            clf.Fit(X, y, sampleWeight: sampleWeight);
            AssertExt.ArrayEqual(clf.Predict(X), Enumerable.Repeat(1, 100).ToArray());

            // Test that low weighted samples are not taken into account at low depth
            X = Enumerable.Range(0, 200).ToColumnMatrix();
            y = new int[200];
            Array.Copy(Enumerable.Repeat(1, 50).ToArray(), 0, y, 50, 50);
            Array.Copy(Enumerable.Repeat(2, 100).ToArray(), 0, y, 100, 100);
            X.SetSubMatrix(100, 100, 0, 1, Enumerable.Repeat(200, 100).ToColumnMatrix());

            sampleWeight = Enumerable.Repeat(1, 200).ToVector();

            sampleWeight.SetSubVector(100, 100, Enumerable.Repeat(0.51, 100).ToVector());
            // Samples of class '2' are still weightier
            clf = new DecisionTreeClassifier <int>(maxDepth: 1, random: new Random(0));
            clf.Fit(X, y, sampleWeight: sampleWeight);
            Assert.AreEqual(149.5, clf.Tree.Threshold[0]);

            sampleWeight.SetSubVector(100, 100, Enumerable.Repeat(0.50, 100).ToVector());
            // Samples of class '2' are no longer weightier
            clf = new DecisionTreeClassifier <int>(maxDepth: 1, random: new Random(0));
            clf.Fit(X, y, sampleWeight: sampleWeight);
            Assert.AreEqual(49.5, clf.Tree.Threshold[0]); // Threshold should have moved


            // Test that sample weighting is the same as having duplicates
            X = iris.Data;
            y = iris.Target;

            var random     = new Random(0);
            var duplicates = new int[200];

            for (int i = 0; i < duplicates.Length; i++)
            {
                duplicates[i] = random.Next(X.RowCount);
            }

            clf = new DecisionTreeClassifier <int>(random: new Random(1));
            clf.Fit(X.RowsAt(duplicates), y.ElementsAt(duplicates));


            sampleWeight = Np.BinCount(duplicates, minLength: X.RowCount).ToVector();
            var clf2 = new DecisionTreeClassifier <int>(random: new Random(1));

            clf2.Fit(X, y, sampleWeight: sampleWeight);


            var @internal = clf.Tree.ChildrenLeft.Indices(v => v != Tree._TREE_LEAF);

            AssertExt.AlmostEqual(clf.Tree.Threshold.ElementsAt(@internal),
                                  clf2.Tree.Threshold.ElementsAt(@internal));
        }