コード例 #1
0
ファイル: Executor.cs プロジェクト: AvenSun/MxNet.Sharp
        public Executor(Symbol symbol,
                        Context context,
                        NDArrayList argmentArrays,
                        NDArrayList gradientArrays,
                        IList <OpGradReq> gradReqs,
                        NDArrayList auxiliaryArrays,
                        IDictionary <string, Context> groupToCtx,
                        Executor sharedExec)
        {
            if (symbol == null)
            {
                throw new ArgumentNullException(nameof(symbol));
            }
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argmentArrays == null)
            {
                throw new ArgumentNullException(nameof(argmentArrays));
            }
            if (gradientArrays == null)
            {
                throw new ArgumentNullException(nameof(gradientArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxiliaryArrays == null)
            {
                throw new ArgumentNullException(nameof(auxiliaryArrays));
            }
            if (groupToCtx == null)
            {
                throw new ArgumentNullException(nameof(groupToCtx));
            }

            ArgmentArrays   = argmentArrays;
            GradientArrays  = gradientArrays;
            AuxiliaryArrays = auxiliaryArrays;
            _Symbol         = symbol;

            var argHandles   = argmentArrays.Select(array => array.GetHandle()).ToArray();
            var gradHandles  = gradientArrays.Select(array => array.GetHandle()).ToArray();
            var auxHandles   = auxiliaryArrays.Select(array => array.GetHandle()).ToArray();
            var gradReqsUint = gradReqs.Select(s => (uint)s).ToArray();

            var mapKeys  = new string[groupToCtx.Count];
            var devTypes = new int[groupToCtx.Count];
            var devIds   = new int[groupToCtx.Count];
            var keys     = groupToCtx.Keys.ToArray();

            for (var index = 0; index < keys.Length; index++)
            {
                var key = keys[index];
                mapKeys[index] = key;
                var value = groupToCtx[key];
                devTypes[index] = (int)value.GetDeviceType();
                devIds[index]   = value.GetDeviceId();
            }

            var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero;

            Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(),
                                                            (int)context.GetDeviceType(),
                                                            context.GetDeviceId(),
                                                            (uint)groupToCtx.Count,
                                                            mapKeys,
                                                            devTypes,
                                                            devIds,
                                                            (uint)argHandles.Length,
                                                            argHandles,
                                                            gradHandles,
                                                            gradReqsUint,
                                                            (uint)auxHandles.Length,
                                                            auxHandles,
                                                            sharedExecHandle,
                                                            out var handle), NativeMethods.OK);
            Handle = handle;

            Outputs = new NDArrayList();
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (uint i = 0; i < outSize; ++i)
            {
                Outputs.Add(new NDArray(outArrayArray[i]));
            }
        }