Ejemplo n.º 1
0
        public void TestAddResidualBlock()
        {
            var inputLayer = VertexFactory.InputLayer("data5");
            var conv       = VertexFactory.ConvLayer("conv5", VertexFactory.ConvLayerParam(7, 64));
            var batchNorm  = VertexFactory.BatchNormLayer("bn5");
            var scale      = VertexFactory.ScaleLayer("scale5");
            var conv2      = VertexFactory.ConvLayer("conv5-2", VertexFactory.ConvLayerParam(7, 64));
            var batchNorm2 = VertexFactory.BatchNormLayer("bn5-2");
            var scale2     = VertexFactory.ScaleLayer("scale5-2");
            var conv3      = VertexFactory.ConvLayer("conv5-3", VertexFactory.ConvLayerParam(7, 64));
            var batchNorm3 = VertexFactory.BatchNormLayer("bn5-3");
            var scale3     = VertexFactory.ScaleLayer("scale5-3");
            var eltwise    = VertexFactory.EltwiseLayer("res5");
            var pool       = VertexFactory.PoolLayer("pool5", VertexFactory.PoolLayerParam(5));
            var softmax    = VertexFactory.SoftmaxLayer("softmax5", 5);

            var builder = new NetworkBuilder("test5")
                          .AddInputLayer(inputLayer)
                          .AddResidualBlock(
                left: b => b
                .AddLayerBlock(lb => lb
                               .AddLayer(conv)
                               .AddBatchNorm(batchNorm)
                               .AddScale(scale)),
                right: b => b
                .AddLayerBlock(lb => lb
                               .AddLayer(conv2)
                               .AddBatchNorm(batchNorm2)
                               .AddScale(scale2))
                .AddLayerBlock(lb => lb
                               .AddLayer(conv3)
                               .AddBatchNorm(batchNorm3)
                               .AddScale(scale3)))
                          .AddEltwise(eltwise)
                          .AddLayer(pool)
                          .AddLayer(softmax);

            builder.PersistGraph();

            //Assert.IsTrue(graph.ContainsEdge(inputLayer, conv));
            //Assert.IsTrue(graph.ContainsEdge(inputLayer, conv2));
            //Assert.IsFalse(graph.ContainsEdge(inputLayer, conv3));

            //Assert.IsFalse(graph.ContainsEdge(conv, conv2));

            //Assert.IsTrue(graph.ContainsEdge(conv, batchNorm));
            //Assert.IsTrue(graph.ContainsEdge(conv, scale));
            //Assert.IsFalse(graph.ContainsEdge(inputLayer, batchNorm));
            //Assert.IsFalse(graph.ContainsEdge(inputLayer, scale));

            //Assert.IsTrue(graph.ContainsEdge(conv2, batchNorm2));
            //Assert.IsTrue(graph.ContainsEdge(conv2, scale2));

            //Assert.IsTrue(graph.ContainsEdge(conv2, conv3));
            //Assert.IsTrue(graph.ContainsEdge(conv3, batchNorm3));
            //Assert.IsTrue(graph.ContainsEdge(conv3, scale3));

            //Assert.IsFalse(graph.ContainsEdge(conv3, conv));
            //Assert.IsTrue(graph.ContainsEdge(conv3, eltwise));
            //Assert.IsTrue(graph.ContainsEdge(conv, eltwise));

            //Assert.IsTrue(graph.ContainsEdge(eltwise, pool));
        }