/// <summary>
        /// Schema propagation for transformers.
        /// Returns the output schema of the data, if the input schema is like the one provided.
        /// Creates three output columns if confidence intervals are requested otherwise
        /// just one.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(_options.Source, out var col))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source);
            }
            if (col.ItemType != NumberDataViewType.Single)
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source, "Single", col.GetTypeString());
            }

            var resultDic = inputSchema.ToDictionary(x => x.Name);

            resultDic[_options.Name] = new SchemaShape.Column(
                _options.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);

            if (!string.IsNullOrEmpty(_options.ConfidenceUpperBoundColumn))
            {
                resultDic[_options.ConfidenceLowerBoundColumn] = new SchemaShape.Column(
                    _options.ConfidenceLowerBoundColumn, SchemaShape.Column.VectorKind.Vector,
                    NumberDataViewType.Single, false);

                resultDic[_options.ConfidenceUpperBoundColumn] = new SchemaShape.Column(
                    _options.ConfidenceUpperBoundColumn, SchemaShape.Column.VectorKind.Vector,
                    NumberDataViewType.Single, false);
            }

            return(new SchemaShape(resultDic.Values));
        }
        /// <summary>
        /// Checks whether this object is consistent with an actual schema shape from a dynamic object,
        /// throwing exceptions if not.
        /// </summary>
        /// <param name="ectx">The context on which to throw exceptions</param>
        /// <param name="shape">The schema shape to check</param>
        public void Check(IExceptionContext ectx, SchemaShape shape)
        {
            Contracts.AssertValue(ectx);
            ectx.AssertValue(shape);

            foreach (var pair in Pairs)
            {
                if (!shape.TryFindColumn(pair.Key, out var col))
                {
                    throw ectx.ExceptParam(nameof(shape), $"Column named '{pair.Key}' was not found");
                }
                var type = GetTypeOrNull(col);
                if ((type != null && !pair.Value.IsAssignableFromStaticPipeline(type)) || (type == null && IsStandard(ectx, pair.Value)))
                {
                    // When not null, we can use IsAssignableFrom to indicate we could assign to this, so as to allow
                    // for example Key<uint, string> to be considered to be compatible with Key<uint>.

                    // In the null case, while we cannot directly verify an unrecognized type, we can at least verify
                    // that the statically declared type should not have corresponded to a recognized type.
                    if (!pair.Value.IsAssignableFromStaticPipeline(type))
                    {
                        throw ectx.ExceptParam(nameof(shape),
                                               $"Column '{pair.Key}' of type '{col.GetTypeString()}' cannot be expressed statically as type '{pair.Value}'.");
                    }
                }
            }
        }
        /// <summary>
        ///  Gets the output columns.
        /// </summary>
        /// <param name="inputSchema">The input schema. </param>
        /// <returns>The output <see cref="SchemaShape"/></returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", LabelColumn.Name);
                }

                if (!LabelColumn.IsCompatibleWith(labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name, LabelColumn.GetTypeString(), labelCol.GetTypeString());
                }
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
        /// <summary>
        ///  Gets the output columns.
        /// </summary>
        /// <param name="inputSchema">The input schema. </param>
        /// <returns>The output <see cref="SchemaShape"/></returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (LabelColumn != null)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel);
                }

                if (!LabelColumn.IsCompatibleWith(labelCol))
                {
                    throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
                }
            }

            var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            for (var i = 0; i < Transformer.Inputs.Length; i++)
            {
                var input = Transformer.Inputs[i];
                if (!inputSchema.TryFindColumn(input, out var col))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());

                var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
                var idx = Transformer.Model.InputNames.IndexOf(input);
                if (idx < 0)
                    throw Host.Except($"Column {input} doesn't match input node names of model.");

                var inputNodeInfo = inputsInfo[idx];
                var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
                if (col.ItemType != expectedType)
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
            }

            for (var i = 0; i < Transformer.Outputs.Length; i++)
            {
                resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
                    Transformer.OutputTypes[i].IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].ItemType, false);
            }
            return new SchemaShape(resultDic.Values);
        }
示例#6
0
        /// <summary>
        /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
        /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
        /// </summary>
        /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
        SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema)
        {
            Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string columnRole) =>
            {
                if (column.IsValid)
                {
                    if (!inputSchema.TryFindColumn(column.Name, out var outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                    }
                    if (!column.IsCompatibleWith(outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name, column.GetTypeString(), outCol.GetTypeString());
                    }
                }
            };

            // Check the input schema.
            checkColumnValid(ScoreColumn, "score");
            checkColumnValid(WeightColumn, "weight");
            checkColumnValid(LabelColumn, "label");

            // Create the new Probability column.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
                                                                                SchemaShape.Column.VectorKind.Scalar,
                                                                                NumberType.R4,
                                                                                false,
                                                                                new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true)));

            return(new SchemaShape(outColumns.Values));
        }
        /// <summary>
        /// <see cref="PretrainedTreeFeaturizationEstimator"/> adds three float-vector columns into <paramref name="inputSchema"/>.
        /// Given a feature vector column, the added columns are the prediction values of all trees, the leaf IDs the feature
        /// vector falls into, and the paths to those leaves.
        /// </summary>
        /// <param name="inputSchema">A schema which contains a feature column. Note that feature column name can be specified
        /// by <see cref="OptionsBase.InputColumnName"/>.</param>
        /// <returns>Output <see cref="SchemaShape"/> produced by <see cref="PretrainedTreeFeaturizationEstimator"/>.</returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Env.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(FeatureColumnName, out var col))
            {
                throw Env.ExceptSchemaMismatch(nameof(inputSchema), "input", FeatureColumnName);
            }

            var result = inputSchema.ToDictionary(x => x.Name);

            if (TreesColumnName != null)
            {
                result[TreesColumnName] = new SchemaShape.Column(TreesColumnName,
                                                                 SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            if (LeavesColumnName != null)
            {
                result[LeavesColumnName] = new SchemaShape.Column(LeavesColumnName,
                                                                  SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            if (PathsColumnName != null)
            {
                result[PathsColumnName] = new SchemaShape.Column(PathsColumnName,
                                                                 SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            return(new SchemaShape(result.Values));
        }
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result    = inputSchema.Columns.ToDictionary(x => x.Name);
            var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);

            for (var i = 0; i < Transformer.Inputs.Length; i++)
            {
                var input = Transformer.Inputs[i];
                if (!inputSchema.TryFindColumn(input, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                }
                if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());
                }
                var expectedType = TensorFlowUtils.Tf2MlNetType(Transformer.TFInputTypes[i]);
                if (col.ItemType != expectedType)
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
                }
            }
            for (var i = 0; i < Transformer.Outputs.Length; i++)
            {
                resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
                                                                           Transformer.OutputTypes[i].IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].ItemType, false);
            }
            return(new SchemaShape(resultDic.Values));
        }
示例#9
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var srcName in _inputColumns)
            {
                if (!inputSchema.TryFindColumn(srcName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName);
                }
                if (!col.ItemType.IsText)
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, "scalar or vector of text", col.GetTypeString());
                }
            }

            var metadata = new List <SchemaShape.Column>(2);

            metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
            if (AdvancedSettings.VectorNormalizer != TextNormKind.None)
            {
                metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false));
            }

            result[OutputColumn] = new SchemaShape.Column(OutputColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false,
                                                          new SchemaShape(metadata));
            if (AdvancedSettings.OutputTokens)
            {
                string name = string.Format(TransformedTextColFormat, OutputColumn);
                result[name] = new SchemaShape.Column(name, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false);
            }

            return(new SchemaShape(result.Values));
        }
示例#10
0
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result    = inputSchema.Columns.ToDictionary(x => x.Name);
            var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);

            var input = Transformer.Input;

            if (!inputSchema.TryFindColumn(input, out var col))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
            }
            if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());
            }
            var inputNodeInfo = Transformer.Model.ModelInfo.InputsInfo[0];
            var expectedType  = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);

            if (col.ItemType != expectedType)
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
            }

            resultDic[Transformer.Output] = new SchemaShape.Column(Transformer.Output,
                                                                   Transformer.OutputType.IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector
                : SchemaShape.Column.VectorKind.VariableVector, NumberType.R4, false);

            return(new SchemaShape(resultDic.Values));
        }
示例#11
0
        /// <summary>
        /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
        /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
        /// </summary>
        /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
        SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema)
        {
            Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string expected) =>
            {
                if (column.IsValid)
                {
                    if (!inputSchema.TryFindColumn(column.Name, out var outCol))
                    {
                        throw Host.Except($"{expected} column '{column.Name}' is not found");
                    }
                    if (!column.IsCompatibleWith(outCol))
                    {
                        throw Host.Except($"{expected} column '{column.Name}' is not compatible");
                    }
                }
            };

            // check the input schema
            checkColumnValid(ScoreColumn, DefaultColumnNames.Score);
            checkColumnValid(WeightColumn, DefaultColumnNames.Weight);
            checkColumnValid(LabelColumn, DefaultColumnNames.Label);
            checkColumnValid(FeatureColumn, DefaultColumnNames.Features);
            checkColumnValid(PredictedLabel, DefaultColumnNames.PredictedLabel);

            //create the new Probability column
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
                                                                                SchemaShape.Column.VectorKind.Scalar,
                                                                                NumberType.R4,
                                                                                false,
                                                                                new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true)));

            return(new SchemaShape(outColumns.Values));
        }
示例#12
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colPair in _columns)
            {
                if (!inputSchema.TryFindColumn(colPair.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input);
                }
                if (!CountFeatureSelectionUtils.IsValidColumnType(col.ItemType))
                {
                    throw _host.ExceptUserArg(nameof(inputSchema), "Column '{0}' does not have compatible type. Expected types are float, double or string.", colPair.Input);
                }
                var metadata = new List <SchemaShape.Column>();
                if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata.Add(slotMeta);
                }
                if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.CategoricalSlotRanges, out var categoricalSlotMeta))
                {
                    metadata.Add(categoricalSlotMeta);
                }
                metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false));
                result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray()));
            }
            return(new SchemaShape(result.Values));
        }
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var columnDictionary = inputSchema.ToDictionary(x => x.Name);

            for (int i = 0; i < _columns.Length; i++)
            {
                for (int j = 0; j < _columns[i].InputColumnNames.Length; j++)
                {
                    if (!inputSchema.TryFindColumn(_columns[i].InputColumnNames[j], out var inputCol))
                    {
                        throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[i].InputColumnNames[j]);
                    }
                }

                // Make sure there is at most one vector valued source column.
                var inputTypes = new DataViewType[_columns[i].InputColumnNames.Length];
                var ivec       = FindVectorInputColumn(_host, _columns[i].InputColumnNames, inputSchema, inputTypes);
                var node       = ParseAndBindLambda(_host, _columns[i].Expression, ivec, inputTypes, out var perm);

                var typeRes = node.ResultType;
                _host.Assert(typeRes is PrimitiveDataViewType);

                // If one of the input columns is a vector column, we pass through the slot names metadata.
                SchemaShape.Column.VectorKind outputVectorKind;
                var metadata = new List <SchemaShape.Column>();
                if (ivec == -1)
                {
                    outputVectorKind = SchemaShape.Column.VectorKind.Scalar;
                }
                else
                {
                    inputSchema.TryFindColumn(_columns[i].InputColumnNames[ivec], out var vectorCol);
                    outputVectorKind = vectorCol.Kind;
                    if (vectorCol.HasSlotNames())
                    {
                        var b = vectorCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotNames);
                        _host.Assert(b);
                        metadata.Add(slotNames);
                    }
                }
                var outputSchemaShapeColumn = new SchemaShape.Column(_columns[i].Name, outputVectorKind, typeRes, false, new SchemaShape(metadata));
                columnDictionary[_columns[i].Name] = outputSchemaShapeColumn;
            }
            return(new SchemaShape(columnDictionary.Values));
        }
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            if (!inputSchema.TryFindColumn(_labelColumnName, out var labelCol))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "label", _labelColumnName);
            }

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }
                if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "known-size vector or scalar", col.GetTypeString());
                }

                if (!col.IsKey || !col.ItemType.Equals(NumberDataViewType.UInt32))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "vector or scalar of U4 key type", col.GetTypeString());
                }

                // We supply slot names if the source is a single-value column, or if it has slot names.
                var newMetadataKinds = new List <SchemaShape.Column>();
                if (col.Kind == SchemaShape.Column.VectorKind.Scalar)
                {
                    newMetadataKinds.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
                }
                else if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                {
                    newMetadataKinds.Add(slotMeta);
                }
                var meta = new SchemaShape(newMetadataKinds);
                result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, meta);
            }

            return(new SchemaShape(result.Values));
        }
        private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            bool success = inputSchema.TryFindColumn(LabelName, out var labelCol);

            Contracts.Assert(success);

            return(new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
            });
        }
示例#16
0
        private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);

            Contracts.Assert(success);

            return(new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation())),
                new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation(true))),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation()))
            });
        }
        protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
            Contracts.Assert(success);

            var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
                .Concat(MetadataUtils.GetTrainerOutputMetadata()));
            return new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata)
            };
        }
示例#18
0
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result    = inputSchema.ToDictionary(x => x.Name);
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            // This loop checks if all input columns needed in the underlying transformer can be found
            // in inputSchema.
            // Since ML.NET can only produces tensors (scalars are converted to tensor with shape [1] before feeding
            // ML.NET them into ONNXRuntime), the bridge code in ONNX Transformer assumes that all inputs are tensors.
            for (var i = 0; i < Transformer.Inputs.Length; i++)
            {
                // Get the i-th IDataView input column's name in the underlying ONNX transformer.
                var input = Transformer.Inputs[i];

                // Make sure inputSchema contains the i-th input column.
                if (!inputSchema.TryFindColumn(input, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                }

                // Make sure that the input columns in inputSchema are fixed shape tensors.
                if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
                }

                var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
                var idx        = Transformer.Model.ModelInfo.InputNames.IndexOf(input);
                if (idx < 0)
                {
                    throw Host.Except($"Column {input} doesn't match input node names of model.");
                }

                var inputNodeInfo = inputsInfo[idx];
                var expectedType  = ((VectorDataViewType)inputNodeInfo.DataViewType).ItemType;
                if (col.ItemType != expectedType)
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
                }
            }

            for (var i = 0; i < Transformer.Outputs.Length; i++)
            {
                resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
                                                                           Transformer.OutputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].GetItemType(), false);
            }

            return(new SchemaShape(resultDic.Values));
        }
        private void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            if (!inputSchema.TryFindColumn(FeatureColumn.Name, out var featureCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name);
            }
            if (!FeatureColumn.IsCompatibleWith(featureCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name,
                                                FeatureColumn.GetTypeString(), featureCol.GetTypeString());
            }

            if (WeightColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name);
                }
                if (!WeightColumn.IsCompatibleWith(weightCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name,
                                                    WeightColumn.GetTypeString(), weightCol.GetTypeString());
                }
            }

            // Special treatment for label column: we allow different types of labels, so the trainers
            // may define their own requirements on the label column.
            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name);
                }
                CheckLabelCompatible(labelCol);
            }
        }
        private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);

            Contracts.Assert(success);

            var predLabelMetadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
                                                    .Concat(AnnotationUtils.GetTrainerOutputAnnotation()));

            return(new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, predLabelMetadata)
            });
        }
示例#21
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(_args.Source, out var col))
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source);
            if (col.ItemType != NumberType.R4)
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, NumberType.R4.ToString(), col.GetTypeString());

            var metadata = new List<SchemaShape.Column>() {
                new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
            };
            var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
            resultDic[_args.Name] = new SchemaShape.Column(
                _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));

            return new SchemaShape(resultDic.Values);
        }
示例#22
0
        /// <summary>
        /// Returns the schema that would be produced by the transformation.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colPair in _infos)
            {
                if (!inputSchema.TryFindColumn(colPair.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input);
                }
                var reason = VectorWhiteningTransformer.TestColumn(col.ItemType);
                if (reason != null)
                {
                    throw _host.ExceptUserArg(nameof(inputSchema), reason);
                }
                result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, col.IsKey, null);
            }
            return(new SchemaShape(result.Values));
        }
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column column, string defaultName)
            {
                if (!inputSchema.TryFindColumn(column.Name, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(col), defaultName, defaultName);
                }

                if (!column.IsCompatibleWith(col))
                {
                    throw Host.Except($"{defaultName} column '{column.Name}' is not compatible");
                }
            }

            if (LabelColumn != null)
            {
                CheckColumnsCompatible(LabelColumn, DefaultColumnNames.Label);
            }

            foreach (var feat in FeatureColumns)
            {
                CheckColumnsCompatible(feat, DefaultColumnNames.Features);
            }

            if (WeightColumn != null)
            {
                CheckColumnsCompatible(WeightColumn, DefaultColumnNames.Weight);
            }

            var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
示例#24
0
        /// <summary>
        /// Schema propagation for transformers. Returns the output schema of the data, if
        /// the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string columnRole)
            {
                if (!inputSchema.TryFindColumn(cachedColumn.Name, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(col), columnRole, cachedColumn.Name);
                }

                if (!cachedColumn.IsCompatibleWith(col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, cachedColumn.Name,
                                                     cachedColumn.GetTypeString(), col.GetTypeString());
                }
            }

            // Check if label column is good.
            var labelColumn = new SchemaShape.Column(LabelName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false);

            CheckColumnsCompatible(labelColumn, "label");

            // Check if columns of matrix's row and column indexes are good. Note that column of IDataView and column of matrix are two different things.
            var matrixColumnIndexColumn = new SchemaShape.Column(MatrixColumnIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);
            var matrixRowIndexColumn    = new SchemaShape.Column(MatrixRowIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);

            CheckColumnsCompatible(matrixColumnIndexColumn, "matrixColumnIndex");
            CheckColumnsCompatible(matrixRowIndexColumn, "matrixRowIndex");

            // Input columns just pass through so that output column dictionary contains all input columns.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            // Add columns produced by this estimator.
            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
示例#25
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
                }
                if (!(col.ItemType is TextType) || (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.Kind != SchemaShape.Column.VectorKind.Vector))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new VectorType(TextType.Instance).ToString(), col.GetTypeString());
                }

                result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
            }

            return(new SchemaShape(result.Values));
        }
示例#26
0
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }

                if (!col.ItemType.IsStandardScalar())
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }
                SchemaShape metadata;

                // In the event that we are transforming something that is of type key, we will get their type of key value
                // metadata, unless it has none or is not vector in which case we back off to having key values over the item type.
                if (!col.IsKey || !col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
                {
                    kv = new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                                                colInfo.TextKeyValues ? TextDataViewType.Instance : col.ItemType, col.IsKey);
                }
                Contracts.Assert(kv.IsValid);

                if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata = new SchemaShape(new[] { slotMeta, kv });
                }
                else
                {
                    metadata = new SchemaShape(new[] { kv });
                }
                result[colInfo.OutputColumnName] = new SchemaShape.Column(colInfo.OutputColumnName, col.Kind, NumberDataViewType.UInt32, true, metadata);
            }

            return(new SchemaShape(result.Values));
        }
示例#27
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
                }

                if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
                }
                SchemaShape metadata;

                // In the event that we are transforming something that is of type key, we will get their type of key value
                // metadata, unless it has none or is not vector in which case we back off to having key values over the item type.
                if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
                {
                    kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                                                col.ItemType, col.IsKey);
                }
                Contracts.AssertValue(kv);

                if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata = new SchemaShape(new[] { slotMeta, kv });
                }
                else
                {
                    metadata = new SchemaShape(new[] { kv });
                }
                result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata);
            }

            return(new SchemaShape(result.Values));
        }
        protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);

            Contracts.Assert(success);

            var scoreMetadata = new List <SchemaShape.Column>()
            {
                new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
            };

            scoreMetadata.AddRange(MetadataUtils.GetTrainerOutputMetadata());

            var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
                                                    .Concat(MetadataUtils.GetTrainerOutputMetadata()));

            return(new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(scoreMetadata)),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, predLabelMetadata)
            });
        }
        private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            if (LabelColumn.IsValid)
            {
                bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
                Contracts.Assert(success);

                var metadata = new SchemaShape(labelCol.Metadata.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
                                .Concat(MetadataForScoreColumn()));
                return new[]
                {
                    new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataUtils.MetadataForMulticlassScoreColumn(labelCol))),
                    new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata)
                };
            }
            else
                return new[]
                {
                    new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataForScoreColumn())),
                    new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(MetadataForScoreColumn()))
                };
        }
示例#30
0
        /// <summary>
        /// Schema propagation for transformers. Returns the output schema of the data, if
        /// the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column column, string columnRole)
            {
                if (!inputSchema.TryFindColumn(column.Name, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                }

                if (!column.IsCompatibleWith(col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name,
                                                     column.GetTypeString(), col.GetTypeString());
                }
            }

            CheckColumnsCompatible(LabelColumn, "label");

            foreach (var feat in FeatureColumns)
            {
                CheckColumnsCompatible(feat, "feature");
            }

            if (WeightColumn.IsValid)
            {
                CheckColumnsCompatible(WeightColumn, "weight");
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }