public Executor Bind(Context context, IList <NdArray> argArrays, Dictionary <string, NdArray> gradDict, OpReqType gradReq, IList <NdArray> auxArrays, Dictionary <string, Context> groupToCtx = null, Executor sharedExec = null) { var gradReqs = ListArguments().ToDictionary(k => k, v => gradReq); return(Bind(context, argArrays, gradDict, gradReqs, auxArrays, groupToCtx, sharedExec)); }
public Executor SimpleBind( Context context, Dictionary <string, uint[]> inputShapes, OpReqType gradReq, Dictionary <string, Type> typeDict = null, Dictionary <string, Context> group2Ctx = null ) { var listArguments = ListArguments(); if (typeDict == null) { typeDict = listArguments.ToDictionary(k => k, v => typeof(float)); } var argShapes = new List <uint[]>(); var auxShapes = new List <uint[]>(); var outShapes = new List <uint[]>(); InferShape(inputShapes, argShapes, outShapes, auxShapes); var argTypes = new List <Type>(); var auxTypes = new List <Type>(); var outTypes = new List <Type>(); InferType(typeDict, argTypes, auxTypes, outTypes); if (argShapes.Count == 0 || argTypes.Count == 0) { throw new Exception("Input node is not complete"); } List <Context> argCtx; List <Context> auxCtx; if (group2Ctx != null) { var listattr = ListAttr(true); var attrDict = listattr.Where(w => w.Key.EndsWith("ctx_group")) .ToDictionary(k => k.Key, v => group2Ctx.GetValueOrDefault(v.Value, context)); argCtx = listArguments .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList(); auxCtx = ListAuxiliaryStates() .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList(); } else { argCtx = Enumerable.Range(0, argShapes.Count).Select(s => context).ToList(); auxCtx = Enumerable.Range(0, auxShapes.Count).Select(s => context).ToList(); } //alloc space var argNdarrays = argTypes .Zip(argCtx, argShapes, (dtype, dev, shape) => NdArray.Zeros(new Shape(shape), dev, dtype)) .ToList(); Dictionary <string, NdArray> gradNdarrays = new Dictionary <string, NdArray>(); if (gradReq != OpReqType.KNullOp) { for (int i = 0; i < listArguments.Count; i++) { var name = listArguments[i]; var shape = argShapes[i]; var dev = argCtx[i]; var dtype = argTypes[i]; gradNdarrays[name] = NdArray.Zeros(new Shape(shape), dev, dtype: dtype); } } var auxNdarrays = auxTypes .Zip(auxCtx, auxShapes, (dtype, dev, shape) => NdArray.Zeros(new Shape(shape), dev, dtype)) .ToList(); var executor = Bind(context, argNdarrays, gradNdarrays, gradReq, auxNdarrays, groupToCtx: group2Ctx); return(executor); }