private Tuple <List <string>, List <string>, List <string> > InitParams( Dictionary <string, Shape> inputShapes, bool overwrite = false) { List <uint[]> argShapes = new List <uint[]>(); List <uint[]> auxShapes = new List <uint[]>(); this._symbol.InferShape(inputShapes, argShapes, null, auxShapes); var argNames = this._symbol.ListArguments(); var inputNames = inputShapes.Keys; var paramNames = argNames.Except(inputNames); var auxNames = this._symbol.ListAuxiliaryStates(); var paramNameShapes = argNames.Zip(argShapes, Tuple.Create).Where(w => paramNames.Contains(w.Item1)); var argParams = paramNameShapes.ToDictionary(k => k.Item1, s => NdArray.Zeros(new Shape(s.Item2))); var auxParams = auxNames.Zip(auxShapes, Tuple.Create).ToDictionary(k => k.Item1, s => NdArray.Zeros(new Shape(s.Item2))); foreach (var kv in argParams) { var k = kv.Key; if (this.ArgParams != null && this.ArgParams.ContainsKey(kv.Key) && !overwrite) { this.ArgParams[k].CopyTo(argParams[k]); } else { this._initializer.Call(k, argParams[k]); } } foreach (var kv in auxParams) { var k = kv.Key; if (this.AuxParams != null && this.AuxParams.ContainsKey(kv.Key) && !overwrite) { this.AuxParams[k].CopyTo(auxParams[k]); } else { this._initializer.Call(k, argParams[k]); } } this.ArgParams = argParams; this.AuxParams = auxParams; return(Tuple.Create(argNames.ToList(), paramNames.ToList(), auxNames.ToList())); }
private Executor BindExec(Symbol sym, Context ctx, Dictionary <string, uint[]> inputShapes, IList <string> paramNames, bool needGradInput, Executor baseExec, Dictionary <string, NdArray> sharedDataArrays, Dictionary <string, Type> inputTypes = null, ILog logger = null) { if (logger == null) { logger = LogManager.GetLogger(""); } var argShapes = new List <uint[]>(); var auxShapes = new List <uint[]>(); var outShapes = new List <uint[]>(); sym.InferShape(inputShapes, argShapes, outShapes, auxShapes); var argTypes = new List <Type>(); var auxType = new List <Type>(); var outType = new List <Type>(); if (inputTypes == null) { inputTypes = inputShapes.ToDictionary(k => k.Key, v => typeof(float)); } sym.InferType(inputTypes, argTypes, auxType, outType); var gradArrays = needGradInput ? new Dictionary <string, NdArray>() : null; var argNames = sym.ListArguments(); HashSet <string> needGrad; if (needGradInput == false) { needGrad = new HashSet <string>(); } else { needGrad = new HashSet <string>(argNames.Except(inputShapes.Keys)); } var gradReq = argNames.ToDictionary(name => name, v => needGrad.Contains(v) ? OpReqType.KWriteTo : OpReqType.KNullOp); List <NdArray> argArrays = new List <NdArray>(); //create or borrow arguments and gradients for (int i = 0; i < argNames.Count; i++) { var name = argNames[i]; if (!paramNames.Contains(name)) { NdArray argArr; //data or label if (sharedDataArrays != null && sharedDataArrays.ContainsKey(name)) { argArr = sharedDataArrays[name]; if (Util.Prod(argArr.GetShape()) >= Util.Prod(argShapes[i])) { Util.Assert(argTypes[i] == argArr.GetDtype()); argArr = argArr.Reshape(new Shape(argShapes[i])); } else { logger.Warn($"bucketing: data \"{name}\" has a shape {new Shape(argShapes[i])}" + ", which is larger than already allocated " + $"shape {argArr.GetShape()}" + ". Need to re-allocate. Consider putting " + "default_bucket_key to be the bucket taking the largest " + "input for better memory sharing."); argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]); // replace existing shared array because the new one is bigger sharedDataArrays[name] = argArr; } } else { argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]); if (sharedDataArrays != null) { sharedDataArrays[name] = argArr; } } argArrays.Add(argArr); } else { NdArray argArr; if (baseExec == null) { argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]); if (needGradInput && needGrad.Contains(name)) { var gradArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]); gradArrays[name] = gradArr; } } else { argArr = baseExec.ArgDict[name]; Util.Assert(argArr.GetShape() == new Shape(argShapes[i])); Util.Assert(argArr.GetDtype() == argTypes[i]); if (needGradInput && needGrad.Contains(name)) { gradArrays[name] = baseExec.GradDict[name]; } } argArrays.Add(argArr); } } IList <NdArray> auxArrays; if (baseExec == null) { auxArrays = auxShapes.Zip(auxType, (l, r) => NdArray.Zeros(new Shape(l), ctx, r)).ToList(); } else { for (int i = 0; i < baseExec.AuxArrays.Count; i++) { var a = baseExec.AuxArrays[i]; Util.Assert((new Shape(auxShapes[i])) == a.GetShape()); Util.Assert(auxType[i] == a.GetDtype()); } auxArrays = baseExec.AuxArrays; } var executor = sym.Bind(ctx, argArrays, gradArrays, gradReq, auxArrays, null, baseExec); return(executor); }
public Executor SimpleBind( Context context, Dictionary <string, uint[]> inputShapes, OpReqType gradReq, Dictionary <string, Type> typeDict = null, Dictionary <string, Context> group2Ctx = null ) { var listArguments = ListArguments(); if (typeDict == null) { typeDict = listArguments.ToDictionary(k => k, v => typeof(float)); } var argShapes = new List <uint[]>(); var auxShapes = new List <uint[]>(); var outShapes = new List <uint[]>(); InferShape(inputShapes, argShapes, outShapes, auxShapes); var argTypes = new List <Type>(); var auxTypes = new List <Type>(); var outTypes = new List <Type>(); InferType(typeDict, argTypes, auxTypes, outTypes); if (argShapes.Count == 0 || argTypes.Count == 0) { throw new Exception("Input node is not complete"); } List <Context> argCtx; List <Context> auxCtx; if (group2Ctx != null) { var listattr = ListAttr(true); var attrDict = listattr.Where(w => w.Key.EndsWith("ctx_group")) .ToDictionary(k => k.Key, v => group2Ctx.GetValueOrDefault(v.Value, context)); argCtx = listArguments .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList(); auxCtx = ListAuxiliaryStates() .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList(); } else { argCtx = Enumerable.Range(0, argShapes.Count).Select(s => context).ToList(); auxCtx = Enumerable.Range(0, auxShapes.Count).Select(s => context).ToList(); } //alloc space var argNdarrays = argTypes .Zip(argCtx, argShapes, (dtype, dev, shape) => NdArray.Zeros(new Shape(shape), dev, dtype)) .ToList(); Dictionary <string, NdArray> gradNdarrays = new Dictionary <string, NdArray>(); if (gradReq != OpReqType.KNullOp) { for (int i = 0; i < listArguments.Count; i++) { var name = listArguments[i]; var shape = argShapes[i]; var dev = argCtx[i]; var dtype = argTypes[i]; gradNdarrays[name] = NdArray.Zeros(new Shape(shape), dev, dtype: dtype); } } var auxNdarrays = auxTypes .Zip(auxCtx, auxShapes, (dtype, dev, shape) => NdArray.Zeros(new Shape(shape), dev, dtype)) .ToList(); var executor = Bind(context, argNdarrays, gradNdarrays, gradReq, auxNdarrays, groupToCtx: group2Ctx); return(executor); }