public NDArrayList Invoke(NDArrayList outputs) { var paramKeys = new List <string>(); var paramValues = new List <string>(); foreach (var data in _params) { paramKeys.Add(data.Key); paramValues.Add(data.Value); } var numInputs = _inputNDArrays.Count; var numOutputs = outputs.Count; var outputHandles = outputs.Select(s => s.Handle).ToArray(); var outputsReceiver = IntPtr.Zero; GCHandle?gcHandle = null; try { if (outputs.Count > 0) { gcHandle = GCHandle.Alloc(outputHandles, GCHandleType.Pinned); outputsReceiver = gcHandle.Value.AddrOfPinnedObject(); } NDArrayHandle[] outputsReceivers = { outputsReceiver }; CheckCall(_LIB.MXImperativeInvoke(_handle, numInputs, _inputNDArrays.ToArray(), ref numOutputs, ref outputsReceiver, paramKeys.Count, paramKeys.ToArray(), paramValues.ToArray())); if (outputs.Count > 0) { gcHandle?.Free(); return(outputs); } outputHandles = new NDArrayHandle[numOutputs]; Marshal.Copy(outputsReceiver, outputHandles, 0, numOutputs); foreach (var outputHandle in outputHandles) { outputs.Add(new NDArray(outputHandle)); } gcHandle?.Free(); } catch (Exception ex) { throw ex; } finally { gcHandle?.Free(); } return(outputs); }
public static NDArray clip_global_norm(NDArrayList arrays, float max_norm, bool check_isfinite = true) { Func <NDArray, NDArray> norm = array => { if (array.SType == StorageStype.Default) { var x = array.Reshape(-1); return(nd.Dot(x, x)); } return(array.Norm().Square()); }; if (arrays.Length == 0) { throw new ArgumentException("arrays.Length == 0"); } var ctx = arrays[0].Context; var total_norm = nd.AddN(arrays.Select(x => x.AsInContext(ctx)).ToArray()); total_norm = total_norm.Sqrt(); if (check_isfinite) { if (float.IsInfinity(total_norm.AsScalar <float>())) { Logger.Warning("nan or inf is detected. " + "Clipping results will be undefined."); } } var scale = max_norm / (total_norm + 1e-8f); scale = nd.Min(nd.Concat(new NDArrayList(scale, nd.Ones(new Shape(1), ctx)), 0)); for (var i = 0; i < arrays.Length; i++) { arrays[i] *= scale.AsInContext(arrays[i].Context); } return(total_norm); }