public void TestSerializeDeserialize() { var a = DataSourceFactory.Create(new float[] { 1, 2, 3 }, new int[] { 3, 1, 1 }); var b = DataSourceFactory.Create(new float[] { 4, 5, 6, 7 }, new int[] { 2, 2, 1 }); var dss = new DataSourceSet(); dss.Add("a", a); dss.Add("b", b); var stream = new MemoryStream(); MsgPackSerializer.Serialize(dss, stream); stream.Position = 0; var result = MsgPackSerializer.Deserialize(stream); Assert.AreEqual(2, result.Features.Count); var x = result["a"]; CollectionAssert.AreEqual(new int[] { 3, 1, 1 }, x.Shape.Dimensions); CollectionAssert.AreEqual(new float[] { 1, 2, 3 }, x.Data.ToArray()); var y = result["b"]; CollectionAssert.AreEqual(new int[] { 2, 2, 1 }, y.Shape.Dimensions); CollectionAssert.AreEqual(new float[] { 4, 5, 6, 7 }, y.Data.ToArray()); }
public void TestMsgPackSamplerMultipleFiles() { var files = new string[2]; for (var i = 0; i < 2; ++i) { files[i] = Path.GetTempFileName(); var a = DataSourceFactory.Create(new float[] { i, i + 1, i + 2 }, new int[] { 3, 1 }); var dss = new DataSourceSet(); dss.Add("a", a); using (var stream = new FileStream(files[i], FileMode.Create, FileAccess.Write)) { MsgPackSerializer.Serialize(dss, stream); } } using (var sampler = new MsgPackSampler(files, 3, false, -1, 10, false, 10)) { Assert.AreEqual(2, sampler.SampleCountPerEpoch); var batch = sampler.GetNextMinibatch(); CollectionAssert.AreEqual(new int[] { 3, 3 }, batch.Features["a"].Shape.Dimensions.ToArray()); var values = DataSourceFactory.FromValue(batch.Features["a"]); CollectionAssert.AreEqual(new float[] { 0, 1, 2, 1, 2, 3, 0, 1, 2 }, values.TypedData); } }
public void TestReadDataSourceSet() { const int NUM_SAMPLES = 10; DataSourceSet dss; var file = Path.GetTempFileName(); using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write)) { for (var i = 0; i < NUM_SAMPLES; ++i) { var a = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 3, 1, 1 }); dss = new DataSourceSet(); dss.Add("a", a); MsgPackSerializer.Serialize(dss, stream); } } var total = MsgPackTools.GetTotalSampleCount(file); Assert.AreEqual(NUM_SAMPLES, total); var reader = MsgPackTools.ReadDataSourceSet(file, total, 4).GetEnumerator(); var hasNext = reader.MoveNext(); Assert.AreEqual(true, hasNext); dss = reader.Current; Assert.AreEqual(4, dss.SampleCount); CollectionAssert.AreEqual(new int[] { 3, 1, 4 }, dss["a"].Shape.Dimensions); CollectionAssert.AreEqual(new float[] { 0, 0, 0, 1, 10, 100, 2, 20, 200, 3, 30, 300 }, dss["a"].Data.ToArray()); hasNext = reader.MoveNext(); Assert.AreEqual(true, hasNext); dss = reader.Current; Assert.AreEqual(4, dss.SampleCount); CollectionAssert.AreEqual(new int[] { 3, 1, 4 }, dss["a"].Shape.Dimensions); CollectionAssert.AreEqual(new float[] { 4, 40, 400, 5, 50, 500, 6, 60, 600, 7, 70, 700 }, dss["a"].Data.ToArray()); hasNext = reader.MoveNext(); Assert.AreEqual(true, hasNext); dss = reader.Current; Assert.AreEqual(2, dss.SampleCount); CollectionAssert.AreEqual(new int[] { 3, 1, 2 }, dss["a"].Shape.Dimensions); CollectionAssert.AreEqual(new float[] { 8, 80, 800, 9, 90, 900 }, dss["a"].Data.ToArray()); hasNext = reader.MoveNext(); Assert.AreEqual(false, hasNext); }
public void TestMsgPackSamplerSlicing2() { const int NUM_SAMPLES = 300; var file = Path.GetTempFileName(); using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write)) { for (var i = 0; i < NUM_SAMPLES; ++i) { var a = DataSourceFactory.Create(new float[] { i, i, i }, new int[] { 1, 1, 3 }); var dss = new DataSourceSet(); dss.Add("a", a); MsgPackSerializer.Serialize(dss, stream); } } using (var sampler = new MsgPackSampler(file, 5, false, 7, 10, false, 100)) { int count = 0; for (var i = 0; i < 10; ++i) { for (var j = 0; j < NUM_SAMPLES * 3 / 5; ++j) { var batch = sampler.GetNextMinibatch(); Assert.AreEqual(1, batch.Features.Count); Assert.IsTrue(batch.Features.ContainsKey("a")); var value = batch["a"]; CollectionAssert.AreEqual(new int[] { 1, 1, 5 }, value.Shape.Dimensions.ToArray()); var ds = DataSourceFactory.FromValue(value).ToArray(); Debug.WriteLine(string.Join(", ", ds)); CollectionAssert.AreEqual(new float[] { j * 5 / 3, (j * 5 + 1) / 3, (j * 5 + 2) / 3, (j * 5 + 3) / 3, (j * 5 + 4) / 3 }, ds); count += 5; } } } }
public void TestMsgPackSamplerRandomize() { const int NUM_CHUNKS = 30; const int CHUNK_SIZE = 6; const int MINIBATCH_SIZE = 2; var data = new float[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }; int FEATURE_DIM = data.Length / CHUNK_SIZE; var file = Path.GetTempFileName(); using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write)) { for (var i = 0; i < NUM_CHUNKS; ++i) { var a = DataSourceFactory.Create(data, new int[] { FEATURE_DIM, CHUNK_SIZE }); var dss = new DataSourceSet(); dss.Add("a", a); MsgPackSerializer.Serialize(dss, stream); } } using (var sampler = new MsgPackSampler(file, MINIBATCH_SIZE, true, 10, 100, false, 100)) { for (var i = 0; i < NUM_CHUNKS; ++i) { var values = new float[data.Length]; for (var j = 0; j < data.Length; j += FEATURE_DIM * MINIBATCH_SIZE) { var batch = sampler.GetNextMinibatch(); var value = DataSourceFactory.FromValue(batch["a"]); CollectionAssert.AreEqual(new int[] { FEATURE_DIM, MINIBATCH_SIZE }, value.Shape.Dimensions.ToArray()); for (var k = 0; k < FEATURE_DIM * MINIBATCH_SIZE; ++k) { values[j + k] = value[k]; } } CollectionAssert.AreNotEqual(data, values); var sorted = values.ToList(); sorted.Sort(); CollectionAssert.AreEqual(data, sorted); } } }
public void TestMsgPackSampler() { const int NUM_SAMPLES = 1000; var file = Path.GetTempFileName(); using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write)) { for (var i = 0; i < NUM_SAMPLES; ++i) { var a = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 3, 1, 1 }); var dss = new DataSourceSet(); dss.Add("a", a); MsgPackSerializer.Serialize(dss, stream); } } using (var sampler = new MsgPackSampler(file, 1, false, 3, 10, false, 100)) { for (var i = 0; i < 10; ++i) { for (var j = 0; j < NUM_SAMPLES; ++j) { var batch = sampler.GetNextMinibatch(); Assert.AreEqual(1, batch.Features.Count); Assert.IsTrue(batch.Features.ContainsKey("a")); Assert.AreEqual((i * NUM_SAMPLES + j + 1) % 3 == 0, batch.SweepEnd); var value = batch["a"]; CollectionAssert.AreEqual(new int[] { 3, 1, 1 }, value.Shape.Dimensions.ToArray()); var ds = DataSourceFactory.FromValue(value).ToArray(); // Debug.WriteLine(string.Join(", ", ds)); CollectionAssert.AreEqual(new float[] { j, j * 10, j * 100 }, ds); } } } }
public void TestMsgPackSamplerSampleCountPerEpoch() { const int NUM_SAMPLES = 3; const int NUM_CHUNKS = 100; var file = Path.GetTempFileName(); using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write)) { for (var i = 0; i < NUM_CHUNKS; ++i) { var a = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 1, 1, NUM_SAMPLES }); var dss = new DataSourceSet(); dss.Add("a", a); MsgPackSerializer.Serialize(dss, stream); } } using (var sampler = new MsgPackSampler(file, 1, false, -1, 10, false, 100)) { Assert.AreEqual(NUM_SAMPLES * NUM_CHUNKS, sampler.SampleCountPerEpoch); } }
public void TestMsgPackSamplerReuseSamples() { var file = Path.GetTempFileName(); using (var stream = new FileStream(file, FileMode.Create, FileAccess.Write)) { for (var i = 0; i < 1000; ++i) { var a = DataSourceFactory.Create(new float[] { i, i * 10, i * 100 }, new int[] { 3, 1, 1 }); var dss = new DataSourceSet(); dss.Add("a", a); MsgPackSerializer.Serialize(dss, stream); } } using (var sampler = new MsgPackSampler(file, 1, false, 100, 100, true, 1000)) { for (var i = 0; i < 10000; ++i) { var batch = sampler.GetNextMinibatch(); Assert.IsTrue(batch.Features.ContainsKey("a")); } } }