public ZoneoutCell(RecurrentCell base_cell, float zoneout_outputs = 0, float zoneout_states = 0) : base(base_cell) { ZoneoutOutputs = zoneout_outputs; ZoneoutStates = zoneout_states; _prev_output = null; }
public BidirectionalCell(RecurrentCell l_cell, RecurrentCell r_cell, string output_prefix = "bi_") : base("", null) { RegisterChild(l_cell, "l_cell"); RegisterChild(r_cell, "r_cell"); _output_prefix = output_prefix; }
internal static NDArrayOrSymbol[] GetBeginState(RecurrentCell cell, NDArrayOrSymbol[] begin_state, NDArrayOrSymbol[] inputs, int batch_size) { if (begin_state != null) { if (inputs[0].IsNDArray) { var ctx = inputs[0].NdX.Context; var args = new FuncArgs(); args.Add("ctx", ctx); begin_state = cell.BeginState(batch_size, "nd.Zeros", args); } else { begin_state = cell.BeginState(batch_size, "sym.Zeros"); } } return(begin_state); }
public void Add(RecurrentCell cell) { RegisterChild(cell); }
public ResidualCell(RecurrentCell base_cell) : base(base_cell) { }