Пример #1
0
        void ExpectMeta(TFOperation op, string name, int expectedListSize, TFAttributeType expectedType, int expectedTotalSize)
        {
            var meta = op.GetAttributeMetadata(name);

            Assert(meta.IsList == (expectedListSize >= 0 ? 1 : 0));
            Assert(expectedListSize == meta.ListSize);
            Assert(expectedTotalSize == expectedTotalSize);
            Assert(expectedType == meta.Type);
        }
Пример #2
0
        private Tensor GetOpMetadata(TFOperation op)
        {
            TFStatus status = new TFStatus();

            // Query the shape
            long[] shape      = null;
            var    shape_attr = op.GetAttributeMetadata("shape", status);

            if (!status.Ok || shape_attr.TotalSize <= 0)
            {
                Debug.LogWarning("Operation " + op.Name + " does not contain shape attribute or it" +
                                 " doesn't contain valid shape data!");
            }
            else
            {
                if (shape_attr.IsList)
                {
                    throw new NotImplementedException("Querying lists is not implemented yet!");
                }
                else
                {
                    TFStatus s    = new TFStatus();
                    long[]   dims = new long[shape_attr.TotalSize];
                    TF_OperationGetAttrShape(op.Handle, "shape", dims, (int)shape_attr.TotalSize,
                                             s.Handle);
                    if (!status.Ok)
                    {
                        throw new FormatException("Could not query model for op shape (" + op.Name + ")");
                    }
                    else
                    {
                        shape = new long[dims.Length];
                        for (int i = 0; i < shape_attr.TotalSize; ++i)
                        {
                            if (dims[i] == -1)
                            {
                                // we have to use batchsize 1
                                shape[i] = 1;
                            }
                            else
                            {
                                shape[i] = dims[i];
                            }
                        }
                    }
                }
            }

            // Query the data type
            TFDataType type_value = new TFDataType();

            unsafe
            {
                TFStatus s = new TFStatus();
                TF_OperationGetAttrType(op.Handle, "dtype", &type_value, s.Handle);
                if (!s.Ok)
                {
                    Debug.LogWarning("Operation " + op.Name +
                                     ": error retrieving dtype, assuming float!");
                    type_value = TFDataType.Float;
                }
            }

            Tensor.TensorType placeholder_type = Tensor.TensorType.FloatingPoint;
            switch (type_value)
            {
            case TFDataType.Float:
                placeholder_type = Tensor.TensorType.FloatingPoint;
                break;

            case TFDataType.Int32:
                placeholder_type = Tensor.TensorType.Integer;
                break;

            default:
                Debug.LogWarning("Operation " + op.Name +
                                 " is not a float/integer. Proceed at your own risk!");
                break;
            }

            Tensor t = new Tensor
            {
                Data      = null,
                Name      = op.Name,
                Shape     = shape,
                ValueType = placeholder_type
            };

            return(t);
        }