示例#1
0
 public ZoneoutCell(RecurrentCell base_cell, float zoneout_outputs = 0, float zoneout_states = 0) :
     base(base_cell)
 {
     ZoneoutOutputs = zoneout_outputs;
     ZoneoutStates  = zoneout_states;
     _prev_output   = null;
 }
示例#2
0
 public BidirectionalCell(RecurrentCell l_cell, RecurrentCell r_cell, string output_prefix = "bi_") : base("",
                                                                                                           null)
 {
     RegisterChild(l_cell, "l_cell");
     RegisterChild(r_cell, "r_cell");
     _output_prefix = output_prefix;
 }
示例#3
0
        internal static NDArrayOrSymbol[] GetBeginState(RecurrentCell cell, NDArrayOrSymbol[] begin_state,
                                                        NDArrayOrSymbol[] inputs, int batch_size)
        {
            if (begin_state != null)
            {
                if (inputs[0].IsNDArray)
                {
                    var ctx  = inputs[0].NdX.Context;
                    var args = new FuncArgs();
                    args.Add("ctx", ctx);
                    begin_state = cell.BeginState(batch_size, "nd.Zeros", args);
                }
                else
                {
                    begin_state = cell.BeginState(batch_size, "sym.Zeros");
                }
            }

            return(begin_state);
        }
 public void Add(RecurrentCell cell)
 {
     RegisterChild(cell);
 }
示例#5
0
 public ResidualCell(RecurrentCell base_cell) : base(base_cell)
 {
 }