Exemplo n.º 1
0
        public Executor SimpleBind(
            Context context,
            Dictionary <string, uint[]> inputShapes,
            OpReqType gradReq,
            Dictionary <string, Type> typeDict     = null,
            Dictionary <string, Context> group2Ctx = null
            )
        {
            var listArguments = ListArguments();


            if (typeDict == null)
            {
                typeDict = listArguments.ToDictionary(k => k, v => typeof(float));
            }

            var argShapes = new List <uint[]>();
            var auxShapes = new List <uint[]>();
            var outShapes = new List <uint[]>();

            InferShape(inputShapes, argShapes, outShapes, auxShapes);


            var argTypes = new List <Type>();
            var auxTypes = new List <Type>();
            var outTypes = new List <Type>();

            InferType(typeDict, argTypes, auxTypes, outTypes);

            if (argShapes.Count == 0 || argTypes.Count == 0)
            {
                throw new Exception("Input node is not complete");
            }

            List <Context> argCtx;
            List <Context> auxCtx;

            if (group2Ctx != null)
            {
                var listattr = ListAttr(true);
                var attrDict = listattr.Where(w => w.Key.EndsWith("ctx_group"))
                               .ToDictionary(k => k.Key, v => group2Ctx.GetValueOrDefault(v.Value, context));
                argCtx = listArguments
                         .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList();
                auxCtx = ListAuxiliaryStates()
                         .Select(name => attrDict.GetValueOrDefault(name + "_ctx_group", context)).ToList();
            }
            else
            {
                argCtx = Enumerable.Range(0, argShapes.Count).Select(s => context).ToList();
                auxCtx = Enumerable.Range(0, auxShapes.Count).Select(s => context).ToList();
            }

            //alloc space
            var argNdarrays = argTypes
                              .Zip(argCtx, argShapes,
                                   (dtype, dev, shape) =>
                                   NdArray.Zeros(new Shape(shape), dev, dtype))
                              .ToList();
            Dictionary <string, NdArray> gradNdarrays = new Dictionary <string, NdArray>();

            if (gradReq != OpReqType.KNullOp)
            {
                for (int i = 0; i < listArguments.Count; i++)
                {
                    var name  = listArguments[i];
                    var shape = argShapes[i];
                    var dev   = argCtx[i];
                    var dtype = argTypes[i];

                    gradNdarrays[name] = NdArray.Zeros(new Shape(shape), dev, dtype: dtype);
                }
            }

            var auxNdarrays = auxTypes
                              .Zip(auxCtx, auxShapes,
                                   (dtype, dev, shape) =>
                                   NdArray.Zeros(new Shape(shape), dev, dtype))
                              .ToList();

            var executor = Bind(context,
                                argNdarrays,
                                gradNdarrays,
                                gradReq,
                                auxNdarrays,
                                groupToCtx: group2Ctx);

            return(executor);
        }
Exemplo n.º 2
0
        public void Push(int key, NdArray val, int priority)
        {
            NDArrayHandle valHandle = val.Handle;

            Util.CallCheck(NativeMethods.MXKVStorePush(_blobPtr.Handle, 1, new int[] { key }, new NDArrayHandle[] { valHandle }, priority));
        }
Exemplo n.º 3
0
        void InferExecutorArrays(
            Context context, IList <NdArray> argArrays,
            IList <NdArray> gradArrays, IList <OpReqType> gradReqs,
            IList <NdArray> auxArrays,
            Dictionary <string, NdArray> argsMap,
            Dictionary <string, NdArray> argGradStore,
            Dictionary <string, OpReqType> gradReqType,
            Dictionary <string, NdArray> auxMap)
        {
            var argNameList = ListArguments();
            var inShapes    = new List <uint[]>();
            var auxShapes   = new List <uint[]>();
            var outShapes   = new List <uint[]>();
            var argShapes   = new Dictionary <string, uint[]>();

            foreach (var argName in argNameList)
            {
                if (argsMap.ContainsKey(argName))
                {
                    argShapes[argName] = argsMap[argName].GetShape();
                }
            }


            InferShape(argShapes, inShapes, outShapes, auxShapes);

            for (int i = 0; i < inShapes.Count; ++i)
            {
                var shape   = inShapes[i];
                var argName = argNameList[i];

                if (argsMap.ContainsKey(argName))
                {
                    argArrays.Add(argsMap[argName]);
                }
                else
                {
                    var temp = new NdArray(shape, context, false);
                    argArrays.Add(temp);
                    NdArray.SampleGaussian(0, 1, temp);
                }

                if (argGradStore.ContainsKey(argName))
                {
                    gradArrays.Add(argGradStore[argName]);
                }
                else
                {
                    gradArrays.Add(new NdArray(shape, context, false));
                }
                if (gradReqType.ContainsKey(argName))
                {
                    gradReqs.Add(gradReqType[argName]);
                }
                else
                {
                    gradReqs.Add(csharp.OpReqType.KWriteTo);
                }
            }

            var auxNameList = ListAuxiliaryStates();

            for (int i = 0; i < auxShapes.Count; ++i)
            {
                var shape   = auxShapes[i];
                var auxName = auxNameList[i];
                if (auxMap.ContainsKey(auxName))
                {
                    auxArrays.Add(auxMap[auxName]);
                }
                else
                {
                    var temp = new NdArray(shape, context, false);
                    auxArrays.Add(temp);
                    csharp.NdArray.SampleGaussian(0, 1, temp);
                }
            }
        }
Exemplo n.º 4
0
        public void Init(int key, NdArray val)
        {
            NDArrayHandle valHandle = val.Handle;

            Util.CallCheck(NativeMethods.MXKVStoreInit(_blobPtr.Handle, 1, new int[] { key }, new NDArrayHandle[] { valHandle }));
        }
Exemplo n.º 5
0
        public void Pull(int key, NdArray @out, int priority)
        {
            NDArrayHandle outHandle = @out.Handle;

            Util.CallCheck(NativeMethods.MXKVStorePull(_blobPtr.Handle, 1, new[] { key }, new[] { outHandle }, priority));
        }
Exemplo n.º 6
0
        private Tuple <List <string>, List <string>, List <string> > InitParams(
            Dictionary <string, Shape> inputShapes,
            bool overwrite = false)
        {
            List <uint[]> argShapes = new List <uint[]>();
            List <uint[]> auxShapes = new List <uint[]>();

            this._symbol.InferShape(inputShapes, argShapes, null, auxShapes);

            var argNames   = this._symbol.ListArguments();
            var inputNames = inputShapes.Keys;

            var paramNames = argNames.Except(inputNames);

            var auxNames = this._symbol.ListAuxiliaryStates();

            var paramNameShapes = argNames.Zip(argShapes, Tuple.Create).Where(w => paramNames.Contains(w.Item1));

            var argParams = paramNameShapes.ToDictionary(k => k.Item1, s => NdArray.Zeros(new Shape(s.Item2)));
            var auxParams = auxNames.Zip(auxShapes, Tuple.Create).ToDictionary(k => k.Item1, s => NdArray.Zeros(new Shape(s.Item2)));

            foreach (var kv in argParams)
            {
                var k = kv.Key;
                if (this.ArgParams != null && this.ArgParams.ContainsKey(kv.Key) && !overwrite)
                {
                    this.ArgParams[k].CopyTo(argParams[k]);
                }
                else
                {
                    this._initializer.Call(k, argParams[k]);
                }
            }


            foreach (var kv in auxParams)
            {
                var k = kv.Key;
                if (this.AuxParams != null && this.AuxParams.ContainsKey(kv.Key) && !overwrite)
                {
                    this.AuxParams[k].CopyTo(auxParams[k]);
                }
                else
                {
                    this._initializer.Call(k, argParams[k]);
                }
            }

            this.ArgParams = argParams;
            this.AuxParams = auxParams;

            return(Tuple.Create(argNames.ToList(), paramNames.ToList(), auxNames.ToList()));
        }
Exemplo n.º 7
0
 NdArray asum_stat(NdArray x)
 {
     //TODO  return ndarray.norm(x)/Math.Sqrt(x.size());
     return(x);
 }
Exemplo n.º 8
0
        private Executor BindExec(Symbol sym, Context ctx, Dictionary <string, uint[]> inputShapes,
                                  IList <string> paramNames, bool needGradInput, Executor baseExec,
                                  Dictionary <string, NdArray> sharedDataArrays, Dictionary <string, Dtype> inputTypes = null, ILog logger = null)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }

            var argShapes = new List <uint[]>();
            var auxShapes = new List <uint[]>();
            var outShapes = new List <uint[]>();

            sym.InferShape(inputShapes, argShapes, outShapes, auxShapes);


            var argTypes = new List <Dtype>();
            var auxType  = new List <Dtype>();
            var outType  = new List <Dtype>();

            if (inputTypes == null)
            {
                inputTypes = inputShapes.ToDictionary(k => k.Key, v => Dtype.Float32);
            }
            sym.InferType(inputTypes, argTypes, auxType, outType);

            var gradArrays = needGradInput ? new Dictionary <string, NdArray>() : null;

            var argNames = sym.ListArguments();
            HashSet <string> needGrad;

            if (needGradInput == false)
            {
                needGrad = new HashSet <string>();
            }

            else
            {
                needGrad = new HashSet <string>(argNames.Except(inputShapes.Keys));
            }

            var gradReq = argNames.ToDictionary(name => name, v => needGrad.Contains(v) ? OpReqType.KWriteTo : OpReqType.KNullOp);

            List <NdArray> argArrays = new List <NdArray>();

            //create or borrow arguments and gradients
            for (int i = 0; i < argNames.Count; i++)
            {
                var name = argNames[i];
                if (!paramNames.Contains(name))
                {
                    NdArray argArr;
                    //data or label
                    if (sharedDataArrays != null && sharedDataArrays.ContainsKey(name))
                    {
                        argArr = sharedDataArrays[name];

                        if (Util.Prod(argArr.GetShape()) >= Util.Prod(argShapes[i]))
                        {
                            Util.Assert(argTypes[i] == argArr.GetDtype());

                            argArr = argArr.Reshape(new Shape(argShapes[i]));
                        }
                        else
                        {
                            logger.Warn($"bucketing: data \"{name}\" has a shape {new Shape(argShapes[i])}" +
                                        ", which is larger than already allocated " +
                                        $"shape {argArr.GetShape()}" +
                                        ". Need to re-allocate. Consider putting " +
                                        "default_bucket_key to be the bucket taking the largest " +
                                        "input for better memory sharing.");
                            argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]);

                            // replace existing shared array because the new one is bigger
                            sharedDataArrays[name] = argArr;
                        }
                    }
                    else
                    {
                        argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]);
                        if (sharedDataArrays != null)
                        {
                            sharedDataArrays[name] = argArr;
                        }
                    }

                    argArrays.Add(argArr);
                }
                else
                {
                    NdArray argArr;
                    if (baseExec == null)
                    {
                        argArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]);

                        if (needGradInput && needGrad.Contains(name))
                        {
                            var gradArr = NdArray.Zeros(new Shape(argShapes[i]), ctx, dtype: argTypes[i]);
                            gradArrays[name] = gradArr;
                        }
                    }
                    else
                    {
                        argArr = baseExec.ArgDict[name];
                        Util.Assert(argArr.GetShape() == new Shape(argShapes[i]));
                        Util.Assert(argArr.GetDtype() == argTypes[i]);
                        if (needGradInput && needGrad.Contains(name))
                        {
                            gradArrays[name] = baseExec.GradDict[name];
                        }
                    }
                    argArrays.Add(argArr);
                }
            }
            IList <NdArray> auxArrays;

            if (baseExec == null)
            {
                auxArrays = auxShapes.Zip(auxType, (l, r) => NdArray.Zeros(new Shape(l), ctx, r)).ToList();
            }
            else
            {
                for (int i = 0; i < baseExec.AuxArrays.Count; i++)
                {
                    var a = baseExec.AuxArrays[i];
                    Util.Assert((new Shape(auxShapes[i])) == a.GetShape());
                    Util.Assert(auxType[i] == a.GetDtype());
                }
                auxArrays = baseExec.AuxArrays;
            }


            var executor = sym.Bind(ctx, argArrays, gradArrays,
                                    gradReq, auxArrays, null, baseExec);

            return(executor);
        }