コード例 #1
0
 public static Tensor <float> Pow(Tensor <float> x, float a)
 {
     if (a == 1)
     {
         return(x);
     }
     if (a == 2)
     {
         return(Op.Square(x));
     }
     return(Apply(x, ConstLike(a, x), (_x, _y) => new Operators.Scalars.Pow(_x, _y)));
 }
コード例 #2
0
        public void RnnXorHasCorrectGradient()
        {
            NN.Random.Seed(12345);
            int nh = 10; // hidden layer

            var Wbit   = T.Shared(0.2f * NN.Random.Uniform(-1.0f, 1.0f, nh, 1).As <float>(), "Wbit");
            var Wstate = T.Shared(NN.Eye <float>(nh), "Wstate");
            var Wout   = T.Shared(0.2f * NN.Random.Uniform(-1.0f, 1.0f, 1, nh).As <float>(), "Wout");
            var b      = T.Shared(0.2f * NN.Random.Uniform(-1.0f, 1.0f, nh, 1).As <float>(), "b");

            var state0 = T.Shared(NN.Zeros <float>(nh, 1), "state0");

            var bits     = T.Tensor3 <float>("bits");          // n x 1
            var expected = T.Matrix <float>("expected");       // 1 x 1

            Func <Tensor <float>, Tensor <float>, Tensor <float> > recurrence = (bit, oldState) =>
            {
                return(T.Tanh(T.Dot(Wbit, bit) + T.Dot(Wstate, oldState) + b));
            };

            var states = T.Scan(fn: recurrence, sequence: bits, outputsInfo: state0);

            var output    = T.Tanh(T.Dot(Wout, states[(Slice)(-1)]));
            var error     = 0.5f * T.Norm2(output - expected);
            var classify  = T.Function(bits, output);
            var gradients = T.Grad(error);

            var gradWstate = gradients[Wstate];

            Assert.IsNotNull(gradWstate);
            var gradWstateIsReshape = gradWstate as Reshaping <float>;

            Assert.IsNotNull(gradWstateIsReshape);
            var gradWstateIsSum = gradWstateIsReshape.x as Sum <float>;

            Assert.IsNotNull(gradWstateIsSum);
            var dfor     = gradWstateIsSum.x as Tensor <float> .For;
            var backLoop = dfor.Loop;

            Assert.AreEqual(3, backLoop.Sequences.Count); // bit, states, delta
            Assert.AreEqual(6, backLoop.Fors.Count);      // dbit, dstate, dWstate, db, dWbit, dstate_p1
            Assert.AreEqual(3, dfor.Index);

            // TODO: check why a recursive was expected
            //var dWstate_ = dfor.RecursiveVariable;
            //Assert.AreEqual("dWstate_", dWstate_.Name);

            var variables = backLoop.Variables.Cast <Tensor <float> >().ToList();
            var bit_      = variables[0];

            Assert.AreEqual("bit_", bit_.Name);
            var oldState_ = variables[1];

            Assert.AreEqual("oldState_", oldState_.Name);
            var delta_oldState_ = variables[2];

            Assert.AreEqual("delta_oldState_", delta_oldState_.Name);
            var dbit_ = variables[3];

            Assert.AreEqual("dbit_", dbit_.Name);
            var doldState_ = variables[4];

            Assert.AreEqual("doldState_", doldState_.Name);
            var oldState_tp1_ = variables[5];

            Assert.AreEqual("oldState_tp1", oldState_tp1_.Name);

            var d = T.Sum((delta_oldState_ + doldState_) * (1f - T.Square(oldState_tp1_)), axis: 1, keepDims: true);

            var doldState = (Tensor <float>)backLoop.Fors[1].Expression;

            (T.Dot(Wstate, d, transposeX: true)).AssertEqual(doldState);

            var dWstate    = (Tensor <float>)backLoop.Fors[3].Expression;
            var dWstateExp = T.Dot(d, oldState_, transposeY: true);

            dWstateExp.AssertEqual(dWstate);

            var dbit = (Tensor <float>)backLoop.Fors[0].Expression;

            (T.Dot(Wbit, d, transposeX: true)).StructuralEquality(dbit);

            var oldState_tp1 = (Tensor <float>)backLoop.Fors[5].Expression;

            oldState_tp1.AssertEqual(oldState_);
        }