public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) { args.DType = this._dtype; } float n = 0; var(fan_in, fan_out) = _compute_fans(args.Shape); if (_mode == "FAN_IN") { n = fan_in; } else if (_mode == "FAN_OUT") { n = fan_out; } else if (_mode == "FAN_AVG") { n = (fan_in + fan_out) / 2.0f; } if (_uniform) { var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n)); return(random_ops.random_uniform(args.Shape, -limit, limit, args.DType)); } else { var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n)); return(random_ops.truncated_normal(args.Shape, 0.0f, trunc_stddev, args.DType, seed: _seed)); } }
public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) { args.DType = this.dtype; } return(random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed)); }
public Tensor Apply(InitializerArgs args) { if (args.DType != TF_DataType.DtInvalid) { dtype = args.DType; } return(random_ops.truncated_normal(args.Shape, mean, stddev, dtype: dtype, seed: seed)); }
public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) { args.DType = this.dtype; } return(array_ops.ones(args.Shape, dtype)); }
public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) { args.DType = dtype; } return(random_ops.random_uniform(args.Shape, minval: minval, maxval: maxval, dtype: dtype, seed: seed)); }
public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) { args.DType = dtype; } if (args.Shape == null) { args.Shape = shape; } return(array_ops.zeros(args.Shape, dtype)); }
public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) { args.DType = this.dtype; } args.VerifyShape = _verify_shape; return(constant_op.constant(value, args.DType, args.Shape, name: "Const", verify_shape: args.VerifyShape, allow_broadcast: false)); }
public Tensor Apply(InitializerArgs args) { throw new NotImplementedException(); }