Exemplo n.º 1
0
        public NDArrayList Invoke(NDArrayList outputs)
        {
            var paramKeys   = new List <string>();
            var paramValues = new List <string>();

            foreach (var data in _params)
            {
                paramKeys.Add(data.Key);
                paramValues.Add(data.Value);
            }

            var numInputs  = _inputNDArrays.Count;
            var numOutputs = outputs.Count;

            var      outputHandles   = outputs.Select(s => s.Handle).ToArray();
            var      outputsReceiver = IntPtr.Zero;
            GCHandle?gcHandle        = null;

            try
            {
                if (outputs.Count > 0)
                {
                    gcHandle        = GCHandle.Alloc(outputHandles, GCHandleType.Pinned);
                    outputsReceiver = gcHandle.Value.AddrOfPinnedObject();
                }

                NDArrayHandle[] outputsReceivers = { outputsReceiver };

                CheckCall(_LIB.MXImperativeInvoke(_handle, numInputs, _inputNDArrays.ToArray(), ref numOutputs,
                                                  ref outputsReceiver,
                                                  paramKeys.Count, paramKeys.ToArray(), paramValues.ToArray()));

                if (outputs.Count > 0)
                {
                    gcHandle?.Free();
                    return(outputs);
                }

                outputHandles = new NDArrayHandle[numOutputs];

                Marshal.Copy(outputsReceiver, outputHandles, 0, numOutputs);

                foreach (var outputHandle in outputHandles)
                {
                    outputs.Add(new NDArray(outputHandle));
                }

                gcHandle?.Free();
            }
            catch (Exception ex)
            {
                throw ex;
            }
            finally
            {
                gcHandle?.Free();
            }

            return(outputs);
        }
Exemplo n.º 2
0
        public static NDArray clip_global_norm(NDArrayList arrays, float max_norm, bool check_isfinite = true)
        {
            Func <NDArray, NDArray> norm = array =>
            {
                if (array.SType == StorageStype.Default)
                {
                    var x = array.Reshape(-1);
                    return(nd.Dot(x, x));
                }

                return(array.Norm().Square());
            };

            if (arrays.Length == 0)
            {
                throw new ArgumentException("arrays.Length == 0");
            }

            var ctx        = arrays[0].Context;
            var total_norm = nd.AddN(arrays.Select(x => x.AsInContext(ctx)).ToArray());

            total_norm = total_norm.Sqrt();
            if (check_isfinite)
            {
                if (float.IsInfinity(total_norm.AsScalar <float>()))
                {
                    Logger.Warning("nan or inf is detected. " +
                                   "Clipping results will be undefined.");
                }
            }

            var scale = max_norm / (total_norm + 1e-8f);

            scale = nd.Min(nd.Concat(new NDArrayList(scale, nd.Ones(new Shape(1), ctx)), 0));
            for (var i = 0; i < arrays.Length; i++)
            {
                arrays[i] *= scale.AsInContext(arrays[i].Context);
            }

            return(total_norm);
        }