Ejemplo n.º 1
0
        public void TestSerializeDeserialize()
        {
            var a   = DataSourceFactory.Create(new float[] { 1, 2, 3 }, new int[] { 3, 1, 1 });
            var b   = DataSourceFactory.Create(new float[] { 4, 5, 6, 7 }, new int[] { 2, 2, 1 });
            var dss = new DataSourceSet();

            dss.Add("a", a);
            dss.Add("b", b);

            var stream = new MemoryStream();

            MsgPackSerializer.Serialize(dss, stream);

            stream.Position = 0;
            var result = MsgPackSerializer.Deserialize(stream);

            Assert.AreEqual(2, result.Features.Count);

            var x = result["a"];

            CollectionAssert.AreEqual(new int[] { 3, 1, 1 }, x.Shape.Dimensions);
            CollectionAssert.AreEqual(new float[] { 1, 2, 3 }, x.Data.ToArray());

            var y = result["b"];

            CollectionAssert.AreEqual(new int[] { 2, 2, 1 }, y.Shape.Dimensions);
            CollectionAssert.AreEqual(new float[] { 4, 5, 6, 7 }, y.Data.ToArray());
        }
Ejemplo n.º 2
0
        public void TestMsgPackSamplerMultipleFiles()
        {
            var files = new string[2];

            for (var i = 0; i < 2; ++i)
            {
                files[i] = Path.GetTempFileName();
                var a   = DataSourceFactory.Create(new float[] { i, i + 1, i + 2 }, new int[] { 3, 1 });
                var dss = new DataSourceSet();
                dss.Add("a", a);
                using (var stream = new FileStream(files[i], FileMode.Create, FileAccess.Write))
                {
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(files, 3, false, -1, 10, false, 10))
            {
                Assert.AreEqual(2, sampler.SampleCountPerEpoch);

                var batch = sampler.GetNextMinibatch();
                CollectionAssert.AreEqual(new int[] { 3, 3 }, batch.Features["a"].Shape.Dimensions.ToArray());
                var values = DataSourceFactory.FromValue(batch.Features["a"]);
                CollectionAssert.AreEqual(new float[] { 0, 1, 2, 1, 2, 3, 0, 1, 2 }, values.TypedData);
            }
        }
Ejemplo n.º 3
0
        public void TestReadDataSourceSet()
        {
            const int     NUM_SAMPLES = 10;
            DataSourceSet dss;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_SAMPLES; ++i)
                {
                    var a = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 3, 1, 1 });
                    dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            var total = MsgPackTools.GetTotalSampleCount(file);

            Assert.AreEqual(NUM_SAMPLES, total);

            var reader = MsgPackTools.ReadDataSourceSet(file, total, 4).GetEnumerator();

            var hasNext = reader.MoveNext();

            Assert.AreEqual(true, hasNext);
            dss = reader.Current;
            Assert.AreEqual(4, dss.SampleCount);
            CollectionAssert.AreEqual(new int[] { 3, 1, 4 }, dss["a"].Shape.Dimensions);
            CollectionAssert.AreEqual(new float[] { 0, 0, 0, 1, 10, 100, 2, 20, 200, 3, 30, 300 }, dss["a"].Data.ToArray());

            hasNext = reader.MoveNext();
            Assert.AreEqual(true, hasNext);
            dss = reader.Current;
            Assert.AreEqual(4, dss.SampleCount);
            CollectionAssert.AreEqual(new int[] { 3, 1, 4 }, dss["a"].Shape.Dimensions);
            CollectionAssert.AreEqual(new float[] { 4, 40, 400, 5, 50, 500, 6, 60, 600, 7, 70, 700 }, dss["a"].Data.ToArray());

            hasNext = reader.MoveNext();
            Assert.AreEqual(true, hasNext);
            dss = reader.Current;
            Assert.AreEqual(2, dss.SampleCount);
            CollectionAssert.AreEqual(new int[] { 3, 1, 2 }, dss["a"].Shape.Dimensions);
            CollectionAssert.AreEqual(new float[] { 8, 80, 800, 9, 90, 900 }, dss["a"].Data.ToArray());

            hasNext = reader.MoveNext();
            Assert.AreEqual(false, hasNext);
        }
Ejemplo n.º 4
0
        public void TestMsgPackSamplerSlicing2()
        {
            const int NUM_SAMPLES = 300;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_SAMPLES; ++i)
                {
                    var a   = DataSourceFactory.Create(new float[] { i, i, i }, new int[] { 1, 1, 3 });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, 5, false, 7, 10, false, 100))
            {
                int count = 0;
                for (var i = 0; i < 10; ++i)
                {
                    for (var j = 0; j < NUM_SAMPLES * 3 / 5; ++j)
                    {
                        var batch = sampler.GetNextMinibatch();
                        Assert.AreEqual(1, batch.Features.Count);
                        Assert.IsTrue(batch.Features.ContainsKey("a"));

                        var value = batch["a"];
                        CollectionAssert.AreEqual(new int[] { 1, 1, 5 }, value.Shape.Dimensions.ToArray());

                        var ds = DataSourceFactory.FromValue(value).ToArray();
                        Debug.WriteLine(string.Join(", ", ds));
                        CollectionAssert.AreEqual(new float[]
                        {
                            j * 5 / 3,
                            (j * 5 + 1) / 3,
                            (j * 5 + 2) / 3,
                            (j * 5 + 3) / 3,
                            (j * 5 + 4) / 3
                        }, ds);
                        count += 5;
                    }
                }
            }
        }
Ejemplo n.º 5
0
        public void TestMsgPackSamplerRandomize()
        {
            const int NUM_CHUNKS     = 30;
            const int CHUNK_SIZE     = 6;
            const int MINIBATCH_SIZE = 2;

            var data        = new float[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 };
            int FEATURE_DIM = data.Length / CHUNK_SIZE;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_CHUNKS; ++i)
                {
                    var a   = DataSourceFactory.Create(data, new int[] { FEATURE_DIM, CHUNK_SIZE });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, MINIBATCH_SIZE, true, 10, 100, false, 100))
            {
                for (var i = 0; i < NUM_CHUNKS; ++i)
                {
                    var values = new float[data.Length];
                    for (var j = 0; j < data.Length; j += FEATURE_DIM * MINIBATCH_SIZE)
                    {
                        var batch = sampler.GetNextMinibatch();
                        var value = DataSourceFactory.FromValue(batch["a"]);
                        CollectionAssert.AreEqual(new int[] { FEATURE_DIM, MINIBATCH_SIZE }, value.Shape.Dimensions.ToArray());
                        for (var k = 0; k < FEATURE_DIM * MINIBATCH_SIZE; ++k)
                        {
                            values[j + k] = value[k];
                        }
                    }

                    CollectionAssert.AreNotEqual(data, values);
                    var sorted = values.ToList();
                    sorted.Sort();
                    CollectionAssert.AreEqual(data, sorted);
                }
            }
        }
Ejemplo n.º 6
0
        public void TestMsgPackSampler()
        {
            const int NUM_SAMPLES = 1000;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_SAMPLES; ++i)
                {
                    var a   = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 3, 1, 1 });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, 1, false, 3, 10, false, 100))
            {
                for (var i = 0; i < 10; ++i)
                {
                    for (var j = 0; j < NUM_SAMPLES; ++j)
                    {
                        var batch = sampler.GetNextMinibatch();
                        Assert.AreEqual(1, batch.Features.Count);
                        Assert.IsTrue(batch.Features.ContainsKey("a"));
                        Assert.AreEqual((i * NUM_SAMPLES + j + 1) % 3 == 0, batch.SweepEnd);

                        var value = batch["a"];
                        CollectionAssert.AreEqual(new int[] { 3, 1, 1 }, value.Shape.Dimensions.ToArray());

                        var ds = DataSourceFactory.FromValue(value).ToArray();
                        //                        Debug.WriteLine(string.Join(", ", ds));
                        CollectionAssert.AreEqual(new float[] { j, j * 10, j * 100 }, ds);
                    }
                }
            }
        }
Ejemplo n.º 7
0
        public void TestMsgPackSamplerSampleCountPerEpoch()
        {
            const int NUM_SAMPLES = 3;
            const int NUM_CHUNKS  = 100;

            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < NUM_CHUNKS; ++i)
                {
                    var a   = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 1, 1, NUM_SAMPLES });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, 1, false, -1, 10, false, 100))
            {
                Assert.AreEqual(NUM_SAMPLES * NUM_CHUNKS, sampler.SampleCountPerEpoch);
            }
        }
Ejemplo n.º 8
0
        public void TestMsgPackSamplerReuseSamples()
        {
            var file = Path.GetTempFileName();

            using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write))
            {
                for (var i = 0; i < 1000; ++i)
                {
                    var a   = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 3, 1, 1 });
                    var dss = new DataSourceSet();
                    dss.Add("a", a);
                    MsgPackSerializer.Serialize(dss, stream);
                }
            }

            using (var sampler = new MsgPackSampler(file, 1, false, 100, 100, true, 1000))
            {
                for (var i = 0; i < 10000; ++i)
                {
                    var batch = sampler.GetNextMinibatch();
                    Assert.IsTrue(batch.Features.ContainsKey("a"));
                }
            }
        }