private void ValidateInputs(out int indexLabel, out int indexTime, out ColumnType typeLabel, out ColumnType typeTime) { if (!Source.Schema.TryGetColumnIndex(_args.columns[0].Source, out indexLabel)) { throw Host.Except("InputColumn does not belong the input schema."); } typeLabel = Source.Schema.GetColumnType(indexLabel); if (typeLabel.IsVector()) { if (typeLabel.AsVector().DimCount() != 1 || typeLabel.AsVector().GetDim(0) != 1) { throw Host.ExceptNotImpl("Not implemented yet for multiple dimensions."); } } if (typeLabel.RawKind() != DataKind.R4) { throw Host.ExceptNotImpl("InputColumn must be R4."); } if (!Source.Schema.TryGetColumnIndex(_args.timeColumn, out indexTime)) { throw Host.Except("Time Column does not belong the input schema."); } typeTime = Source.Schema.GetColumnType(indexTime); if (typeTime.RawKind() != DataKind.R4) { throw Host.ExceptNotImpl("Time columne must be R4."); } }
public static Tuple <DataKind, ArrayKind> GetKindArray(ColumnType type) { if (type.IsVector()) { int dc = type.AsVector().DimCount(); return(new Tuple <DataKind, ArrayKind>(type.ItemType().RawKind(), dc == 1 && type.AsVector().GetDim(0) > 0 ? ArrayKind.Array : ArrayKind.VBuffer)); } else { return(new Tuple <DataKind, ArrayKind>(type.RawKind(), ArrayKind.None)); } }
/// <summary> /// Saves a type into a stream. /// </summary> public static void WriteType(ModelSaveContext ctx, ColumnType type) { ctx.Writer.Write(type.IsVector()); if (type.IsVector()) { ctx.Writer.Write(type.AsVector().DimCount()); for (int i = 0; i < type.AsVector().DimCount(); ++i) { ctx.Writer.Write(type.AsVector().GetDim(i)); } ctx.Writer.Write((byte)type.AsVector().ItemType().RawKind()); } else if (type.IsKey()) { throw Contracts.ExceptNotImpl("Key cannot be serialized yet."); } else { ctx.Writer.Write((byte)type.RawKind()); } }
private void ValidateInputs(out int indexLabel, out int indexTime, out ColumnType typeLabel, out ColumnType typeTime) { indexLabel = SchemaHelper.GetColumnIndex(Source.Schema, _args.columns[0].Source); typeLabel = Source.Schema[indexLabel].Type; if (typeLabel.IsVector()) { if (typeLabel.AsVector().DimCount() != 1 || typeLabel.AsVector().GetDim(0) != 1) { throw Host.ExceptNotImpl("Not implemented yet for multiple dimensions."); } } if (typeLabel.RawKind() != DataKind.R4) { throw Host.ExceptNotImpl("InputColumn must be R4."); } indexTime = SchemaHelper.GetColumnIndex(Source.Schema, _args.timeColumn); typeTime = Source.Schema[indexTime].Type; if (typeTime.RawKind() != DataKind.R4) { throw Host.ExceptNotImpl("Time columne must be R4."); } }
public static ModelArgs GetModelArgs(ColumnType type, string colName, List <long> dims = null, List <bool> dimsParams = null) { Contracts.CheckValue(type, nameof(type)); Contracts.CheckNonEmpty(colName, nameof(colName)); TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined; DataKind rawKind; if (type.IsVector()) { rawKind = type.AsVector().ItemType().RawKind(); } else if (type.IsKey()) { rawKind = type.AsKey().RawKind(); } else { rawKind = type.RawKind(); } switch (rawKind) { case DataKind.BL: dataType = TensorProto.Types.DataType.Float; break; case DataKind.TX: dataType = TensorProto.Types.DataType.String; break; case DataKind.I1: dataType = TensorProto.Types.DataType.Int8; break; case DataKind.U1: dataType = TensorProto.Types.DataType.Uint8; break; case DataKind.I2: dataType = TensorProto.Types.DataType.Int16; break; case DataKind.U2: dataType = TensorProto.Types.DataType.Uint16; break; case DataKind.I4: dataType = TensorProto.Types.DataType.Int32; break; case DataKind.U4: dataType = TensorProto.Types.DataType.Int64; break; case DataKind.I8: dataType = TensorProto.Types.DataType.Int64; break; case DataKind.U8: dataType = TensorProto.Types.DataType.Uint64; break; case DataKind.R4: dataType = TensorProto.Types.DataType.Float; break; case DataKind.R8: dataType = TensorProto.Types.DataType.Double; break; default: string msg = "Unsupported type: DataKind " + rawKind.ToString(); Contracts.Check(false, msg); break; } string name = colName; List <long> dimsLocal = null; List <bool> dimsParamLocal = null; if (dims != null) { dimsLocal = dims; dimsParamLocal = dimsParams; } else { dimsLocal = new List <long>(); if (type.ValueCount() == 0) //Unknown size. { dimsLocal.Add(1); dimsParamLocal = new List <bool>() { false, true }; //false for batch size, true for dims. } else if (type.ValueCount() == 1) { dimsLocal.Add(1); } else if (type.ValueCount() > 1) { var vec = type.AsVector(); for (int i = 0; i < vec.DimCount(); i++) { dimsLocal.Add(vec.GetDim(i)); } } } //batch size. dimsLocal?.Insert(0, 1); return(new ModelArgs(name, dataType, dimsLocal, dimsParamLocal)); }