Ejemplo n.º 1
0
        public Executor Bind(Context context,
                             IList <NdArray> argArrays,
                             Dictionary <string, NdArray> gradDict,
                             OpReqType gradReq,
                             IList <NdArray> auxArrays,
                             Dictionary <string, Context> groupToCtx = null,
                             Executor sharedExec = null)
        {
            var gradReqs = ListArguments().ToDictionary(k => k, v => gradReq);

            return(Bind(context, argArrays, gradDict, gradReqs, auxArrays, groupToCtx, sharedExec));
        }
Ejemplo n.º 2
0
        public Executor SimpleBind(
            Context context,
            Dictionary <string, uint[]> inputShapes,
            OpReqType gradReq,
            Dictionary <string, Type> typeDict     = null,
            Dictionary <string, Context> group2Ctx = null
            )
        {
            var listArguments = ListArguments();


            if (typeDict == null)
            {
                typeDict = listArguments.ToDictionary(k => k, v => typeof(float));
            }

            var argShapes = new List <uint[]>();
            var auxShapes = new List <uint[]>();
            var outShapes = new List <uint[]>();

            InferShape(inputShapes, argShapes, outShapes, auxShapes);


            var argTypes = new List <Type>();
            var auxTypes = new List <Type>();
            var outTypes = new List <Type>();

            InferType(typeDict, argTypes, auxTypes, outTypes);

            if (argShapes.Count == 0 || argTypes.Count == 0)
            {
                throw new Exception("Input node is not complete");
            }

            List <Context> argCtx;
            List <Context> auxCtx;

            if (group2Ctx != null)
            {
                var listattr = ListAttr(true);
                var attrDict = listattr.Where(w => w.Key.EndsWith("ctx_group"))
                               .ToDictionary(k => k.Key, v => group2Ctx.GetValueOrDefault(v.Value, context));
                argCtx = listArguments
                         .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList();
                auxCtx = ListAuxiliaryStates()
                         .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList();
            }
            else
            {
                argCtx = Enumerable.Range(0, argShapes.Count).Select(s => context).ToList();
                auxCtx = Enumerable.Range(0, auxShapes.Count).Select(s => context).ToList();
            }

            //alloc space
            var argNdarrays = argTypes
                              .Zip(argCtx, argShapes,
                                   (dtype, dev, shape) =>
                                   NdArray.Zeros(new Shape(shape), dev, dtype))
                              .ToList();
            Dictionary <string, NdArray> gradNdarrays = new Dictionary <string, NdArray>();

            if (gradReq != OpReqType.KNullOp)
            {
                for (int i = 0; i < listArguments.Count; i++)
                {
                    var name  = listArguments[i];
                    var shape = argShapes[i];
                    var dev   = argCtx[i];
                    var dtype = argTypes[i];

                    gradNdarrays[name] = NdArray.Zeros(new Shape(shape), dev, dtype: dtype);
                }
            }

            var auxNdarrays = auxTypes
                              .Zip(auxCtx, auxShapes,
                                   (dtype, dev, shape) =>
                                   NdArray.Zeros(new Shape(shape), dev, dtype))
                              .ToList();

            var executor = Bind(context,
                                argNdarrays,
                                gradNdarrays,
                                gradReq,
                                auxNdarrays,
                                groupToCtx: group2Ctx);

            return(executor);
        }