public virtual SymbolList BeginState(string func = "sym.Zeros", FuncArgs kwargs = null) { if (_modified) { throw new Exception("After applying modifier cells (e.g. DropoutCell) the base " + "cell cannot be called directly. Call the modifier cell instead."); } SymbolList states = new SymbolList(); for (int i = 0; i < StateInfo.Length; i++) { var info = StateInfo[i]; Symbol state = null; _init_counter++; kwargs.Add("name", $"{_prefix}begin_state_{_init_counter}"); if (info == null) { info = new StateInfo(kwargs); } else { info.Update(kwargs); } var obj = new sym(); var m = typeof(sym).GetMethod(func.Replace("sym.", ""), BindingFlags.Static); var keys = m.GetParameters().Select(x => x.Name).ToArray(); var paramArgs = info.GetArgs(keys); states.Add((Symbol)m.Invoke(obj, paramArgs)); } return(states); }
internal static NDArrayOrSymbol[] GetBeginState(RecurrentCell cell, NDArrayOrSymbol[] begin_state, NDArrayOrSymbol[] inputs, int batch_size) { if (begin_state != null) { if (inputs[0].IsNDArray) { var ctx = inputs[0].NdX.Context; var args = new FuncArgs(); args.Add("ctx", ctx); begin_state = cell.BeginState(batch_size, "nd.Zeros", args); } else { begin_state = cell.BeginState(batch_size, "sym.Zeros"); } } return(begin_state); }