public override KerasSymbol Call(Shape shape, DType dtype = null) { var(fan_in, fan_out) = _compute_fans(shape); var scale = this.scale; if (this.mode == "fan_in") { scale /= Math.Max(1, fan_in); } else if (this.mode == "fan_out") { scale /= Math.Max(1, fan_out); } else { scale /= Math.Max(1, (fan_in + fan_out) / 2); } if (this.distribution == "normal") { // 0.879... = scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) float stddev = (float)Math.Sqrt(scale) / 0.8796256610342398f; return(K.TruncatedNormal(shape, 0, stddev, dtype: dtype, seed: this.seed)); } else { float limit = (float)Math.Sqrt(3.0 * scale); return(K.RandomUniform(shape, -limit, limit, dtype: dtype, seed: this.seed)); } }
public override KerasSymbol Call(Shape shape, DType dtype = null) { return(K.RandomUniform(shape, this.minval, this.maxval, dtype: dtype, seed: this.seed)); }