private void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            if (!inputSchema.TryFindColumn(FeatureColumn.Name, out var featureCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name);
            }
            if (!FeatureColumn.IsCompatibleWith(featureCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name,
                                                FeatureColumn.GetTypeString(), featureCol.GetTypeString());
            }

            if (WeightColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name);
                }
                if (!WeightColumn.IsCompatibleWith(weightCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name,
                                                    WeightColumn.GetTypeString(), weightCol.GetTypeString());
                }
            }

            // Special treatment for label column: we allow different types of labels, so the trainers
            // may define their own requirements on the label column.
            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name);
                }
                CheckLabelCompatible(labelCol);
            }
        }
        private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.Assert(labelCol.IsValid);

            Action error =
                () => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "Single or Key", labelCol.GetTypeString());

            if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
            {
                error();
            }
            if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single)
            {
                error();
            }
        }
        /// <summary>
        ///  Gets the output columns.
        /// </summary>
        /// <param name="inputSchema">The input schema. </param>
        /// <returns>The output <see cref="SchemaShape"/></returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", LabelColumn.Name);
                }

                if (!LabelColumn.IsCompatibleWith(labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name, LabelColumn.GetTypeString(), labelCol.GetTypeString());
                }
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
        protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.Assert(labelCol.IsValid);

            Action error =
                () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "R4 or a Key", labelCol.GetTypeString());

            if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
            {
                error();
            }
            if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4)
            {
                error();
            }
        }
예제 #5
0
        protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.Assert(labelCol.IsValid);

            Action error =
                () => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString());

            if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
            {
                error();
            }

            if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8 && !(labelCol.ItemType is BoolType))
            {
                error();
            }
        }