コード例 #1
0
        public void getFilteredNodesXToY_two_operations_x0_x1_x2_intermediate_y()
        {
            ENV.engine = new Engine();
            var x0           = alb.scalar(1);
            var x1           = alb.scalar(2);
            var x2           = alb.scalar(3);
            var intermediate = alb.scalar(4);
            var y            = alb.scalar(2);

            var input1 = new System.Collections.Generic.Dictionary <string, Tensor>();

            input1.Add("x0", x0);
            input1.Add("x1", x1);



            var input2 = new System.Collections.Generic.Dictionary <string, Tensor>();

            input2.Add("x2", x2);
            input2.Add("intermediate", intermediate);

            TapeNode[] tape = new TapeNode[] {
                new TapeNode()
                {
                    gradient = null
                    , id     = 0
                    , inputs = input1,
                    name     = "node0",
                    output   = intermediate
                }, new TapeNode()
                {
                    gradient = null
                    , id     = 1
                    , inputs = input2,
                    name     = "node1",
                    output   = y
                }
            };

            var filteredTapeNodes =
                Tape.getFilteredNodesXToY(tape, new Tensor[] { x0, x1, x2 }, y);

            Assert.AreEqual(filteredTapeNodes.Length, 2);
            AssertTools.TapeIsEqual(filteredTapeNodes, tape);
        }