Beispiel #1
0
        public void TestMsgPackSamplerMultipleFiles()
        {
            var files = new string[2];

            for (var i = 0; i < 2; ++i)
            {
                files[i] = Path.GetTempFileName();
                var a   = DataSourceFactory.Create(new float[] { i, i + 1, i + 2 }, new int[] { 3, 1 });
                var dss = new DataSourceSet();
                dss.Add("a", a);
                using (var stream = new FileStream(files[i], FileMode.Create, FileAccess.Write))
                {
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(files, 3, false, -1, 10, false, 10))
            {
                Assert.AreEqual(2, sampler.SampleCountPerEpoch);

                var batch = sampler.GetNextMinibatch();
                CollectionAssert.AreEqual(new int[] { 3, 3 }, batch.Features["a"].Shape.Dimensions.ToArray());
                var values = DataSourceFactory.FromValue(batch.Features["a"]);
                CollectionAssert.AreEqual(new float[] { 0, 1, 2, 1, 2, 3, 0, 1, 2 }, values.TypedData);
            }
        }
Beispiel #2
0
        protected override void EndProcessing()
        {
            Path = Path.Select(p => IO.GetAbsolutePath(this, p)).ToArray();

            var sampler = new MsgPackSampler(Path, MinibatchSize, Randomize, SampleCountPerEpoch, QueueSize, ReuseSamples, BufferSize, TimeoutForAdd, TimeoutForTake);

            WriteObject(sampler);
        }
Beispiel #3
0
        public void TestMsgPackSamplerSlicing2()
        {
            const int NUM_SAMPLES = 300;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_SAMPLES; ++i)
                {
                    var a   = DataSourceFactory.Create(new float[] { i, i, i }, new int[] { 1, 1, 3 });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, 5, false, 7, 10, false, 100))
            {
                int count = 0;
                for (var i = 0; i < 10; ++i)
                {
                    for (var j = 0; j < NUM_SAMPLES * 3 / 5; ++j)
                    {
                        var batch = sampler.GetNextMinibatch();
                        Assert.AreEqual(1, batch.Features.Count);
                        Assert.IsTrue(batch.Features.ContainsKey("a"));

                        var value = batch["a"];
                        CollectionAssert.AreEqual(new int[] { 1, 1, 5 }, value.Shape.Dimensions.ToArray());

                        var ds = DataSourceFactory.FromValue(value).ToArray();
                        Debug.WriteLine(string.Join(", ", ds));
                        CollectionAssert.AreEqual(new float[]
                        {
                            j * 5 / 3,
                            (j * 5 + 1) / 3,
                            (j * 5 + 2) / 3,
                            (j * 5 + 3) / 3,
                            (j * 5 + 4) / 3
                        }, ds);
                        count += 5;
                    }
                }
            }
        }
Beispiel #4
0
        public void TestMsgPackSamplerRandomize()
        {
            const int NUM_CHUNKS     = 30;
            const int CHUNK_SIZE     = 6;
            const int MINIBATCH_SIZE = 2;

            var data        = new float[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 };
            int FEATURE_DIM = data.Length / CHUNK_SIZE;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_CHUNKS; ++i)
                {
                    var a   = DataSourceFactory.Create(data, new int[] { FEATURE_DIM, CHUNK_SIZE });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, MINIBATCH_SIZE, true, 10, 100, false, 100))
            {
                for (var i = 0; i < NUM_CHUNKS; ++i)
                {
                    var values = new float[data.Length];
                    for (var j = 0; j < data.Length; j += FEATURE_DIM * MINIBATCH_SIZE)
                    {
                        var batch = sampler.GetNextMinibatch();
                        var value = DataSourceFactory.FromValue(batch["a"]);
                        CollectionAssert.AreEqual(new int[] { FEATURE_DIM, MINIBATCH_SIZE }, value.Shape.Dimensions.ToArray());
                        for (var k = 0; k < FEATURE_DIM * MINIBATCH_SIZE; ++k)
                        {
                            values[j + k] = value[k];
                        }
                    }

                    CollectionAssert.AreNotEqual(data, values);
                    var sorted = values.ToList();
                    sorted.Sort();
                    CollectionAssert.AreEqual(data, sorted);
                }
            }
        }
Beispiel #5
0
        public static IEnumerable <DataSourceSet> ReadDataSourceSet(string file, int totalSampleCount, int splitCount)
        {
            var device = DeviceDescriptor.CPUDevice;

            using (var sampler = new MsgPackSampler(new string[] { file }, splitCount, false, totalSampleCount, 1000, false))
            {
                for (var count = 0; count < totalSampleCount;)
                {
                    var dss = sampler.Deque();
                    count += dss.SampleCount;

                    if (count > totalSampleCount)
                    {
                        dss = dss.Subset(0, splitCount - (count - totalSampleCount));
                    }

                    yield return(dss);
                }
            }
        }
Beispiel #6
0
        public void TestMsgPackSampler()
        {
            const int NUM_SAMPLES = 1000;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_SAMPLES; ++i)
                {
                    var a   = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 3, 1, 1 });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, 1, false, 3, 10, false, 100))
            {
                for (var i = 0; i < 10; ++i)
                {
                    for (var j = 0; j < NUM_SAMPLES; ++j)
                    {
                        var batch = sampler.GetNextMinibatch();
                        Assert.AreEqual(1, batch.Features.Count);
                        Assert.IsTrue(batch.Features.ContainsKey("a"));
                        Assert.AreEqual((i * NUM_SAMPLES + j + 1) % 3 == 0, batch.SweepEnd);

                        var value = batch["a"];
                        CollectionAssert.AreEqual(new int[] { 3, 1, 1 }, value.Shape.Dimensions.ToArray());

                        var ds = DataSourceFactory.FromValue(value).ToArray();
                        //                        Debug.WriteLine(string.Join(", ", ds));
                        CollectionAssert.AreEqual(new float[] { j, j * 10, j * 100 }, ds);
                    }
                }
            }
        }
Beispiel #7
0
        public void TestMsgPackSamplerSampleCountPerEpoch()
        {
            const int NUM_SAMPLES = 3;
            const int NUM_CHUNKS  = 100;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_CHUNKS; ++i)
                {
                    var a   = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 1, 1, NUM_SAMPLES });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, 1, false, -1, 10, false, 100))
            {
                Assert.AreEqual(NUM_SAMPLES * NUM_CHUNKS, sampler.SampleCountPerEpoch);
            }
        }
Beispiel #8
0
        public void TestMsgPackSamplerReuseSamples()
        {
            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < 1000; ++i)
                {
                    var a   = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 3, 1, 1 });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, 1, false, 100, 100, true, 1000))
            {
                for (var i = 0; i < 10000; ++i)
                {
                    var batch = sampler.GetNextMinibatch();
                    Assert.IsTrue(batch.Features.ContainsKey("a"));
                }
            }
        }