public static void Normal(float loc, float scale, NdArray @out) { FunctionHandle funcHandle; NativeMethods.MXGetFunction("_sample_normal", out funcHandle); var input = IntPtr.Zero; var output = @out.Handle; var paramKeys = new string[] { "loc", "scale", "shape" }; var paramVals = new string[] { loc.ToString(), scale.ToString(), @out.GetShape().ToString() }; Util.CallCheck(NativeMethods.MXFuncInvokeEx( funcHandle, ref input, new float[0], ref output, paramKeys.Length, paramKeys, paramVals )); }
public static void Uniform(float low, float high, NdArray @out) { FunctionHandle funcHandle; NativeMethods.MXGetFunction("_sample_uniform", out funcHandle); var input = IntPtr.Zero; var output = @out.Handle; var paramKeys = new string[] { "low", "high", "shape" }; var paramVals = new string[] { low.ToString(CultureInfo.InvariantCulture), high.ToString(CultureInfo.InvariantCulture), @out.GetShape().ToString() }; Util.CallCheck(NativeMethods.MXFuncInvokeEx( funcHandle, ref input, new float[0], ref output, paramKeys.Length, paramKeys, paramVals )); }
public override NdArray create_state(int index, NdArray weight) { if (Math.Abs(this._momentum) < float.Epsilon) { return(null); } return(NdArray.Zeros(weight.GetShape(), weight.GetContext(), weight.GetDtype())); }
protected override void InitWeight(string name, NdArray arr) { var shape = arr.GetShape(); float hwScale = 1.0f; if (shape.Size() > 2) { hwScale = Util.Prod(shape.Data().Skip(2).ToArray()); } var fanOut = shape[0] * hwScale; var fanIn = shape[1] * hwScale; float factor = 1.0f; switch (_factorType) { case FactorType.Avg: factor = (fanIn + fanOut) / 2.0f; break; case FactorType.In: factor = fanIn; break; case FactorType.Out: factor = fanOut; break; default: throw new ArgumentOutOfRangeException(nameof(FactorType)); } var scale = (float)Math.Sqrt(this._magnitude / factor); switch (_rndType) { case RndType.Uniform: Random.Uniform(-scale, scale, arr); break; case RndType.Gaussian: Random.Normal(-scale, scale, arr); break; default: throw new ArgumentOutOfRangeException(nameof(RndType)); } }
protected virtual void InitBilinear(string name, NdArray arr) { var shape = arr.GetShape().Data(); var prodShape = Util.Prod(shape); float[] weight = new float[prodShape]; var f = Math.Ceiling(shape[3] / 2.0); var c = (2 * f - 1 - f % 2) / (2.0 * f); for (int i = 0; i < prodShape; i++) { var x = i % shape[3]; var y = (i / shape[3]) % shape[2]; weight[i] = (float)((1 - Math.Abs(x / f - c)) * (1 - Math.Abs(y / f - c))); } arr.SyncCopyFromCpu(weight); }
protected virtual void InitLocBias(string name, NdArray arr) { Util.Assert(arr.GetShape()[0] == 6); arr.SyncCopyFromCpu(new[] { 1.0f, 0, 0, 0, 1.0f, 0 }); }