Exemple #1
0
        /// <summary>
        /// Create a constant tensor based on a shape
        /// Used by Zeros and Ones
        /// </summary>
        /// <param name="value">Value for tensor</param>
        /// <param name="tfshape">Shape of the tensor</param>
        /// <param name="dtype">Optional Type of the Zero value. Default: Double</param>
        /// <param name="operName">Operation name, optional.</param>
        /// <returns></returns>
        /// see https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/python/framework/constant_op.py
        public TFOutput Constant(object value, TFShape tfshape, TFDataType dtype = TFDataType.Double, string operName = null)
        {
            //convert the .net type to relevant tensorflow type
            object dtvalue = TFTensor.FetchSimple(dtype, value);

            var shape = tfshape.ToArray();
            var idx   = new int [shape.Length];

            for (int i = 0; i < shape.Length; i++)
            {
                if (shape [i] > Int32.MaxValue)
                {
                    throw new ArgumentOutOfRangeException("Shape can not be longer than 32 bits");
                }
            }

            Array data = null;

            if (tfshape.IsLongArray)
            {
                data = Array.CreateInstance(dtvalue.GetType(), tfshape.ToArray());
            }
            else
            {
                data = Array.CreateInstance(dtvalue.GetType(), tfshape.ToIntArray());
            }

            TFTensor.Set(data, dtype, shape, idx, 0, value);

            TFTensor tensor_value = new TFTensor(data);

            return(Const(tensor_value, tensor_value.TensorType, operName));
        }
Exemple #2
0
 //
 // Converts a shape to a tensor, to a TFOutput
 //
 TFOutput ShapeTensorOutput(TFShape shape)
 {
     if (shape.IsLongArray)
     {
         return(Const(shape.ToArray(), TFDataType.Int64));
     }
     else
     {
         return(Const(shape.ToIntArray(), TFDataType.Int32));
     }
 }