Exemplo n.º 1
0
        private void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            if (!inputSchema.TryFindColumn(FeatureColumn.Name, out var featureCol))
            {
                throw Host.Except($"Feature column '{FeatureColumn.Name}' is not found");
            }
            if (!FeatureColumn.IsCompatibleWith(featureCol))
            {
                throw Host.Except($"Feature column '{FeatureColumn.Name}' is not compatible");
            }

            if (WeightColumn != null)
            {
                if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol))
                {
                    throw Host.Except($"Weight column '{WeightColumn.Name}' is not found");
                }
                if (!WeightColumn.IsCompatibleWith(weightCol))
                {
                    throw Host.Except($"Weight column '{WeightColumn.Name}' is not compatible");
                }
            }

            // Special treatment for label column: we allow different types of labels, so the trainers
            // may define their own requirements on the label column.
            if (LabelColumn != null)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.Except($"Label column '{LabelColumn.Name}' is not found");
                }
                CheckLabelCompatible(labelCol);
            }
        }
        /// <summary>
        ///  Gets the output columns.
        /// </summary>
        /// <param name="inputSchema">The input schema. </param>
        /// <returns>The output <see cref="SchemaShape"/></returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", LabelColumn.Name);
                }

                if (!LabelColumn.IsCompatibleWith(labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name, LabelColumn.GetTypeString(), labelCol.GetTypeString());
                }
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
Exemplo n.º 3
0
        /// <summary>
        ///  Gets the output columns.
        /// </summary>
        /// <param name="inputSchema">The input schema. </param>
        /// <returns>The output <see cref="SchemaShape"/></returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (LabelColumn != null)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel);
                }

                if (!LabelColumn.IsCompatibleWith(labelCol))
                {
                    throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
                }
            }

            var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
Exemplo n.º 4
0
        protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.CheckValue(labelCol, nameof(labelCol));
            Contracts.AssertValue(LabelColumn);

            if (!LabelColumn.IsCompatibleWith(labelCol))
            {
                throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
            }
        }
        protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.CheckParam(labelCol.IsValid, nameof(labelCol), "not initialized properly");
            Host.Assert(LabelColumn.IsValid);

            if (!LabelColumn.IsCompatibleWith(labelCol))
            {
                throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
            }
        }
        private protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.CheckParam(labelCol.IsValid, nameof(labelCol), "not initialized properly");
            Host.Assert(LabelColumn.IsValid);

            if (!LabelColumn.IsCompatibleWith(labelCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", WeightColumn.Name,
                                                LabelColumn.GetTypeString(), labelCol.GetTypeString());
            }
        }
        private void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            if (!inputSchema.TryFindColumn(FeatureColumn.Name, out var featureCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name);
            }
            if (!FeatureColumn.IsCompatibleWith(featureCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name,
                                                FeatureColumn.GetTypeString(), featureCol.GetTypeString());
            }

            if (WeightColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name);
                }
                if (!WeightColumn.IsCompatibleWith(weightCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name,
                                                    WeightColumn.GetTypeString(), weightCol.GetTypeString());
                }
            }

            // Special treatment for label column: we allow different types of labels, so the trainers
            // may define their own requirements on the label column.
            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name);
                }
                CheckLabelCompatible(labelCol);
            }
        }