Exemplo n.º 1
0
        public static LSTMNetwork Create2(int hiddenNeuronCount, int inputDim)
        {
            var graph = new ComputeGraph();

            var inputs    = graph.DeclareInput("inputs", inputDim);
            var previousH = graph.DeclareCopyFrom("previousH", hiddenNeuronCount, "ht");
            var previousC = graph.DeclareCopyFrom("previousC", hiddenNeuronCount, "ct_tanh");

            var f_gate_W = graph.DeclareFC("f_gate_W", inputDim, hiddenNeuronCount, inputs);
            var f_gate_U = graph.DeclareFC("f_gate_U", hiddenNeuronCount, hiddenNeuronCount, previousH);

            var i_gate_W = graph.DeclareFC("i_gate_W", inputDim, hiddenNeuronCount, inputs);
            var i_gate_U = graph.DeclareFC("i_gate_U", hiddenNeuronCount, hiddenNeuronCount, previousH);

            var o_gate_W = graph.DeclareFC("o_gate_W", inputDim, hiddenNeuronCount, inputs);
            var o_gate_U = graph.DeclareFC("o_gate_U", hiddenNeuronCount, hiddenNeuronCount, previousH);

            var c_gate_W = graph.DeclareFC("c_gate_W", inputDim, hiddenNeuronCount, inputs);
            var c_gate_U = graph.DeclareFC("c_gate_U", hiddenNeuronCount, hiddenNeuronCount, previousH);

            var f_value  = graph.DeclarePlus("f_value", f_gate_W, f_gate_U);
            var i_value  = graph.DeclarePlus("i_value", i_gate_W, i_gate_U);
            var o_value  = graph.DeclarePlus("o_value", o_gate_W, o_gate_U);
            var c1_value = graph.DeclarePlus("c1_value", c_gate_W, c_gate_U);

            var f_sigmoid_value = graph.DeclareSigmoid("f_sigmoid_value", f_value);
            var i_sigmoid_value = graph.DeclareSigmoid("i_sigmoid_value", i_value);
            var o_sigmoid_value = graph.DeclareSigmoid("o_sigmoid_value", o_value);
            var c1_tanh_value   = graph.DeclareTanh("c1_tanh_value", c1_value);

            var f_c_value  = graph.DeclareMultiply("f_c_value", f_sigmoid_value, previousC);
            var i_c1_value = graph.DeclareMultiply("i_c1_value", i_sigmoid_value, c1_tanh_value);

            var f_c_i_c1_plus = graph.DeclarePlus("f_c_i_c1_plus", f_c_value, i_c1_value);
            var ct_tanh       = graph.DeclareTanh("ct_tanh", f_c_i_c1_plus);

            var ht = graph.DeclareMultiply("ht", ct_tanh, o_sigmoid_value);

            ct_tanh.ResultsDim = hiddenNeuronCount;
            ht.ResultsDim      = hiddenNeuronCount;

            ct_tanh.IsMonitoring = true;
            ht.IsMonitoring      = true;

            graph.Compile();

            LSTMNetwork network = new LSTMNetwork(graph);

            return(network);
        }
Exemplo n.º 2
0
 public LSTMNetwork(ComputeGraph graph1)
 {
     this.graph1 = graph1;
 }