示例#1
0
        public Tensor Apply(InitializerArgs args)
        {
            if (args.DType == TF_DataType.DtInvalid)
            {
                args.DType = this._dtype;
            }

            float n = 0;

            var(fan_in, fan_out) = _compute_fans(args.Shape);
            if (_mode == "FAN_IN")
            {
                n = fan_in;
            }
            else if (_mode == "FAN_OUT")
            {
                n = fan_out;
            }
            else if (_mode == "FAN_AVG")
            {
                n = (fan_in + fan_out) / 2.0f;
            }

            if (_uniform)
            {
                var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n));
                return(random_ops.random_uniform(args.Shape, -limit, limit, args.DType));
            }
            else
            {
                var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n));
                return(random_ops.truncated_normal(args.Shape, 0.0f, trunc_stddev, args.DType,
                                                   seed: _seed));
            }
        }
示例#2
0
 public Tensor Apply(InitializerArgs args)
 {
     if (args.DType == TF_DataType.DtInvalid)
     {
         args.DType = this.dtype;
     }
     return(random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed));
 }
示例#3
0
 public Tensor Apply(InitializerArgs args)
 {
     if (args.DType != TF_DataType.DtInvalid)
     {
         dtype = args.DType;
     }
     return(random_ops.truncated_normal(args.Shape, mean, stddev, dtype: dtype, seed: seed));
 }
示例#4
0
        public Tensor Apply(InitializerArgs args)
        {
            if (args.DType == TF_DataType.DtInvalid)
            {
                args.DType = this.dtype;
            }

            return(array_ops.ones(args.Shape, dtype));
        }
示例#5
0
        public Tensor Apply(InitializerArgs args)
        {
            if (args.DType == TF_DataType.DtInvalid)
            {
                args.DType = dtype;
            }

            return(random_ops.random_uniform(args.Shape,
                                             minval: minval,
                                             maxval: maxval,
                                             dtype: dtype,
                                             seed: seed));
        }
示例#6
0
        public Tensor Apply(InitializerArgs args)
        {
            if (args.DType == TF_DataType.DtInvalid)
            {
                args.DType = dtype;
            }
            if (args.Shape == null)
            {
                args.Shape = shape;
            }

            return(array_ops.zeros(args.Shape, dtype));
        }
示例#7
0
        public Tensor Apply(InitializerArgs args)
        {
            if (args.DType == TF_DataType.DtInvalid)
            {
                args.DType = this.dtype;
            }

            args.VerifyShape = _verify_shape;

            return(constant_op.constant(value, args.DType, args.Shape,
                                        name: "Const",
                                        verify_shape: args.VerifyShape,
                                        allow_broadcast: false));
        }
示例#8
0
 public Tensor Apply(InitializerArgs args)
 {
     throw new NotImplementedException();
 }