Exemplo n.º 1
0
        public void TestMinibatch()
        {
            var features = DataSourceFactory.Create(new float[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }, new int[] { 2, 1, -1 });

            var dss = new Dictionary <string, IDataSource <float> >()
            {
                { "input", features }
            };
            var sampler = new DataSourceSampler(dss, 2, false, true);

            {
                var batch = sampler.GetNextMinibatch();
                //                GC.Collect();
                //                GC.Collect();

                Assert.AreEqual(2, batch.SampleCount);
                Assert.AreEqual(false, batch.SweepEnd);

                var data = batch.Features["input"];
                // var c1 = SharedPtrMethods.GetUseCountOf(data);
                // var c2 = SharedPtrMethods.GetUseCountOf(data.data);
                // var c3 = SharedPtrMethods.GetUseCountOf(data.data.Data);
                var ds = DataSourceFactory.FromValue(data);
                CollectionAssert.AreEqual(new float[] { 0, 1, 2, 3 }, ds.TypedData);
                CollectionAssert.AreEqual(new int[] { 2, 1, 2 }, data.Shape.Dimensions.ToArray());
            }

            {
                var batch = sampler.GetNextMinibatch();
//                GC.Collect();
                Assert.AreEqual(2, batch.SampleCount);
                Assert.AreEqual(true, batch.SweepEnd);
                var data = batch.Features["input"];
                // var c1 = SharedPtrMethods.GetUseCountOf(data);
                // var c2 = SharedPtrMethods.GetUseCountOf(data.data);
                // var c3 = SharedPtrMethods.GetUseCountOf(data.data.Data);
                var ds = DataSourceFactory.FromValue(data);
                CollectionAssert.AreEqual(new float[] { 4, 5, 6, 7 }, ds.TypedData);
                CollectionAssert.AreEqual(new int[] { 2, 1, 2 }, data.Shape.Dimensions.ToArray());
            }

            // When not randomized, remnant data that is smaller than the minibatch size is ignored.
            {
                var batch = sampler.GetNextMinibatch();
//                GC.Collect();
                Assert.AreEqual(2, batch.SampleCount);
                Assert.AreEqual(false, batch.SweepEnd);
                var data = batch.Features["input"];
                var ds   = DataSourceFactory.FromValue(data);
                CollectionAssert.AreEqual(new float[] { 0, 1, 2, 3 }, ds.TypedData);
                CollectionAssert.AreEqual(new int[] { 2, 1, 2 }, data.Shape.Dimensions.ToArray());
            }
        }
Exemplo n.º 2
0
        public void TestTrainingSession2()
        {
            // Data

            var features = DataSourceFactory.Create(new float[] { 0, 0, 0, 1, 1, 0, 1, 1, 3, 4, 3, 5, 4, 4, 4, 5 }, new int[] { 2, 1, -1 });
            var labels   = DataSourceFactory.Create(new float[] { 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0 }, new int[] { 2, 1, -1 });

            var sampler = new DataSourceSampler(new Dictionary <string, IDataSource <float> >()
            {
                { "input", features },
                { "label", labels }
            }, 2);

            // Model

            var input = CNTKLib.InputVariable(new int[] { 2 }, false, DataType.Float, "input");
            var h     = Composite.Dense(input, new int[] { 100 }, CNTKLib.HeNormalInitializer(), true, null, false, 4, "relu", DeviceDescriptor.UseDefaultDevice(), "");

            h = Composite.Dense(h, new int[] { 2 }, CNTKLib.GlorotNormalInitializer(), true, null, false, 4, "sigmoid", DeviceDescriptor.UseDefaultDevice(), "");
            var output = h;

            var label = CNTKLib.InputVariable(new int[] { 2 }, DataType.Float, "label");

            // Loss and metric functions

            var loss  = CNTKLib.BinaryCrossEntropy(output, label);
            var error = CNTKLib.ClassificationError(output, label);

            // Train

            var lr = new TrainingParameterScheduleDouble(.01);
            var m  = new TrainingParameterScheduleDouble(.9);

            var learner = Learner.MomentumSGDLearner(output.Parameters(), lr, m, true);

            var session   = new TrainingSession(output, loss, error, learner, null, sampler, null);
            var iteration = session.GetEnumerator();

            for (var i = 0; i < 1000; ++i)
            {
                iteration.MoveNext();
                var dummy = iteration.Current;
                var valid = session.GetValidationMetric();
            }

            Assert.IsTrue(session.Metric < 0.1);
        }
Exemplo n.º 3
0
        public void TestSequence()
        {
            var features = DataSourceFactory.Create(new float[] { 0, 1, 2, 3, 4, 5, 6, 7 }, new int[] { 2, 2, -1 });

            var ds = new Dictionary <string, IDataSource <float> >()
            {
                { "input", features }
            };
            var sampler = new DataSourceSampler(ds, 2, false, true);

            var batch = sampler.GetNextMinibatch();

            Assert.AreEqual(2, batch.SampleCount);
            Assert.AreEqual(true, batch.SweepEnd);
            var data = batch.Features["input"];

            CollectionAssert.AreEqual(new float[] { 0, 1, 2, 3, 4, 5, 6, 7 }, DataSourceFactory.FromValue(data).TypedData);
            CollectionAssert.AreEqual(new int[] { 2, 2, 2 }, data.Shape.Dimensions.ToArray());
        }