示例#1
0
        public void BinaryNetSerializerTest()
        {
            var net = new Net();

            net.AddLayer(new InputLayer(5, 5, 3));
            var conv = new ConvLayer(2, 2, 16);

            net.AddLayer(conv);
            var fullycon = new FullyConnLayer(3);

            net.AddLayer(fullycon);
            net.AddLayer(new SoftmaxLayer(3));

            // Serialize (binary)
            using (var ms = new MemoryStream())
            {
                net.SaveBinary(ms);
                ms.Position = 0;

                // Deserialize (binary)
                Net deserialized = SerializationExtensions.LoadBinary(ms) as Net;

                // Make sure deserialized is identical to serialized
                Assert.IsNotNull(deserialized.Layers);
                Assert.AreEqual(net.Layers.Count, deserialized.Layers.Count);
                Assert.IsTrue(net.Layers[0] is InputLayer);

                var deserializedConv = net.Layers[1] as ConvLayer;
                Assert.NotNull(deserializedConv);
                Assert.NotNull(deserializedConv.Filters);
                Assert.AreEqual(16, deserializedConv.Filters.Count);
                for (int i = 0; i < deserializedConv.Filters.Count; i++)
                {
                    for (int k = 0; k < deserializedConv.Filters[i].Length; k++)
                    {
                        Assert.AreEqual(conv.Filters[i].Get(k), deserializedConv.Filters[i].Get(k));
                        Assert.AreEqual(conv.Filters[i].GetGradient(k), deserializedConv.Filters[i].GetGradient(k));
                    }
                }

                var deserializedFullyCon = net.Layers[2] as FullyConnLayer;
                Assert.NotNull(deserializedFullyCon);
                Assert.NotNull(deserializedFullyCon.Filters);
                Assert.AreEqual(3, deserializedFullyCon.Filters.Count);
                for (int i = 0; i < deserializedFullyCon.Filters.Count; i++)
                {
                    for (int k = 0; k < deserializedFullyCon.Filters[i].Length; k++)
                    {
                        Assert.AreEqual(fullycon.Filters[i].Get(k), deserializedFullyCon.Filters[i].Get(k));
                        Assert.AreEqual(fullycon.Filters[i].GetGradient(k), deserializedFullyCon.Filters[i].GetGradient(k));
                    }
                }

                Assert.IsTrue(deserialized.Layers[3] is SoftmaxLayer);
                Assert.AreEqual(3, ((SoftmaxLayer)deserialized.Layers[3]).ClassCount);
            }
        }
        private void LoadNeuralNetwork()
        {
            //FileStream networkFileStream = File.OpenRead(String.Format(@"{0}\neural-network-40-180-2.txt", dataFolder));
            //network = Network.Load(networkFileStream);
            //
            //String convNetJson = File.ReadAllText(String.Format(@"{0}\conv-neural-network.json", dataFolder));
            //convNet = SerializationExtensions.FromJSON(convNetJson);

            FileStream convNetBin = File.OpenRead(String.Format(@"{0}\conv-neural-network-bin", dataFolder));

            convNet = (Net)SerializationExtensions.LoadBinary(convNetBin);
        }
示例#3
0
    public static INet LoadNet(string filename)
    {
        INet result = null;

        if (File.Exists(filename))
        {
            using (var fs = new FileStream(filename, FileMode.Open))
            {
                result = SerializationExtensions.LoadBinary(fs);
            }
        }

        return(result);
    }
示例#4
0
        public void FluentBinaryNetSerializerTest()
        {
            var net = FluentNet.Create(5, 5, 3)
                      .Conv(2, 2, 16)
                      .FullyConn(3)
                      .Softmax(3)
                      .Build();

            // Serialize (binary)
            using (var ms = new MemoryStream())
            {
                net.SaveBinary(ms);
                ms.Position = 0;

                // Deserialize (binary)
                FluentNet deserialized = SerializationExtensions.LoadBinary(ms) as FluentNet;

                Assert.IsNotNull(deserialized);
                Assert.AreEqual(net.InputLayers.Count, deserialized.InputLayers.Count);

                // TODO: improve test
            }
        }