Ejemplo n.º 1
0
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            for (var i = 0; i < Transformer.Inputs.Length; i++)
            {
                var input = Transformer.Inputs[i];
                if (!inputSchema.TryFindColumn(input, out var col))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());

                var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
                var idx = Transformer.Model.InputNames.IndexOf(input);
                if (idx < 0)
                    throw Host.Except($"Column {input} doesn't match input node names of model.");

                var inputNodeInfo = inputsInfo[idx];
                var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
                if (col.ItemType != expectedType)
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
            }

            for (var i = 0; i < Transformer.Outputs.Length; i++)
            {
                resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
                    Transformer.OutputTypes[i].IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].ItemType, false);
            }
            return new SchemaShape(resultDic.Values);
        }
Ejemplo n.º 2
0
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result    = inputSchema.ToDictionary(x => x.Name);
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            // This loop checks if all input columns needed in the underlying transformer can be found
            // in inputSchema.
            // Since ML.NET can only produces tensors (scalars are converted to tensor with shape [1] before feeding
            // ML.NET them into ONNXRuntime), the bridge code in ONNX Transformer assumes that all inputs are tensors.
            for (var i = 0; i < Transformer.Inputs.Length; i++)
            {
                // Get the i-th IDataView input column's name in the underlying ONNX transformer.
                var input = Transformer.Inputs[i];

                // Make sure inputSchema contains the i-th input column.
                if (!inputSchema.TryFindColumn(input, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                }

                // Make sure that the input columns in inputSchema are fixed shape tensors.
                if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
                }

                var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
                var idx        = Transformer.Model.ModelInfo.InputNames.IndexOf(input);
                if (idx < 0)
                {
                    throw Host.Except($"Column {input} doesn't match input node names of model.");
                }

                var inputNodeInfo = inputsInfo[idx];
                var expectedType  = ((VectorDataViewType)inputNodeInfo.DataViewType).ItemType;
                if (col.ItemType != expectedType)
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
                }
            }

            for (var i = 0; i < Transformer.Outputs.Length; i++)
            {
                resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
                                                                           Transformer.OutputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].GetItemType(), false);
            }

            return(new SchemaShape(resultDic.Values));
        }
Ejemplo n.º 3
0
        /// <summary>
        ///  Gets the output columns.
        /// </summary>
        /// <param name="inputSchema">The input schema. </param>
        /// <returns>The output <see cref="SchemaShape"/></returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel);
                }

                if (!LabelColumn.IsCompatibleWith(labelCol))
                {
                    throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
                }
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
        /// <summary>
        /// <see cref="PretrainedTreeFeaturizationEstimator"/> adds three float-vector columns into <paramref name="inputSchema"/>.
        /// Given a feature vector column, the added columns are the prediction values of all trees, the leaf IDs the feature
        /// vector falls into, and the paths to those leaves.
        /// </summary>
        /// <param name="inputSchema">A schema which contains a feature column. Note that feature column name can be specified
        /// by <see cref="OptionsBase.InputColumnName"/>.</param>
        /// <returns>Output <see cref="SchemaShape"/> produced by <see cref="PretrainedTreeFeaturizationEstimator"/>.</returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Env.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(FeatureColumnName, out var col))
            {
                throw Env.ExceptSchemaMismatch(nameof(inputSchema), "input", FeatureColumnName);
            }

            var result = inputSchema.ToDictionary(x => x.Name);

            if (TreesColumnName != null)
            {
                result[TreesColumnName] = new SchemaShape.Column(TreesColumnName,
                                                                 SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            if (LeavesColumnName != null)
            {
                result[LeavesColumnName] = new SchemaShape.Column(LeavesColumnName,
                                                                  SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            if (PathsColumnName != null)
            {
                result[PathsColumnName] = new SchemaShape.Column(PathsColumnName,
                                                                 SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            return(new SchemaShape(result.Values));
        }
Ejemplo n.º 5
0
        // Add one column called WasColumnImputed, otherwise everything stays the same.
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            var columns = inputSchema.ToDictionary(x => x.Name);

            columns[IsRowImputedColumnName] = new SchemaShape.Column(IsRowImputedColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false);
            return(new SchemaShape(columns.Values));
        }
Ejemplo n.º 6
0
        /// <summary>
        /// Schema propagation for transformers.
        /// Returns the output schema of the data, if the input schema is like the one provided.
        /// Creates three output columns if confidence intervals are requested otherwise
        /// just one.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(_options.Source, out var col))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source);
            }
            if (col.ItemType != NumberDataViewType.Single)
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source, "Single", col.GetTypeString());
            }

            var resultDic = inputSchema.ToDictionary(x => x.Name);

            resultDic[_options.Name] = new SchemaShape.Column(
                _options.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);

            if (!string.IsNullOrEmpty(_options.ConfidenceUpperBoundColumn))
            {
                resultDic[_options.ConfidenceLowerBoundColumn] = new SchemaShape.Column(
                    _options.ConfidenceLowerBoundColumn, SchemaShape.Column.VectorKind.Vector,
                    NumberDataViewType.Single, false);

                resultDic[_options.ConfidenceUpperBoundColumn] = new SchemaShape.Column(
                    _options.ConfidenceUpperBoundColumn, SchemaShape.Column.VectorKind.Vector,
                    NumberDataViewType.Single, false);
            }

            return(new SchemaShape(resultDic.Values));
        }
        /// <summary>
        ///  Gets the output columns.
        /// </summary>
        /// <param name="inputSchema">The input schema. </param>
        /// <returns>The output <see cref="SchemaShape"/></returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", LabelColumn.Name);
                }

                if (!LabelColumn.IsCompatibleWith(labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name, LabelColumn.GetTypeString(), labelCol.GetTypeString());
                }
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colPair in _columns)
            {
                if (!inputSchema.TryFindColumn(colPair.InputColumnName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.InputColumnName);
                }
                if (!CountFeatureSelectionUtils.IsValidColumnType(col.ItemType))
                {
                    throw _host.ExceptUserArg(nameof(inputSchema), "Column '{0}' does not have compatible type. Expected types are float, double or string.", colPair.InputColumnName);
                }
                var metadata = new List <SchemaShape.Column>();
                if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata.Add(slotMeta);
                }
                if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.CategoricalSlotRanges, out var categoricalSlotMeta))
                {
                    metadata.Add(categoricalSlotMeta);
                }
                metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
                result[colPair.Name] = new SchemaShape.Column(colPair.Name, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray()));
            }
            return(new SchemaShape(result.Values));
        }
Ejemplo n.º 9
0
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            var addedCols        = DataViewConstructionUtils.GetSchemaColumns(Transformer.AddedSchema);
            var addedSchemaShape = SchemaShape.Create(SchemaExtensions.MakeSchema(addedCols));

            var result   = inputSchema.ToDictionary(x => x.Name);
            var inputDef = InternalSchemaDefinition.Create(typeof(TSrc), Transformer.InputSchemaDefinition);

            foreach (var col in inputDef.Columns)
            {
                if (!result.TryGetValue(col.ColumnName, out var column))
                {
                    throw Contracts.ExceptSchemaMismatch(nameof(inputSchema), "input", col.ColumnName);
                }

                SchemaShape.GetColumnTypeShape(col.ColumnType, out var vecKind, out var itemType, out var isKey);
                // Special treatment for vectors: if we expect variable vector, we also allow fixed-size vector.
                if (itemType != column.ItemType || isKey != column.IsKey ||
                    vecKind == SchemaShape.Column.VectorKind.Scalar && column.Kind != SchemaShape.Column.VectorKind.Scalar ||
                    vecKind == SchemaShape.Column.VectorKind.Vector && column.Kind != SchemaShape.Column.VectorKind.Vector ||
                    vecKind == SchemaShape.Column.VectorKind.VariableVector && column.Kind == SchemaShape.Column.VectorKind.Scalar)
                {
                    throw Contracts.ExceptSchemaMismatch(nameof(inputSchema), "input", col.ColumnName, col.ColumnType.ToString(), column.GetTypeString());
                }
            }

            foreach (var addedCol in addedSchemaShape)
            {
                result[addedCol.Name] = addedCol;
            }

            return(new SchemaShape(result.Values));
        }
        /// <summary>
        /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
        /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
        /// </summary>
        /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
        SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema)
        {
            Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string columnRole) =>
            {
                if (column.IsValid)
                {
                    if (!inputSchema.TryFindColumn(column.Name, out var outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                    }
                    if (!column.IsCompatibleWith(outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name, column.GetTypeString(), outCol.GetTypeString());
                    }
                }
            };

            // Check the input schema.
            checkColumnValid(ScoreColumn, "score");
            checkColumnValid(WeightColumn, "weight");
            checkColumnValid(LabelColumn, "label");

            // Create the new Probability column.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
                                                                                SchemaShape.Column.VectorKind.Scalar,
                                                                                NumberDataViewType.Single,
                                                                                false,
                                                                                new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation(true)));

            return(new SchemaShape(outColumns.Values));
        }
Ejemplo n.º 11
0
        /// <summary>
        /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
        /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
        /// </summary>
        /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
        SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema)
        {
            Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string expected) =>
            {
                if (column.IsValid)
                {
                    if (!inputSchema.TryFindColumn(column.Name, out var outCol))
                    {
                        throw Host.Except($"{expected} column '{column.Name}' is not found");
                    }
                    if (!column.IsCompatibleWith(outCol))
                    {
                        throw Host.Except($"{expected} column '{column.Name}' is not compatible");
                    }
                }
            };

            // check the input schema
            checkColumnValid(ScoreColumn, DefaultColumnNames.Score);
            checkColumnValid(WeightColumn, DefaultColumnNames.Weight);
            checkColumnValid(LabelColumn, DefaultColumnNames.Label);
            checkColumnValid(FeatureColumn, DefaultColumnNames.Features);
            checkColumnValid(PredictedLabel, DefaultColumnNames.PredictedLabel);

            //create the new Probability column
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
                                                                                SchemaShape.Column.VectorKind.Scalar,
                                                                                NumberType.R4,
                                                                                false,
                                                                                new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true)));

            return(new SchemaShape(outColumns.Values));
        }
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            result[_name] = CheckInputsAndMakeColumn(inputSchema, _name, _source);
            return(new SchemaShape(result.Values));
        }
        /// <summary>
        /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
        /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
        /// The same annotation data that would be produced by <see cref="AnnotationUtils.GetTrainerOutputAnnotation(bool)"/> is marked as
        /// being present on the output, if it is present on the input score column.
        /// </summary>
        /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
        SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema)
        {
            Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string columnRole) =>
            {
                if (column.IsValid)
                {
                    if (!inputSchema.TryFindColumn(column.Name, out var outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                    }
                    if (!column.IsCompatibleWith(outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name, column.GetTypeString(), outCol.GetTypeString());
                    }
                }
            };

            // Check the input schema.
            checkColumnValid(ScoreColumn, "score");
            checkColumnValid(WeightColumn, "weight");
            checkColumnValid(LabelColumn, "label");

            bool success = inputSchema.TryFindColumn(ScoreColumn.Name, out var inputScoreCol);

            Host.Assert(success);
            const SchemaShape.Column.VectorKind scalar = SchemaShape.Column.VectorKind.Scalar;

            var annotations = new List <SchemaShape.Column>();

            annotations.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized,
                                                   SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
            // We only propagate this training column metadata if it looks like it's all there, and all correct.
            if (inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreColumnSetId, out var setIdCol) &&
                setIdCol.Kind == scalar && setIdCol.IsKey && setIdCol.ItemType == NumberDataViewType.UInt32 &&
                inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreColumnKind, out var kindCol) &&
                kindCol.Kind == scalar && kindCol.ItemType is TextDataViewType &&
                inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreValueKind, out var valueKindCol) &&
                valueKindCol.Kind == scalar && valueKindCol.ItemType is TextDataViewType)
            {
                annotations.Add(setIdCol);
                annotations.Add(kindCol);
                annotations.Add(valueKindCol);
            }

            // Create the new Probability column.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
                                                                                SchemaShape.Column.VectorKind.Scalar,
                                                                                NumberDataViewType.Single,
                                                                                false, new SchemaShape(annotations));

            return(new SchemaShape(outColumns.Values));
        }
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            var columns = inputSchema.ToDictionary(x => x.Name);

            foreach (ColumnsProduced column in Enum.GetValues(typeof(ColumnsProduced)))
            {
                columns[_options.Prefix + column.ToString()] = new SchemaShape.Column(_options.Prefix + column.ToString(), SchemaShape.Column.VectorKind.Scalar,
                                                                                      ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()), false, null);
            }

            return(new SchemaShape(columns.Values));
        }
Ejemplo n.º 15
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            var columns = inputSchema.ToDictionary(x => x.Name);

            foreach (var column in _options.Columns)
            {
                var inputColumn = columns[column.Source];
                columns[column.Name] = new SchemaShape.Column(column.Name, inputColumn.Kind,
                                                              inputColumn.ItemType, inputColumn.IsKey, inputColumn.Annotations);
            }

            return(new SchemaShape(columns.Values));
        }
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            CheckInputSchema(inputSchema);

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
Ejemplo n.º 17
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var columnDictionary = inputSchema.ToDictionary(x => x.Name);

            for (int i = 0; i < _columns.Length; i++)
            {
                for (int j = 0; j < _columns[i].InputColumnNames.Length; j++)
                {
                    if (!inputSchema.TryFindColumn(_columns[i].InputColumnNames[j], out var inputCol))
                    {
                        throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[i].InputColumnNames[j]);
                    }
                }

                // Make sure there is at most one vector valued source column.
                var inputTypes = new DataViewType[_columns[i].InputColumnNames.Length];
                var ivec       = FindVectorInputColumn(_host, _columns[i].InputColumnNames, inputSchema, inputTypes);
                var node       = ParseAndBindLambda(_host, _columns[i].Expression, ivec, inputTypes, out var perm);

                var typeRes = node.ResultType;
                _host.Assert(typeRes is PrimitiveDataViewType);

                // If one of the input columns is a vector column, we pass through the slot names metadata.
                SchemaShape.Column.VectorKind outputVectorKind;
                var metadata = new List <SchemaShape.Column>();
                if (ivec == -1)
                {
                    outputVectorKind = SchemaShape.Column.VectorKind.Scalar;
                }
                else
                {
                    inputSchema.TryFindColumn(_columns[i].InputColumnNames[ivec], out var vectorCol);
                    outputVectorKind = vectorCol.Kind;
                    if (vectorCol.HasSlotNames())
                    {
                        var b = vectorCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotNames);
                        _host.Assert(b);
                        metadata.Add(slotNames);
                    }
                }
                var outputSchemaShapeColumn = new SchemaShape.Column(_columns[i].Name, outputVectorKind, typeRes, false, new SchemaShape(metadata));
                columnDictionary[_columns[i].Name] = outputSchemaShapeColumn;
            }
            return(new SchemaShape(columnDictionary.Values));
        }
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            if (!inputSchema.TryFindColumn(_labelColumnName, out var labelCol))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "label", _labelColumnName);
            }

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }
                if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "known-size vector or scalar", col.GetTypeString());
                }

                if (!col.IsKey || !col.ItemType.Equals(NumberDataViewType.UInt32))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "vector or scalar of U4 key type", col.GetTypeString());
                }

                // We supply slot names if the source is a single-value column, or if it has slot names.
                var newMetadataKinds = new List <SchemaShape.Column>();
                if (col.Kind == SchemaShape.Column.VectorKind.Scalar)
                {
                    newMetadataKinds.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
                }
                else if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                {
                    newMetadataKinds.Add(slotMeta);
                }
                var meta = new SchemaShape(newMetadataKinds);
                result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, meta);
            }

            return(new SchemaShape(result.Values));
        }
Ejemplo n.º 19
0
        /// <summary>
        /// Returns the schema that would be produced by the transformation.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colPair in _infos)
            {
                if (!inputSchema.TryFindColumn(colPair.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input);
                }
                var reason = VectorWhiteningTransformer.TestColumn(col.ItemType);
                if (reason != null)
                {
                    throw _host.ExceptUserArg(nameof(inputSchema), reason);
                }
                result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, col.IsKey, null);
            }
            return(new SchemaShape(result.Values));
        }
Ejemplo n.º 20
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            var newColumns = new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
                new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
            };

            foreach (SchemaShape.Column column in newColumns)
            {
                outColumns[column.Name] = column;
            }

            return(new SchemaShape(outColumns.Values));
        }
Ejemplo n.º 21
0
        /// <summary>
        /// Schema propagation for transformers. Returns the output schema of the data, if
        /// the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string columnRole)
            {
                if (!inputSchema.TryFindColumn(cachedColumn.Name, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(col), columnRole, cachedColumn.Name);
                }

                if (!cachedColumn.IsCompatibleWith(col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, cachedColumn.Name,
                                                     cachedColumn.GetTypeString(), col.GetTypeString());
                }
            }

            // Check if label column is good.
            var labelColumn = new SchemaShape.Column(LabelName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false);

            CheckColumnsCompatible(labelColumn, "label");

            // Check if columns of matrix's row and column indexes are good. Note that column of IDataView and column of matrix are two different things.
            var matrixColumnIndexColumn = new SchemaShape.Column(MatrixColumnIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);
            var matrixRowIndexColumn    = new SchemaShape.Column(MatrixRowIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);

            CheckColumnsCompatible(matrixColumnIndexColumn, "matrixColumnIndex");
            CheckColumnsCompatible(matrixRowIndexColumn, "matrixRowIndex");

            // Input columns just pass through so that output column dictionary contains all input columns.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            // Add columns produced by this estimator.
            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
Ejemplo n.º 22
0
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }

                if (!col.ItemType.IsStandardScalar())
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }
                SchemaShape metadata;

                // In the event that we are transforming something that is of type key, we will get their type of key value
                // metadata, unless it has none or is not vector in which case we back off to having key values over the item type.
                if (!col.IsKey || !col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
                {
                    kv = new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                                                colInfo.TextKeyValues ? TextDataViewType.Instance : col.ItemType, col.IsKey);
                }
                Contracts.Assert(kv.IsValid);

                if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata = new SchemaShape(new[] { slotMeta, kv });
                }
                else
                {
                    metadata = new SchemaShape(new[] { kv });
                }
                result[colInfo.OutputColumnName] = new SchemaShape.Column(colInfo.OutputColumnName, col.Kind, NumberDataViewType.UInt32, true, metadata);
            }

            return(new SchemaShape(result.Values));
        }
Ejemplo n.º 23
0
        /// <summary>
        /// Schema propagation for transformers. Returns the output schema of the data, if
        /// the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column column, string columnRole)
            {
                if (!inputSchema.TryFindColumn(column.Name, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                }

                if (!column.IsCompatibleWith(col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name,
                                                     column.GetTypeString(), col.GetTypeString());
                }
            }

            CheckColumnsCompatible(LabelColumn, "label");

            foreach (var feat in FeatureColumns)
            {
                CheckColumnsCompatible(feat, "feature");
            }

            if (WeightColumn.IsValid)
            {
                CheckColumnsCompatible(WeightColumn, "weight");
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
Ejemplo n.º 24
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column column, string defaultName)
            {
                if (!inputSchema.TryFindColumn(column.Name, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(col), defaultName, defaultName);
                }

                if (!column.IsCompatibleWith(col))
                {
                    throw Host.Except($"{defaultName} column '{column.Name}' is not compatible");
                }
            }

            CheckColumnsCompatible(LabelColumn, DefaultColumnNames.Label);

            foreach (var feat in FeatureColumns)
            {
                CheckColumnsCompatible(feat, DefaultColumnNames.Features);
            }

            if (WeightColumn.IsValid)
            {
                CheckColumnsCompatible(WeightColumn, DefaultColumnNames.Weight);
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
Ejemplo n.º 25
0
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName);
            }
            if (col.ItemType != NumberType.R4)
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, NumberType.R4.ToString(), col.GetTypeString());
            }

            var metadata = new List <SchemaShape.Column>()
            {
                new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
            };
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
                Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));

            return(new SchemaShape(resultDic.Values));
        }
Ejemplo n.º 26
0
        /// <summary>
        /// Schema propagation for transformers.
        /// Returns the output schema of the data, if the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(_options.Source, out var col))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source);
            }
            if (col.ItemType != NumberDataViewType.Single)
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source, "Single", col.GetTypeString());
            }

            var metadata = new List <SchemaShape.Column>()
            {
                new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)
            };
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            resultDic[_options.Name] = new SchemaShape.Column(
                _options.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Double, false, new SchemaShape(metadata));

            return(new SchemaShape(resultDic.Values));
        }