Example #1
0
        public BidirectionalCell(BaseRNNCell l_cell, BaseRNNCell r_cell, string output_prefix = "bi_",
                                 RNNParams @params = null) : base("", @params)
        {
            this._output_prefix        = output_prefix;
            this._override_cell_params = @params != null;
            if (this._override_cell_params)
            {
                Debug.Assert(l_cell._own_params != null && r_cell._own_params, "Either specify params for BidirectionalCell or child cells, not both.");
                foreach (var item in this.Params._params)
                {
                    l_cell.Params._params[item.Key] = item.Value;
                    r_cell.Params._params[item.Key] = item.Value;
                }
            }

            foreach (var item in l_cell.Params._params)
            {
                this.Params._params[item.Key] = item.Value;
            }

            foreach (var item in r_cell.Params._params)
            {
                this.Params._params[item.Key] = item.Value;
            }

            this._cells.Add(l_cell);
            this._cells.Add(r_cell);
        }
Example #2
0
 public GRUCell(int num_hidden, string prefix = "lstm_", RNNParams @params = null) : base(prefix, @params)
 {
     this._num_hidden = num_hidden;
     this._iW         = this.Params.Get("i2h_weight");
     this._iB         = this.Params.Get("i2h_bias");
     this._hW         = this.Params.Get("h2h_weight");
     this._hB         = this.Params.Get("h2h_bias");
 }
Example #3
0
 public LSTMCell(int num_hidden, string prefix = "lstm_", RNNParams @params = null, float forget_bias = 1) : base(prefix, @params)
 {
     this._num_hidden = num_hidden;
     this._iW         = this.Params.Get("i2h_weight");
     this._iB         = this.Params.Get("i2h_bias", init: new LSTMBias(forget_bias));
     this._hW         = this.Params.Get("h2h_weight");
     this._hB         = this.Params.Get("h2h_bias");
 }
Example #4
0
        public BaseRNNCell(string prefix, RNNParams @params = null)
        {
            if (@params == null)
            {
                @params     = new RNNParams(prefix);
                _own_params = true;
            }
            else
            {
                _own_params = false;
            }

            _prefix   = prefix;
            _params   = @params;
            _modified = false;

            Reset();
        }
Example #5
0
        public FusedRNNCell(int num_hidden, int num_layers = 1, RNNMode mode = RNNMode.Lstm, bool bidirectional = false,
                            float dropout = 0, bool get_next_state           = false, float forget_bias = 1, string prefix = null, RNNParams @params = null) : base(prefix == null ? mode + "_" : prefix, @params)
        {
            this._num_hidden     = num_hidden;
            this._num_layers     = num_layers;
            this._mode           = mode;
            this._bidirectional  = bidirectional;
            this._dropout        = dropout;
            this._get_next_state = get_next_state;
            this._directions     = bidirectional ? new List <string> {
                "l",
                "r"
            } : new List <string> {
                "l"
            };

            var initializer = new FusedRNN(null, num_hidden, num_layers, mode, bidirectional, forget_bias);

            this._parameter = this.Params.Get("parameters", init: initializer);
        }
Example #6
0
 public RNNCell(int num_hidden, ActivationType activation = ActivationType.Tanh, string prefix = "rnn_", RNNParams @params = null) : base(prefix, @params)
 {
     throw new NotImplementedException();
 }
 public FusedRNNCell(int num_hidden, int num_layers = 1, string mode = "lstm", bool bidirectional = false,
                     float dropout = 0, bool get_next_state          = false, float forget_bias   = 1, string prefix = null, RNNParams @params = null) : base(prefix, @params)
 {
     throw new NotImplementedException();
 }
Example #8
0
 public RNNCell(int num_hidden, ActivationType activation = ActivationType.Tanh, string prefix = "rnn_", RNNParams @params = null) : base(prefix, @params)
 {
     this._num_hidden = num_hidden;
     this._activation = activation;
     this._iW         = this.Params.Get("i2h_weight");
     this._iB         = this.Params.Get("i2h_bias");
     this._hW         = this.Params.Get("h2h_weight");
     this._hB         = this.Params.Get("h2h_bias");
 }
Example #9
0
 public GRUCell(int num_hidden, string prefix = "lstm_", RNNParams @params = null) : base(prefix, @params)
 {
     throw new NotImplementedException();
 }
Example #10
0
 public SequentialRNNCell(RNNParams @params = null) : base("", @params)
 {
     this._override_cell_params = @params != null;
     this._cells = new List <BaseRNNCell>();
 }
 public BidirectionalCell(BaseRNNCell l_cell, BaseRNNCell r_cell, string output_prefix = "bi_",
                          RNNParams @params = null) : base("", @params)
 {
     throw new NotImplementedException();
 }
Example #12
0
 public DropoutCell(float dropout, string prefix, RNNParams @params = null) : base(prefix, @params)
 {
     this.dropout = dropout;
 }
Example #13
0
 public DropoutCell(float dropout, string prefix, RNNParams @params = null) : base(prefix, @params)
 {
     throw new NotImplementedException();
 }
Example #14
0
 public SequentialRNNCell(RNNParams @params = null) : base("", @params)
 {
 }
Example #15
0
 public LSTMCell(int num_hidden, string prefix = "lstm_", RNNParams @params = null, float forget_bias = 1) : base(prefix, @params)
 {
     throw new NotImplementedException();
 }