/// <summary> /// Schema propagation for transformers. /// Returns the output schema of the data, if the input schema is like the one provided. /// Creates three output columns if confidence intervals are requested otherwise /// just one. /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); if (!inputSchema.TryFindColumn(_options.Source, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source); } if (col.ItemType != NumberDataViewType.Single) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source, "Single", col.GetTypeString()); } var resultDic = inputSchema.ToDictionary(x => x.Name); resultDic[_options.Name] = new SchemaShape.Column( _options.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); if (!string.IsNullOrEmpty(_options.ConfidenceUpperBoundColumn)) { resultDic[_options.ConfidenceLowerBoundColumn] = new SchemaShape.Column( _options.ConfidenceLowerBoundColumn, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); resultDic[_options.ConfidenceUpperBoundColumn] = new SchemaShape.Column( _options.ConfidenceUpperBoundColumn, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); } return(new SchemaShape(resultDic.Values)); }
/// <summary> /// Checks whether this object is consistent with an actual schema shape from a dynamic object, /// throwing exceptions if not. /// </summary> /// <param name="ectx">The context on which to throw exceptions</param> /// <param name="shape">The schema shape to check</param> public void Check(IExceptionContext ectx, SchemaShape shape) { Contracts.AssertValue(ectx); ectx.AssertValue(shape); foreach (var pair in Pairs) { if (!shape.TryFindColumn(pair.Key, out var col)) { throw ectx.ExceptParam(nameof(shape), $"Column named '{pair.Key}' was not found"); } var type = GetTypeOrNull(col); if ((type != null && !pair.Value.IsAssignableFromStaticPipeline(type)) || (type == null && IsStandard(ectx, pair.Value))) { // When not null, we can use IsAssignableFrom to indicate we could assign to this, so as to allow // for example Key<uint, string> to be considered to be compatible with Key<uint>. // In the null case, while we cannot directly verify an unrecognized type, we can at least verify // that the statically declared type should not have corresponded to a recognized type. if (!pair.Value.IsAssignableFromStaticPipeline(type)) { throw ectx.ExceptParam(nameof(shape), $"Column '{pair.Key}' of type '{col.GetTypeString()}' cannot be expressed statically as type '{pair.Value}'."); } } } }
/// <summary> /// Gets the output columns. /// </summary> /// <param name="inputSchema">The input schema. </param> /// <returns>The output <see cref="SchemaShape"/></returns> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); if (LabelColumn.IsValid) { if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) { throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", LabelColumn.Name); } if (!LabelColumn.IsCompatibleWith(labelCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name, LabelColumn.GetTypeString(), labelCol.GetTypeString()); } } var outColumns = inputSchema.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) { outColumns[col.Name] = col; } return(new SchemaShape(outColumns.Values)); }
/// <summary> /// Gets the output columns. /// </summary> /// <param name="inputSchema">The input schema. </param> /// <returns>The output <see cref="SchemaShape"/></returns> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); if (LabelColumn != null) { if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) { throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel); } if (!LabelColumn.IsCompatibleWith(labelCol)) { throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible"); } } var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) { outColumns[col.Name] = col; } return(new SchemaShape(outColumns.Values)); }
public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); var resultDic = inputSchema.ToDictionary(x => x.Name); for (var i = 0; i < Transformer.Inputs.Length; i++) { var input = Transformer.Inputs[i]; if (!inputSchema.TryFindColumn(input, out var col)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString()); var inputsInfo = Transformer.Model.ModelInfo.InputsInfo; var idx = Transformer.Model.InputNames.IndexOf(input); if (idx < 0) throw Host.Except($"Column {input} doesn't match input node names of model."); var inputNodeInfo = inputsInfo[idx]; var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type); if (col.ItemType != expectedType) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); } for (var i = 0; i < Transformer.Outputs.Length; i++) { resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i], Transformer.OutputTypes[i].IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].ItemType, false); } return new SchemaShape(resultDic.Values); }
/// <summary> /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator. /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added. /// </summary> /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param> SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema) { Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string columnRole) => { if (column.IsValid) { if (!inputSchema.TryFindColumn(column.Name, out var outCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name); } if (!column.IsCompatibleWith(outCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name, column.GetTypeString(), outCol.GetTypeString()); } } }; // Check the input schema. checkColumnValid(ScoreColumn, "score"); checkColumnValid(WeightColumn, "weight"); checkColumnValid(LabelColumn, "label"); // Create the new Probability column. var outColumns = inputSchema.ToDictionary(x => x.Name); outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))); return(new SchemaShape(outColumns.Values)); }
/// <summary> /// <see cref="PretrainedTreeFeaturizationEstimator"/> adds three float-vector columns into <paramref name="inputSchema"/>. /// Given a feature vector column, the added columns are the prediction values of all trees, the leaf IDs the feature /// vector falls into, and the paths to those leaves. /// </summary> /// <param name="inputSchema">A schema which contains a feature column. Note that feature column name can be specified /// by <see cref="OptionsBase.InputColumnName"/>.</param> /// <returns>Output <see cref="SchemaShape"/> produced by <see cref="PretrainedTreeFeaturizationEstimator"/>.</returns> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Env.CheckValue(inputSchema, nameof(inputSchema)); if (!inputSchema.TryFindColumn(FeatureColumnName, out var col)) { throw Env.ExceptSchemaMismatch(nameof(inputSchema), "input", FeatureColumnName); } var result = inputSchema.ToDictionary(x => x.Name); if (TreesColumnName != null) { result[TreesColumnName] = new SchemaShape.Column(TreesColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); } if (LeavesColumnName != null) { result[LeavesColumnName] = new SchemaShape.Column(LeavesColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); } if (PathsColumnName != null) { result[PathsColumnName] = new SchemaShape.Column(PathsColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); } return(new SchemaShape(result.Values)); }
public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); for (var i = 0; i < Transformer.Inputs.Length; i++) { var input = Transformer.Inputs[i]; if (!inputSchema.TryFindColumn(input, out var col)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); } if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString()); } var expectedType = TensorFlowUtils.Tf2MlNetType(Transformer.TFInputTypes[i]); if (col.ItemType != expectedType) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); } } for (var i = 0; i < Transformer.Outputs.Length; i++) { resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i], Transformer.OutputTypes[i].IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].ItemType, false); } return(new SchemaShape(resultDic.Values)); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); foreach (var srcName in _inputColumns) { if (!inputSchema.TryFindColumn(srcName, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); } if (!col.ItemType.IsText) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, "scalar or vector of text", col.GetTypeString()); } } var metadata = new List <SchemaShape.Column>(2); metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); if (AdvancedSettings.VectorNormalizer != TextNormKind.None) { metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); } result[OutputColumn] = new SchemaShape.Column(OutputColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); if (AdvancedSettings.OutputTokens) { string name = string.Format(TransformedTextColFormat, OutputColumn); result[name] = new SchemaShape.Column(name, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false); } return(new SchemaShape(result.Values)); }
public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); var input = Transformer.Input; if (!inputSchema.TryFindColumn(input, out var col)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); } if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString()); } var inputNodeInfo = Transformer.Model.ModelInfo.InputsInfo[0]; var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type); if (col.ItemType != expectedType) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); } resultDic[Transformer.Output] = new SchemaShape.Column(Transformer.Output, Transformer.OutputType.IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.VariableVector, NumberType.R4, false); return(new SchemaShape(resultDic.Values)); }
/// <summary> /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator. /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added. /// </summary> /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param> SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema) { Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string expected) => { if (column.IsValid) { if (!inputSchema.TryFindColumn(column.Name, out var outCol)) { throw Host.Except($"{expected} column '{column.Name}' is not found"); } if (!column.IsCompatibleWith(outCol)) { throw Host.Except($"{expected} column '{column.Name}' is not compatible"); } } }; // check the input schema checkColumnValid(ScoreColumn, DefaultColumnNames.Score); checkColumnValid(WeightColumn, DefaultColumnNames.Weight); checkColumnValid(LabelColumn, DefaultColumnNames.Label); checkColumnValid(FeatureColumn, DefaultColumnNames.Features); checkColumnValid(PredictedLabel, DefaultColumnNames.PredictedLabel); //create the new Probability column var outColumns = inputSchema.ToDictionary(x => x.Name); outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))); return(new SchemaShape(outColumns.Values)); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _columns) { if (!inputSchema.TryFindColumn(colPair.Input, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input); } if (!CountFeatureSelectionUtils.IsValidColumnType(col.ItemType)) { throw _host.ExceptUserArg(nameof(inputSchema), "Column '{0}' does not have compatible type. Expected types are float, double or string.", colPair.Input); } var metadata = new List <SchemaShape.Column>(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) { metadata.Add(slotMeta); } if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.CategoricalSlotRanges, out var categoricalSlotMeta)) { metadata.Add(categoricalSlotMeta); } metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray())); } return(new SchemaShape(result.Values)); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var columnDictionary = inputSchema.ToDictionary(x => x.Name); for (int i = 0; i < _columns.Length; i++) { for (int j = 0; j < _columns[i].InputColumnNames.Length; j++) { if (!inputSchema.TryFindColumn(_columns[i].InputColumnNames[j], out var inputCol)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[i].InputColumnNames[j]); } } // Make sure there is at most one vector valued source column. var inputTypes = new DataViewType[_columns[i].InputColumnNames.Length]; var ivec = FindVectorInputColumn(_host, _columns[i].InputColumnNames, inputSchema, inputTypes); var node = ParseAndBindLambda(_host, _columns[i].Expression, ivec, inputTypes, out var perm); var typeRes = node.ResultType; _host.Assert(typeRes is PrimitiveDataViewType); // If one of the input columns is a vector column, we pass through the slot names metadata. SchemaShape.Column.VectorKind outputVectorKind; var metadata = new List <SchemaShape.Column>(); if (ivec == -1) { outputVectorKind = SchemaShape.Column.VectorKind.Scalar; } else { inputSchema.TryFindColumn(_columns[i].InputColumnNames[ivec], out var vectorCol); outputVectorKind = vectorCol.Kind; if (vectorCol.HasSlotNames()) { var b = vectorCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotNames); _host.Assert(b); metadata.Add(slotNames); } } var outputSchemaShapeColumn = new SchemaShape.Column(_columns[i].Name, outputVectorKind, typeRes, false, new SchemaShape(metadata)); columnDictionary[_columns[i].Name] = outputSchemaShapeColumn; } return(new SchemaShape(columnDictionary.Values)); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); if (!inputSchema.TryFindColumn(_labelColumnName, out var labelCol)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "label", _labelColumnName); } foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); } if (col.Kind == SchemaShape.Column.VectorKind.VariableVector) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "known-size vector or scalar", col.GetTypeString()); } if (!col.IsKey || !col.ItemType.Equals(NumberDataViewType.UInt32)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "vector or scalar of U4 key type", col.GetTypeString()); } // We supply slot names if the source is a single-value column, or if it has slot names. var newMetadataKinds = new List <SchemaShape.Column>(); if (col.Kind == SchemaShape.Column.VectorKind.Scalar) { newMetadataKinds.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)); } else if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta)) { newMetadataKinds.Add(slotMeta); } var meta = new SchemaShape(newMetadataKinds); result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, meta); } return(new SchemaShape(result.Values)); }
private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { bool success = inputSchema.TryFindColumn(LabelName, out var labelCol); Contracts.Assert(success); return(new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), }); }
private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); return(new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation())), new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation(true))), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation())) }); }
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) .Concat(MetadataUtils.GetTrainerOutputMetadata())); return new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata) }; }
/// <summary> /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. /// Used for schema propagation and verification in a pipeline. /// </summary> public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); var resultDic = inputSchema.ToDictionary(x => x.Name); // This loop checks if all input columns needed in the underlying transformer can be found // in inputSchema. // Since ML.NET can only produces tensors (scalars are converted to tensor with shape [1] before feeding // ML.NET them into ONNXRuntime), the bridge code in ONNX Transformer assumes that all inputs are tensors. for (var i = 0; i < Transformer.Inputs.Length; i++) { // Get the i-th IDataView input column's name in the underlying ONNX transformer. var input = Transformer.Inputs[i]; // Make sure inputSchema contains the i-th input column. if (!inputSchema.TryFindColumn(input, out var col)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); } // Make sure that the input columns in inputSchema are fixed shape tensors. if (col.Kind == SchemaShape.Column.VectorKind.VariableVector) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString()); } var inputsInfo = Transformer.Model.ModelInfo.InputsInfo; var idx = Transformer.Model.ModelInfo.InputNames.IndexOf(input); if (idx < 0) { throw Host.Except($"Column {input} doesn't match input node names of model."); } var inputNodeInfo = inputsInfo[idx]; var expectedType = ((VectorDataViewType)inputNodeInfo.DataViewType).ItemType; if (col.ItemType != expectedType) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); } } for (var i = 0; i < Transformer.Outputs.Length; i++) { resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i], Transformer.OutputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].GetItemType(), false); } return(new SchemaShape(resultDic.Values)); }
private void CheckInputSchema(SchemaShape inputSchema) { // Verify that all required input columns are present, and are of the same type. if (!inputSchema.TryFindColumn(FeatureColumn.Name, out var featureCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name); } if (!FeatureColumn.IsCompatibleWith(featureCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name, FeatureColumn.GetTypeString(), featureCol.GetTypeString()); } if (WeightColumn.IsValid) { if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name); } if (!WeightColumn.IsCompatibleWith(weightCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name, WeightColumn.GetTypeString(), weightCol.GetTypeString()); } } // Special treatment for label column: we allow different types of labels, so the trainers // may define their own requirements on the label column. if (LabelColumn.IsValid) { if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name); } CheckLabelCompatible(labelCol); } }
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); var predLabelMetadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues) .Concat(AnnotationUtils.GetTrainerOutputAnnotation())); return(new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, predLabelMetadata) }); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); if (!inputSchema.TryFindColumn(_args.Source, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source); if (col.ItemType != NumberType.R4) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, NumberType.R4.ToString(), col.GetTypeString()); var metadata = new List<SchemaShape.Column>() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); resultDic[_args.Name] = new SchemaShape.Column( _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); return new SchemaShape(resultDic.Values); }
/// <summary> /// Returns the schema that would be produced by the transformation. /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _infos) { if (!inputSchema.TryFindColumn(colPair.Input, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input); } var reason = VectorWhiteningTransformer.TestColumn(col.ItemType); if (reason != null) { throw _host.ExceptUserArg(nameof(inputSchema), reason); } result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, col.IsKey, null); } return(new SchemaShape(result.Values)); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); void CheckColumnsCompatible(SchemaShape.Column column, string defaultName) { if (!inputSchema.TryFindColumn(column.Name, out var col)) { throw Host.ExceptSchemaMismatch(nameof(col), defaultName, defaultName); } if (!column.IsCompatibleWith(col)) { throw Host.Except($"{defaultName} column '{column.Name}' is not compatible"); } } if (LabelColumn != null) { CheckColumnsCompatible(LabelColumn, DefaultColumnNames.Label); } foreach (var feat in FeatureColumns) { CheckColumnsCompatible(feat, DefaultColumnNames.Features); } if (WeightColumn != null) { CheckColumnsCompatible(WeightColumn, DefaultColumnNames.Weight); } var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) { outColumns[col.Name] = col; } return(new SchemaShape(outColumns.Values)); }
/// <summary> /// Schema propagation for transformers. Returns the output schema of the data, if /// the input schema is like the one provided. /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string columnRole) { if (!inputSchema.TryFindColumn(cachedColumn.Name, out var col)) { throw _host.ExceptSchemaMismatch(nameof(col), columnRole, cachedColumn.Name); } if (!cachedColumn.IsCompatibleWith(col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, cachedColumn.Name, cachedColumn.GetTypeString(), col.GetTypeString()); } } // Check if label column is good. var labelColumn = new SchemaShape.Column(LabelName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false); CheckColumnsCompatible(labelColumn, "label"); // Check if columns of matrix's row and column indexes are good. Note that column of IDataView and column of matrix are two different things. var matrixColumnIndexColumn = new SchemaShape.Column(MatrixColumnIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true); var matrixRowIndexColumn = new SchemaShape.Column(MatrixRowIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true); CheckColumnsCompatible(matrixColumnIndexColumn, "matrixColumnIndex"); CheckColumnsCompatible(matrixRowIndexColumn, "matrixRowIndex"); // Input columns just pass through so that output column dictionary contains all input columns. var outColumns = inputSchema.ToDictionary(x => x.Name); // Add columns produced by this estimator. foreach (var col in GetOutputColumnsCore(inputSchema)) { outColumns[col.Name] = col; } return(new SchemaShape(outColumns.Values)); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); } if (!(col.ItemType is TextType) || (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.Kind != SchemaShape.Column.VectorKind.Vector)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new VectorType(TextType.Instance).ToString(), col.GetTypeString()); } result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } return(new SchemaShape(result.Values)); }
/// <summary> /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. /// Used for schema propagation and verification in a pipeline. /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); } if (!col.ItemType.IsStandardScalar()) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); } SchemaShape metadata; // In the event that we are transforming something that is of type key, we will get their type of key value // metadata, unless it has none or is not vector in which case we back off to having key values over the item type. if (!col.IsKey || !col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector) { kv = new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, colInfo.TextKeyValues ? TextDataViewType.Instance : col.ItemType, col.IsKey); } Contracts.Assert(kv.IsValid); if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta)) { metadata = new SchemaShape(new[] { slotMeta, kv }); } else { metadata = new SchemaShape(new[] { kv }); } result[colInfo.OutputColumnName] = new SchemaShape.Column(colInfo.OutputColumnName, col.Kind, NumberDataViewType.UInt32, true, metadata); } return(new SchemaShape(result.Values)); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); } if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); } SchemaShape metadata; // In the event that we are transforming something that is of type key, we will get their type of key value // metadata, unless it has none or is not vector in which case we back off to having key values over the item type. if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector) { kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, col.ItemType, col.IsKey); } Contracts.AssertValue(kv); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) { metadata = new SchemaShape(new[] { slotMeta, kv }); } else { metadata = new SchemaShape(new[] { kv }); } result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata); } return(new SchemaShape(result.Values)); }
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); var scoreMetadata = new List <SchemaShape.Column>() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; scoreMetadata.AddRange(MetadataUtils.GetTrainerOutputMetadata()); var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) .Concat(MetadataUtils.GetTrainerOutputMetadata())); return(new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(scoreMetadata)), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, predLabelMetadata) }); }
private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { if (LabelColumn.IsValid) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) .Concat(MetadataForScoreColumn())); return new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataUtils.MetadataForMulticlassScoreColumn(labelCol))), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata) }; } else return new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataForScoreColumn())), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(MetadataForScoreColumn())) }; }
/// <summary> /// Schema propagation for transformers. Returns the output schema of the data, if /// the input schema is like the one provided. /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); void CheckColumnsCompatible(SchemaShape.Column column, string columnRole) { if (!inputSchema.TryFindColumn(column.Name, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name); } if (!column.IsCompatibleWith(col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name, column.GetTypeString(), col.GetTypeString()); } } CheckColumnsCompatible(LabelColumn, "label"); foreach (var feat in FeatureColumns) { CheckColumnsCompatible(feat, "feature"); } if (WeightColumn.IsValid) { CheckColumnsCompatible(WeightColumn, "weight"); } var outColumns = inputSchema.ToDictionary(x => x.Name); foreach (var col in GetOutputColumnsCore(inputSchema)) { outColumns[col.Name] = col; } return(new SchemaShape(outColumns.Values)); }