Example #1
0
        public void InferExecutorArrays(Context context,
                                        NDArrayList argArrays,
                                        NDArrayList gradArrays,
                                        IList <OpGradReq> gradReqs,
                                        NDArrayList auxArrays,
                                        NDArrayDict argsMap,
                                        NDArrayDict argGradStore,
                                        IDictionary <string, OpGradReq> gradReqType,
                                        NDArrayDict auxMap)
        {
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argArrays == null)
            {
                throw new ArgumentNullException(nameof(argArrays));
            }
            if (gradArrays == null)
            {
                throw new ArgumentNullException(nameof(gradArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxArrays == null)
            {
                throw new ArgumentNullException(nameof(auxArrays));
            }
            if (argsMap == null)
            {
                throw new ArgumentNullException(nameof(argsMap));
            }
            if (argGradStore == null)
            {
                throw new ArgumentNullException(nameof(argGradStore));
            }
            if (gradReqType == null)
            {
                throw new ArgumentNullException(nameof(gradReqType));
            }
            if (auxMap == null)
            {
                throw new ArgumentNullException(nameof(auxMap));
            }

            ThrowIfDisposed();

            var argNameList = ListArguments();
            var argShapes   = new Dictionary <string, Shape>();

            foreach (var argName in argNameList)
            {
                if (argsMap[argName] != null)
                {
                    argShapes[argName] = argsMap[argName].Shape;
                }
            }

            var(inShapes, auxShapes, outShapes) = InferShape(argShapes);

            for (var i = 0; i < inShapes.Length; ++i)
            {
                var shape   = inShapes[i];
                var argName = argNameList[i];
                if (argsMap[argName] != null)
                {
                    argArrays.Add(argsMap[argName]);
                }
                else
                {
                    argArrays.Add(new NDArray(shape, false));
                    //NDArray.SampleGaussian(0, 1, argArrays.Last());
                    var argArr = argArrays.Last();
                    nd.Random.Uniform(0, 1, argArr.Shape).CopyTo(argArr);
                }

                if (argGradStore[argName] != null)
                {
                    gradArrays.Add(argGradStore[argName]);
                }
                else
                {
                    gradArrays.Add(new NDArray(shape, false));
                }

                if (gradReqType.TryGetValue(argName, out var value3))
                {
                    gradReqs.Add(value3);
                }
                else if (argName.LastIndexOf("data", StringComparison.InvariantCulture) == argName.Length - 4 ||
                         argName.LastIndexOf("label", StringComparison.InvariantCulture) == argName.Length - 5)
                {
                    gradReqs.Add(OpGradReq.Null);
                }
                else
                {
                    gradReqs.Add(OpGradReq.Write);
                }
            }

            var auxNameList = ListAuxiliaryStates();

            for (var i = 0; i < auxShapes.Length; ++i)
            {
                var shape   = auxShapes[i];
                var auxName = auxNameList[i];
                if (auxMap[auxName] != null)
                {
                    auxArrays.Add(auxMap[auxName]);
                }
                else
                {
                    auxArrays.Add(new NDArray(shape, false));
                    var aux = auxArrays.Last();
                    //NDArray.SampleGaussian(0, 1, auxArrays.Last());
                    nd.Random.Uniform(0, 1, aux.Shape).CopyTo(aux);
                }
            }
        }