public FusedRNNCell(int num_hidden, int num_layers = 1, RNNMode mode = RNNMode.Lstm, bool bidirectional = false, float dropout = 0, bool get_next_state = false, float forget_bias = 1, string prefix = null, RNNParams @params = null) : base(prefix == null ? mode + "_" : prefix, @params) { this._num_hidden = num_hidden; this._num_layers = num_layers; this._mode = mode; this._bidirectional = bidirectional; this._dropout = dropout; this._get_next_state = get_next_state; this._directions = bidirectional ? new List <string> { "l", "r" } : new List <string> { "l" }; var initializer = new FusedRNN(null, num_hidden, num_layers, mode, bidirectional, forget_bias); this._parameter = this.Params.Get("parameters", init: initializer); }
public static Symbol RNN(Symbol data, Symbol parameters, Symbol state, Symbol stateCell, uint32_t stateSize, uint32_t numLayers, RNNMode mode, bool bidirectional = false, mx_float p = 0, bool stateOutputs = false) { return(new Operator("RNN").SetParam("state_size", stateSize) .SetParam("num_layers", numLayers) .SetParam("mode", RNNModeValues[(int)mode]) .SetParam("bidirectional", bidirectional) .SetParam("p", p) .SetParam("state_outputs", stateOutputs) .SetInput("data", data) .SetInput("parameters", parameters) .SetInput("state", state) .SetInput("state_cell", stateCell) .CreateSymbol()); }
public FusedRNN(Initializer init, int num_hidden, int num_layers, RNNMode mode, bool bidirectional = false, float forget_bias = 1) { throw new NotImplementedException(); }