コード例 #1
0
ファイル: OnnxUtils.cs プロジェクト: wsf1990/machinelearning
        private static TypeProto MakeType(TypeProto typeProto, TensorProto.Types.DataType dataType,
            List<long> dims, List<bool> dimsParam)
        {
            Contracts.CheckValue(typeProto, nameof(typeProto));

            if (typeProto.TensorType == null)
                typeProto.TensorType = new TypeProto.Types.Tensor();

            typeProto.TensorType.ElemType = dataType;
            if (dims != null)
            {
                for (int index = 0; index < dims.Count; index++)
                {
                    var d = new TensorShapeProto.Types.Dimension();
                    if (typeProto.TensorType.Shape == null)
                        typeProto.TensorType.Shape = new TensorShapeProto();

                    if (dimsParam != null && dimsParam.Count > index && dimsParam[index])
                        d.DimParam = "None";
                    else
                        d.DimValue = dims[index];

                    typeProto.TensorType.Shape.Dim.Add(d);
                }
            }

            return typeProto;
        }
コード例 #2
0
ファイル: OnnxUtils.cs プロジェクト: sdg002/machinelearning
 public ModelArgs(string name, TensorProto.Types.DataType dataType, List <long> dims, List <bool> dimParams)
 {
     Name      = name;
     DataType  = dataType;
     Dims      = dims;
     DimParams = dimParams;
 }
コード例 #3
0
ファイル: OnnxUtils.cs プロジェクト: yxq9603/machinelearning
        private static AttributeProto MakeAttribute(string key, TensorProto.Types.DataType value)
        {
            AttributeProto attribute = MakeAttribute(key);

            attribute.Type = AttributeProto.Types.AttributeType.Int;
            attribute.I    = (int)value;
            return(attribute);
        }
コード例 #4
0
            private static Type ParseElementType(TensorProto.Types.DataType dataType)
            {
                switch (dataType)
                {
                case TensorProto.Types.DataType.Float:
                case TensorProto.Types.DataType.Float16:
                case TensorProto.Types.DataType.Double:
                    return(typeof(double));
                }

                throw new NotSupportedException();
            }
コード例 #5
0
ファイル: OnnxUtils.cs プロジェクト: wsf1990/machinelearning
        private static ValueInfoProto MakeValue(ValueInfoProto value, string name, TensorProto.Types.DataType dataType,
            List<long> dims, List<bool> dimsParam)
        {
            Contracts.CheckValue(value, nameof(value));
            Contracts.CheckNonEmpty(name, nameof(name));

            value.Name = name;
            if (value.Type == null)
                value.Type = new TypeProto();

            MakeType(value.Type, dataType, dims, dimsParam);
            return value;
        }
コード例 #6
0
ファイル: OnnxUtils.cs プロジェクト: sdg002/machinelearning
        public static ModelArgs GetModelArgs(ColumnType type, string colName,
                                             List <long> dims = null, List <bool> dimsParams = null)
        {
            Contracts.CheckValue(type, nameof(type));
            Contracts.CheckNonEmpty(colName, nameof(colName));

            TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
            DataKind rawKind;

            if (type is VectorType vectorType)
            {
                rawKind = vectorType.ItemType.RawKind;
            }
            else if (type is KeyType keyType)
            {
                rawKind = keyType.RawKind;
            }
            else
            {
                rawKind = type.RawKind;
            }

            switch (rawKind)
            {
            case DataKind.BL:
                dataType = TensorProto.Types.DataType.Float;
                break;

            case DataKind.TX:
                dataType = TensorProto.Types.DataType.String;
                break;

            case DataKind.I1:
                dataType = TensorProto.Types.DataType.Int8;
                break;

            case DataKind.U1:
                dataType = TensorProto.Types.DataType.Uint8;
                break;

            case DataKind.I2:
                dataType = TensorProto.Types.DataType.Int16;
                break;

            case DataKind.U2:
                dataType = TensorProto.Types.DataType.Uint16;
                break;

            case DataKind.I4:
                dataType = TensorProto.Types.DataType.Int32;
                break;

            case DataKind.U4:
                dataType = TensorProto.Types.DataType.Int64;
                break;

            case DataKind.I8:
                dataType = TensorProto.Types.DataType.Int64;
                break;

            case DataKind.U8:
                dataType = TensorProto.Types.DataType.Uint64;
                break;

            case DataKind.R4:
                dataType = TensorProto.Types.DataType.Float;
                break;

            case DataKind.R8:
                dataType = TensorProto.Types.DataType.Double;
                break;

            default:
                string msg = "Unsupported type: DataKind " + rawKind.ToString();
                Contracts.Check(false, msg);
                break;
            }

            string      name           = colName;
            List <long> dimsLocal      = null;
            List <bool> dimsParamLocal = null;

            if (dims != null)
            {
                dimsLocal      = dims;
                dimsParamLocal = dimsParams;
            }
            else
            {
                dimsLocal = new List <long>();
                if (type.ValueCount == 0) //Unknown size.
                {
                    dimsLocal.Add(1);
                    dimsParamLocal = new List <bool>()
                    {
                        false, true
                    };                                                 //false for batch size, true for dims.
                }
                else if (type.ValueCount == 1)
                {
                    dimsLocal.Add(1);
                }
                else if (type.ValueCount > 1)
                {
                    var vec = (VectorType)type;
                    for (int i = 0; i < vec.Dimensions.Length; i++)
                    {
                        dimsLocal.Add(vec.Dimensions[i]);
                    }
                }
            }
            //batch size.
            dimsLocal?.Insert(0, 1);

            return(new ModelArgs(name, dataType, dimsLocal, dimsParamLocal));
        }
コード例 #7
0
        public static ModelArgs GetModelArgs(DataViewType type, string colName,
                                             List <long> dims = null, List <bool> dimsParams = null)
        {
            Contracts.CheckValue(type, nameof(type));
            Contracts.CheckNonEmpty(colName, nameof(colName));

            TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
            Type rawType;

            if (type is VectorType vectorType)
            {
                rawType = vectorType.ItemType.RawType;
            }
            else
            {
                rawType = type.RawType;
            }

            if (rawType == typeof(bool))
            {
                dataType = TensorProto.Types.DataType.Float;
            }
            else if (rawType == typeof(ReadOnlyMemory <char>))
            {
                dataType = TensorProto.Types.DataType.String;
            }
            else if (rawType == typeof(sbyte))
            {
                dataType = TensorProto.Types.DataType.Int8;
            }
            else if (rawType == typeof(byte))
            {
                dataType = TensorProto.Types.DataType.Uint8;
            }
            else if (rawType == typeof(short))
            {
                dataType = TensorProto.Types.DataType.Int16;
            }
            else if (rawType == typeof(ushort))
            {
                dataType = TensorProto.Types.DataType.Uint16;
            }
            else if (rawType == typeof(int))
            {
                dataType = TensorProto.Types.DataType.Int32;
            }
            else if (rawType == typeof(uint))
            {
                dataType = TensorProto.Types.DataType.Int64;
            }
            else if (rawType == typeof(long))
            {
                dataType = TensorProto.Types.DataType.Int64;
            }
            else if (rawType == typeof(ulong))
            {
                dataType = TensorProto.Types.DataType.Uint64;
            }
            else if (rawType == typeof(float))
            {
                dataType = TensorProto.Types.DataType.Float;
            }
            else if (rawType == typeof(double))
            {
                dataType = TensorProto.Types.DataType.Double;
            }
            else
            {
                string msg = "Unsupported type: " + type.ToString();
                Contracts.Check(false, msg);
            }

            string      name           = colName;
            List <long> dimsLocal      = null;
            List <bool> dimsParamLocal = null;

            if (dims != null)
            {
                dimsLocal      = dims;
                dimsParamLocal = dimsParams;
            }
            else
            {
                dimsLocal = new List <long>();
                int valueCount = type.GetValueCount();
                if (valueCount == 0) //Unknown size.
                {
                    dimsLocal.Add(1);
                    dimsParamLocal = new List <bool>()
                    {
                        false, true
                    };                                                 //false for batch size, true for dims.
                }
                else if (valueCount == 1)
                {
                    dimsLocal.Add(1);
                }
                else if (valueCount > 1)
                {
                    var vec = (VectorType)type;
                    for (int i = 0; i < vec.Dimensions.Length; i++)
                    {
                        dimsLocal.Add(vec.Dimensions[i]);
                    }
                }
            }
            //batch size.
            dimsLocal?.Insert(0, 1);

            return(new ModelArgs(name, dataType, dimsLocal, dimsParamLocal));
        }
コード例 #8
0
        public static ModelArgs GetModelArgs(ColumnType type, string colName,
                                             List <long> dims = null, List <bool> dimsParams = null)
        {
            Contracts.CheckValue(type, nameof(type));
            Contracts.CheckNonEmpty(colName, nameof(colName));

            TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
            DataKind rawKind;

            if (type.IsVector)
            {
                rawKind = type.AsVector.ItemType.RawKind;
            }
            else if (type.IsKey)
            {
                rawKind = type.AsKey.RawKind;
            }
            else
            {
                rawKind = type.RawKind;
            }

            switch (rawKind)
            {
            case DataKind.BL:
                dataType = TensorProto.Types.DataType.Float;
                break;

            case DataKind.TX:
                dataType = TensorProto.Types.DataType.String;
                break;

            case DataKind.U4:
                dataType = TensorProto.Types.DataType.Int64;
                break;

            case DataKind.R4:
                dataType = TensorProto.Types.DataType.Float;
                break;

            default:
                Contracts.Assert(false, "Unknown type.");
                break;
            }

            string      name           = colName;
            List <long> dimsLocal      = null;
            List <bool> dimsParamLocal = null;

            if (dims != null)
            {
                dimsLocal      = dims;
                dimsParamLocal = dimsParams;
            }
            else
            {
                dimsLocal = new List <long>();
                if (type.ValueCount == 0) //Unknown size.
                {
                    dimsLocal.Add(1);
                    dimsParamLocal = new List <bool>()
                    {
                        false, true
                    };                                                 //false for batch size, true for dims.
                }
                else if (type.ValueCount == 1)
                {
                    dimsLocal.Add(1);
                }
                else if (type.ValueCount > 1)
                {
                    var vec = type.AsVector;
                    for (int i = 0; i < vec.DimCount; i++)
                    {
                        dimsLocal.Add(vec.GetDim(i));
                    }
                }
            }
            //batch size.
            dimsLocal?.Insert(0, 1);

            return(new ModelArgs(name, dataType, dimsLocal, dimsParamLocal));
        }