private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { SchemaShape.Column?labelCol = null; var predictedLabelAnnotationCols = AnnotationUtils.GetTrainerOutputAnnotation(); if (LabelColumn.IsValid) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var inputLabelCol); Contracts.Assert(success); labelCol = inputLabelCol; predictedLabelAnnotationCols = predictedLabelAnnotationCols.Concat( inputLabelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)); } var scoreAnnotationCols = AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol); return(new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(scoreAnnotationCols)), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, new SchemaShape(predictedLabelAnnotationCols)) }); }
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); Contracts.Assert(success); var predLabelMetadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues) .Concat(AnnotationUtils.GetTrainerOutputAnnotation())); return(new[] { new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))), new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, predLabelMetadata) }); }