예제 #1
0
        public void lstm_test01()
        {
            //define values, and variables
            Variable x       = Variable.InputVariable(new int[] { 4 }, DataType.Float, "input");
            var      xValues = Value.CreateBatchOfSequences <float>(new int[] { 4 }, mData, device);

            //
            var lstm00 = RNN.RecurrenceLSTM(x, 3, 3, DataType.Float, device, false, Activation.TanH, true, true, 1);

            //
            LSTMReccurentNN lstmNN = new LSTMReccurentNN(1, 1, device);
            //lstm implme reference 01
            var lstmCell = lstmNN.CreateLSTM(x, "output1");
            var lstm01   = CNTKLib.SequenceLast(lstmCell.h);

            //lstme implementation refe 02
            var lstm02 = LSTMSequenceClassifier.LSTMNet(x, 1, device, "output1");

            //
            var wParams00 = lstm00.Inputs.Where(p => p.Uid.Contains("Parameter")).ToList();
            var wParams01 = lstm00.Inputs.Where(p => p.Uid.Contains("Parameter")).ToList();
            var wParams02 = lstm00.Inputs.Where(p => p.Uid.Contains("Parameter")).ToList();

            //parameter count
            Assert.Equal(wParams00.Count, wParams01.Count);
            Assert.Equal(wParams00.Count, wParams02.Count);

            //structure of parameters test
            Assert.Equal(wParams00.Where(p => p.Name.Contains("_b")).Count(), wParams01.Where(p => p.Name.Contains("_b")).Count());
            Assert.Equal(wParams00.Where(p => p.Name.Contains("_w")).Count(), wParams01.Where(p => p.Name.Contains("_w")).Count());
            Assert.Equal(wParams00.Where(p => p.Name.Contains("_u")).Count(), wParams01.Where(p => p.Name.Contains("_u")).Count());
            Assert.Equal(wParams00.Where(p => p.Name.Contains("peep")).Count(), wParams01.Where(p => p.Name.Contains("peep")).Count());
            Assert.Equal(wParams00.Where(p => p.Name.Contains("stabilize")).Count(), wParams01.Where(p => p.Name.Contains("stabilize")).Count());


            //check structure of parameters with originaly developed lstm
            //chech for arguments
            Assert.True(lstm01.Arguments.Count == lstm02.Arguments.Count);
            for (int i = 0; i < lstm01.Arguments.Count; i++)
            {
                testVariable(lstm01.Arguments[i], lstm01.Arguments[i]);
            }

            ///
            Assert.True(lstm01.Inputs.Count == lstm02.Inputs.Count);
            for (int i = 0; i < lstm01.Inputs.Count; i++)
            {
                testVariable(lstm01.Inputs[i], lstm02.Inputs[i]);
            }

            ///
            Assert.True(lstm01.Outputs.Count == lstm02.Outputs.Count);
            for (int i = 0; i < lstm01.Outputs.Count; i++)
            {
                testVariable(lstm01.Outputs[i], lstm02.Outputs[i]);
            }
        }
예제 #2
0
        public void LSTM_Test_Params_Count_with_peep_selfstabilize()
        {
            //define values, and variables
            Variable x = Variable.InputVariable(new int[] { 2 }, DataType.Float, "input");
            Variable y = Variable.InputVariable(new int[] { 3 }, DataType.Float, "output");

            #region lstm org implemented in cntk for reference
            //lstme implementation refe 02
            var lstmTest02 = LSTMSequenceClassifier.LSTMNet(x, 3, device, "output1");
            var ft2        = lstmTest02.Inputs.Where(l => l.Uid.StartsWith("Parameter")).ToList();
            var totalSize  = ft2.Sum(p => p.Shape.TotalSize);
            //bias params
            var bs2      = ft2.Where(p => p.Name.Contains("_b")).ToList();
            var totalBs2 = bs2.Sum(v => v.Shape.TotalSize);

            //weights
            var ws2      = ft2.Where(p => p.Name.Contains("_w")).ToList();
            var totalWs2 = ws2.Sum(v => v.Shape.TotalSize);

            //update
            var us2      = ft2.Where(p => p.Name.Contains("_u")).ToList();
            var totalUs2 = us2.Sum(v => v.Shape.TotalSize);

            //peephole
            var ph2      = ft2.Where(p => p.Name.Contains("_peep")).ToList();
            var totalph2 = ph2.Sum(v => v.Shape.TotalSize);

            //stabilize
            var st2      = ft2.Where(p => p.Name.Contains("_stabilize")).ToList();
            var totalst2 = st2.Sum(v => v.Shape.TotalSize);
            #endregion

            #region anndotnet old implementation
            //
            //LSTMReccurentNN lstmNN = new LSTMReccurentNN(3, 3, device);
            ////lstm implme reference 01
            //var lstmCell11 = lstmNN.CreateLSTM(x, "output1");
            //var lstmTest01 = CNTKLib.SequenceLast(lstmCell11.h);
            //var ft1 = lstmTest01.Inputs.Where(l => l.Uid.StartsWith("Parameter")).ToList();
            //var consts1 = lstmTest01.Inputs.Where(l => l.Uid.StartsWith("Constant")).ToList();
            //var inp1 = lstmTest01.Inputs.Where(l => l.Uid.StartsWith("Input")).ToList();
            //var pparams1 = ft1.Sum(v => v.Shape.TotalSize);
            #endregion

            //Number of LSTM parameters
            var lstm1 = RNN.RecurrenceLSTM(x, 3, 3, DataType.Float, device, false, Activation.TanH, true, true, 1);

            var ft     = lstm1.Inputs.Where(l => l.Uid.StartsWith("Parameter")).ToList();
            var consts = lstm1.Inputs.Where(l => l.Uid.StartsWith("Constant")).ToList();
            var inp    = lstm1.Inputs.Where(l => l.Uid.StartsWith("Input")).ToList();

            //bias params
            var bs      = ft.Where(p => p.Name.Contains("_b")).ToList();
            var totalBs = bs.Sum(v => v.Shape.TotalSize);
            Assert.Equal(12, totalBs);
            //weights
            var ws      = ft.Where(p => p.Name.Contains("_w")).ToList();
            var totalWs = ws.Sum(v => v.Shape.TotalSize);
            Assert.Equal(24, totalWs);
            //update
            var us      = ft.Where(p => p.Name.Contains("_u")).ToList();
            var totalUs = us.Sum(v => v.Shape.TotalSize);
            Assert.Equal(36, totalUs);
            //peephole
            var ph      = ft.Where(p => p.Name.Contains("_peep")).ToList();
            var totalPh = ph.Sum(v => v.Shape.TotalSize);
            Assert.Equal(9, totalPh);
            //stabilize
            var st      = ft.Where(p => p.Name.Contains("_stabilize")).ToList();
            var totalst = st.Sum(v => v.Shape.TotalSize);
            Assert.Equal(6, totalst);

            var totalOnly          = totalBs + totalWs + totalUs;
            var totalWithSTabilize = totalOnly + totalst;
            var totalWithPeep      = totalOnly + totalPh;

            var totalP      = totalOnly + totalst + totalPh;
            var totalParams = ft.Sum(v => v.Shape.TotalSize);
            Assert.Equal(totalP, totalParams);
            //72- without peep and stab
            //75 - witout peep with stabil +3xm =
            //81 - with peephole and without stabil
            //87 - with peep ans stab 3+9
        }