/// <summary> /// Create a constant tensor based on a shape /// Used by Zeros and Ones /// </summary> /// <param name="value">Value for tensor</param> /// <param name="tfshape">Shape of the tensor</param> /// <param name="dtype">Optional Type of the Zero value. Default: Double</param> /// <param name="operName">Operation name, optional.</param> /// <returns></returns> /// see https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/python/framework/constant_op.py public TFOutput Constant(object value, TFShape tfshape, TFDataType dtype = TFDataType.Double, string operName = null) { //convert the .net type to relevant tensorflow type object dtvalue = TFTensor.FetchSimple(dtype, value); var shape = tfshape.ToArray(); var idx = new int [shape.Length]; for (int i = 0; i < shape.Length; i++) { if (shape [i] > Int32.MaxValue) { throw new ArgumentOutOfRangeException("Shape can not be longer than 32 bits"); } } Array data = null; if (tfshape.IsLongArray) { data = Array.CreateInstance(dtvalue.GetType(), tfshape.ToArray()); } else { data = Array.CreateInstance(dtvalue.GetType(), tfshape.ToIntArray()); } TFTensor.Set(data, dtype, shape, idx, 0, value); TFTensor tensor_value = new TFTensor(data); return(Const(tensor_value, tensor_value.TensorType, operName)); }
// // Converts a shape to a tensor, to a TFOutput // TFOutput ShapeTensorOutput(TFShape shape) { if (shape.IsLongArray) { return(Const(shape.ToArray(), TFDataType.Int64)); } else { return(Const(shape.ToIntArray(), TFDataType.Int32)); } }