Пример #1
0
        public void InferArgsMap(Context context,
                                 IDictionary <string, NDArray> argsMap,
                                 IDictionary <string, NDArray> knownArgs)
        {
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argsMap == null)
            {
                throw new ArgumentNullException(nameof(argsMap));
            }
            if (knownArgs == null)
            {
                throw new ArgumentNullException(nameof(knownArgs));
            }

            this.ThrowIfDisposed();

            var argShapes = new Dictionary <string, IList <mx_uint> >();
            var inShapes  = new List <List <mx_uint> >();
            var auxShapes = new List <List <mx_uint> >();
            var outShapes = new List <List <mx_uint> >();

            var argNameList = this.ListArguments();

            foreach (var argName in argNameList)
            {
                if (knownArgs.TryGetValue(argName, out var value))
                {
                    argShapes[argName] = value.GetShape();
                }
            }

            this.InferShape(argShapes, inShapes, auxShapes, outShapes);

            for (var i = 0; i < inShapes.Count; ++i)
            {
                var shape   = inShapes[i];
                var argName = argNameList[i];
                if (knownArgs.TryGetValue(argName, out var value))
                {
                    argsMap[argName] = value;
                }
                else
                {
                    var array = new NDArray(shape.ToArray(), context, false);
                    argsMap[argName] = array;
                    NDArray.SampleGaussian(0, 1, array);
                }
            }
        }
Пример #2
0
        public override void Operator(string name, NDArray array)
        {
            var shape   = new Shape(array.GetShape());
            var hwScale = 1.0f;

            if (shape.Dimension > 2)
            {
                for (uint i = 2; i < shape.Dimension; ++i)
                {
                    hwScale *= shape[i];
                }
            }

            var @in    = shape[1] * hwScale;
            var @out   = shape[0] * hwScale;
            var factor = 1.0f;

            switch (this.Factor)
            {
            case FactorType.Average:
                factor = (@in + @out) / 2.0f;
                break;

            case FactorType.In:
                factor = @in;
                break;

            case FactorType.Out:
                factor = @out;
                break;
            }

            var scale = (float)Math.Sqrt(this.Magnitude / factor);

            switch (this.Rand)
            {
            case RandType.Uniform:
                NDArray.SampleUniform(-scale, scale, array);
                break;

            case RandType.Gaussian:
                NDArray.SampleGaussian(0, scale, array);
                break;
            }
        }
Пример #3
0
        public void InferExecutorArrays(Context context,
                                        IList <NDArray> argArrays,
                                        IList <NDArray> gradArrays,
                                        IList <OpReqType> gradReqs,
                                        IList <NDArray> auxArrays,
                                        IDictionary <string, NDArray> argsMap,
                                        IDictionary <string, NDArray> argGradStore,
                                        IDictionary <string, OpReqType> gradReqType,
                                        IDictionary <string, NDArray> auxMap)
        {
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argArrays == null)
            {
                throw new ArgumentNullException(nameof(argArrays));
            }
            if (gradArrays == null)
            {
                throw new ArgumentNullException(nameof(gradArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxArrays == null)
            {
                throw new ArgumentNullException(nameof(auxArrays));
            }
            if (argsMap == null)
            {
                throw new ArgumentNullException(nameof(argsMap));
            }
            if (argGradStore == null)
            {
                throw new ArgumentNullException(nameof(argGradStore));
            }
            if (gradReqType == null)
            {
                throw new ArgumentNullException(nameof(gradReqType));
            }
            if (auxMap == null)
            {
                throw new ArgumentNullException(nameof(auxMap));
            }

            this.ThrowIfDisposed();

            var argNameList = this.ListArguments();
            var inShapes    = new List <List <mx_uint> >();
            var auxShapes   = new List <List <mx_uint> >();
            var outShapes   = new List <List <mx_uint> >();
            var argShapes   = new Dictionary <string, IList <mx_uint> >();

            foreach (var argName in argNameList)
            {
                if (argsMap.TryGetValue(argName, out var value))
                {
                    argShapes[argName] = value.GetShape();
                }
            }

            this.InferShape(argShapes, inShapes, auxShapes, outShapes);

            for (var i = 0; i < inShapes.Count; ++i)
            {
                var shape   = inShapes[i];
                var argName = argNameList[i];
                if (argsMap.TryGetValue(argName, out var value1))
                {
                    argArrays.Add(value1);
                }
                else
                {
                    argArrays.Add(new NDArray(shape, context, false));
                    NDArray.SampleGaussian(0, 1, argArrays.Last());
                }

                if (argGradStore.TryGetValue(argName, out var value2))
                {
                    gradArrays.Add(value2);
                }
                else
                {
                    gradArrays.Add(new NDArray(shape, context, false));
                }

                if (gradReqType.TryGetValue(argName, out var value3))
                {
                    gradReqs.Add(value3);
                }
                else if (argName.LastIndexOf("data", StringComparison.InvariantCulture) == argName.Length - 4 ||
                         argName.LastIndexOf("label", StringComparison.InvariantCulture) == argName.Length - 5)
                {
                    gradReqs.Add(OpReqType.NullOp);
                }
                else
                {
                    gradReqs.Add(OpReqType.WriteTo);
                }
            }

            var auxNameList = this.ListAuxiliaryStates();

            for (var i = 0; i < auxShapes.Count; ++i)
            {
                var shape   = auxShapes[i];
                var auxName = auxNameList[i];
                if (auxMap.TryGetValue(auxName, out var value))
                {
                    auxArrays.Add(value);
                }
                else
                {
                    auxArrays.Add(new NDArray(shape, context, false));
                    NDArray.SampleGaussian(0, 1, auxArrays.Last());
                }
            }
        }