Ejemplo n.º 1
0
        private Tuple <List <string>, List <string>, List <string> > InitParams(
            Dictionary <string, Shape> inputShapes,
            bool overwrite = false)
        {
            List <uint[]> argShapes = new List <uint[]>();
            List <uint[]> auxShapes = new List <uint[]>();

            this._symbol.InferShape(inputShapes, argShapes, null, auxShapes);

            var argNames   = this._symbol.ListArguments();
            var inputNames = inputShapes.Keys;

            var paramNames = argNames.Except(inputNames);

            var auxNames = this._symbol.ListAuxiliaryStates();

            var paramNameShapes = argNames.Zip(argShapes, Tuple.Create).Where(w => paramNames.Contains(w.Item1));

            var argParams = paramNameShapes.ToDictionary(k => k.Item1, s => NdArray.Zeros(new Shape(s.Item2)));
            var auxParams = auxNames.Zip(auxShapes, Tuple.Create).ToDictionary(k => k.Item1, s => NdArray.Zeros(new Shape(s.Item2)));

            foreach (var kv in argParams)
            {
                var k = kv.Key;
                if (this.ArgParams != null && this.ArgParams.ContainsKey(kv.Key) && !overwrite)
                {
                    this.ArgParams[k].CopyTo(argParams[k]);
                }
                else
                {
                    this._initializer.Call(k, argParams[k]);
                }
            }


            foreach (var kv in auxParams)
            {
                var k = kv.Key;
                if (this.AuxParams != null && this.AuxParams.ContainsKey(kv.Key) && !overwrite)
                {
                    this.AuxParams[k].CopyTo(auxParams[k]);
                }
                else
                {
                    this._initializer.Call(k, argParams[k]);
                }
            }

            this.ArgParams = argParams;
            this.AuxParams = auxParams;

            return(Tuple.Create(argNames.ToList(), paramNames.ToList(), auxNames.ToList()));
        }
Ejemplo n.º 2
0
        private Executor BindExec(Symbol sym, Context ctx, Dictionary <string, uint[]> inputShapes,
                                  IList <string> paramNames, bool needGradInput, Executor baseExec,
                                  Dictionary <string, NdArray> sharedDataArrays, Dictionary <string, Type> inputTypes = null, ILog logger = null)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }

            var argShapes = new List <uint[]>();
            var auxShapes = new List <uint[]>();
            var outShapes = new List <uint[]>();

            sym.InferShape(inputShapes, argShapes, outShapes, auxShapes);


            var argTypes = new List <Type>();
            var auxType  = new List <Type>();
            var outType  = new List <Type>();

            if (inputTypes == null)
            {
                inputTypes = inputShapes.ToDictionary(k => k.Key, v => typeof(float));
            }
            sym.InferType(inputTypes, argTypes, auxType, outType);

            var gradArrays = needGradInput ? new Dictionary <string, NdArray>() : null;

            var argNames = sym.ListArguments();
            HashSet <string> needGrad;

            if (needGradInput == false)
            {
                needGrad = new HashSet <string>();
            }

            else
            {
                needGrad = new HashSet <string>(argNames.Except(inputShapes.Keys));
            }

            var gradReq = argNames.ToDictionary(name => name, v => needGrad.Contains(v) ? OpReqType.KWriteTo : OpReqType.KNullOp);

            List <NdArray> argArrays = new List <NdArray>();

            //create or borrow arguments and gradients
            for (int i = 0; i < argNames.Count; i++)
            {
                var name = argNames[i];
                if (!paramNames.Contains(name))
                {
                    NdArray argArr;
                    //data or label
                    if (sharedDataArrays != null && sharedDataArrays.ContainsKey(name))
                    {
                        argArr = sharedDataArrays[name];

                        if (Util.Prod(argArr.GetShape()) >= Util.Prod(argShapes[i]))
                        {
                            Util.Assert(argTypes[i] == argArr.GetDtype());

                            argArr = argArr.Reshape(new Shape(argShapes[i]));
                        }
                        else
                        {
                            logger.Warn($"bucketing: data \"{name}\" has a shape {new Shape(argShapes[i])}" +
                                        ", which is larger than already allocated " +
                                        $"shape {argArr.GetShape()}" +
                                        ". Need to re-allocate. Consider putting " +
                                        "default_bucket_key to be the bucket taking the largest " +
                                        "input for better memory sharing.");
                            argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]);

                            // replace existing shared array because the new one is bigger
                            sharedDataArrays[name] = argArr;
                        }
                    }
                    else
                    {
                        argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]);
                        if (sharedDataArrays != null)
                        {
                            sharedDataArrays[name] = argArr;
                        }
                    }

                    argArrays.Add(argArr);
                }
                else
                {
                    NdArray argArr;
                    if (baseExec == null)
                    {
                        argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]);

                        if (needGradInput && needGrad.Contains(name))
                        {
                            var gradArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]);
                            gradArrays[name] = gradArr;
                        }
                    }
                    else
                    {
                        argArr = baseExec.ArgDict[name];
                        Util.Assert(argArr.GetShape() == new Shape(argShapes[i]));
                        Util.Assert(argArr.GetDtype() == argTypes[i]);
                        if (needGradInput && needGrad.Contains(name))
                        {
                            gradArrays[name] = baseExec.GradDict[name];
                        }
                    }
                    argArrays.Add(argArr);
                }
            }
            IList <NdArray> auxArrays;

            if (baseExec == null)
            {
                auxArrays = auxShapes.Zip(auxType, (l, r) => NdArray.Zeros(new Shape(l), ctx, r)).ToList();
            }
            else
            {
                for (int i = 0; i < baseExec.AuxArrays.Count; i++)
                {
                    var a = baseExec.AuxArrays[i];
                    Util.Assert((new Shape(auxShapes[i])) == a.GetShape());
                    Util.Assert(auxType[i] == a.GetDtype());
                }
                auxArrays = baseExec.AuxArrays;
            }


            var executor = sym.Bind(ctx, argArrays, gradArrays,
                                    gradReq, auxArrays, null, baseExec);

            return(executor);
        }
Ejemplo n.º 3
0
        public Executor SimpleBind(
            Context context,
            Dictionary <string, uint[]> inputShapes,
            OpReqType gradReq,
            Dictionary <string, Type> typeDict     = null,
            Dictionary <string, Context> group2Ctx = null
            )
        {
            var listArguments = ListArguments();


            if (typeDict == null)
            {
                typeDict = listArguments.ToDictionary(k => k, v => typeof(float));
            }

            var argShapes = new List <uint[]>();
            var auxShapes = new List <uint[]>();
            var outShapes = new List <uint[]>();

            InferShape(inputShapes, argShapes, outShapes, auxShapes);


            var argTypes = new List <Type>();
            var auxTypes = new List <Type>();
            var outTypes = new List <Type>();

            InferType(typeDict, argTypes, auxTypes, outTypes);

            if (argShapes.Count == 0 || argTypes.Count == 0)
            {
                throw new Exception("Input node is not complete");
            }

            List <Context> argCtx;
            List <Context> auxCtx;

            if (group2Ctx != null)
            {
                var listattr = ListAttr(true);
                var attrDict = listattr.Where(w => w.Key.EndsWith("ctx_group"))
                               .ToDictionary(k => k.Key, v => group2Ctx.GetValueOrDefault(v.Value, context));
                argCtx = listArguments
                         .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList();
                auxCtx = ListAuxiliaryStates()
                         .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList();
            }
            else
            {
                argCtx = Enumerable.Range(0, argShapes.Count).Select(s => context).ToList();
                auxCtx = Enumerable.Range(0, auxShapes.Count).Select(s => context).ToList();
            }

            //alloc space
            var argNdarrays = argTypes
                              .Zip(argCtx, argShapes,
                                   (dtype, dev, shape) =>
                                   NdArray.Zeros(new Shape(shape), dev, dtype))
                              .ToList();
            Dictionary <string, NdArray> gradNdarrays = new Dictionary <string, NdArray>();

            if (gradReq != OpReqType.KNullOp)
            {
                for (int i = 0; i < listArguments.Count; i++)
                {
                    var name  = listArguments[i];
                    var shape = argShapes[i];
                    var dev   = argCtx[i];
                    var dtype = argTypes[i];

                    gradNdarrays[name] = NdArray.Zeros(new Shape(shape), dev, dtype: dtype);
                }
            }

            var auxNdarrays = auxTypes
                              .Zip(auxCtx, auxShapes,
                                   (dtype, dev, shape) =>
                                   NdArray.Zeros(new Shape(shape), dev, dtype))
                              .ToList();

            var executor = Bind(context,
                                argNdarrays,
                                gradNdarrays,
                                gradReq,
                                auxNdarrays,
                                groupToCtx: group2Ctx);

            return(executor);
        }