public void LoadParameters(string filename, Context[] ctx = null, bool allow_missing = false, bool ignore_extra = false, bool cast_dtype = false, string dtype_source = "current") { var loaded = new NDArrayDict(); var @params = CollectParamsWithPrefix(); NDArray.Load(filename, out loaded); if (loaded == null && @params == null) { return; } if (!loaded.Keys.Any(x => x.Contains("."))) { loaded = null; CollectParams().Load(filename, ctx, allow_missing, ignore_extra, Prefix, cast_dtype, dtype_source); return; } if (!allow_missing) { foreach (var name in @params.Keys()) { if (!loaded.Contains(name)) { throw new Exception(string.Format("Parameter '{0}' is missing in file '{1}'", name, filename)); } } } foreach (var name in loaded.Keys) { if (!ignore_extra && [email protected](name)) { throw new Exception(string.Format( "Parameter '{0}' loaded from file {1} is not present in ParameterDict", name, filename)); } if (@params.Contains(name)) { @params[name].LoadInit(loaded[name], ctx, cast_dtype, dtype_source); } } }
public NDArrayList Call(KerasSymbol[] inputs) { var ret_outputs = new List <NDArray>(); foreach (var x in this.output) { var bind_values = MxNetBackend.DfsGetBindValues(x); NDArrayDict data = new NDArrayDict(); for (int i = 0; i < this.inputs.Length; i++) { var arr = bind_values.Where(a => a.Key == this.inputs[i].Name).FirstOrDefault(); if (arr.Value != null) { data[this.inputs[i].Name] = arr.Value; } } var args = x.Symbol.ListArguments(); List <DataDesc> data_shapes = new List <DataDesc>(); Dictionary <string, OpGradReq> grad_types = new Dictionary <string, OpGradReq>(); for (int i = 0; i < this.inputs.Length; i++) { data_shapes.Add(new DataDesc(this.inputs[i].Name, inputs[i].Shape)); grad_types.Add(this.inputs[i].Name, OpGradReq.Null); } var executor = x.Symbol.SimpleBind(mx.Cpu(), grad_req: grad_types, kwargs: data_shapes.ToArray()); var arg_dict = executor.ArgmentDictionary(); foreach (var v in arg_dict) { if (data.Contains(v.Key)) { arg_dict[v.Key] = data[v.Key]; } } executor.Forward(this.is_train); var outputs = executor.Outputs; ret_outputs.Add(outputs[0]); } return(ret_outputs); }