コード例 #1
0
        void TestDnnImageFeaturizer()
        {
            var samplevector = GetSampleArrayData();

            var dataView = DataViewConstructionUtils.CreateFromList(Env,
                                                                    new TestData[] {
                new TestData()
                {
                    data_0 = samplevector
                },
            });

            var xyData = new List <TestDataXY> {
                new TestDataXY()
                {
                    A = new float[inputSize]
                }
            };
            var stringData = new List <TestDataDifferntType> {
                new TestDataDifferntType()
                {
                    data_0 = new string[inputSize]
                }
            };
            var sizeData = new List <TestDataSize> {
                new TestDataSize()
                {
                    data_0 = new float[2]
                }
            };
            var pipe = new DnnImageFeaturizerEstimator(Env, "output_1", m => m.ModelSelector.ResNet18(m.Environment, m.OutputColumn, m.InputColumn), "data_0");

            var invalidDataWrongNames      = ML.Data.ReadFromEnumerable(xyData);
            var invalidDataWrongTypes      = ML.Data.ReadFromEnumerable(stringData);
            var invalidDataWrongVectorSize = ML.Data.ReadFromEnumerable(sizeData);

            TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames);
            TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes);
            pipe.GetOutputSchema(SchemaShape.Create(invalidDataWrongVectorSize.Schema));
            try
            {
                pipe.Fit(invalidDataWrongVectorSize);
                Assert.False(true);
            }
            catch (ArgumentOutOfRangeException) { }
            catch (InvalidOperationException) { }
        }
コード例 #2
0
        protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            return(new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score,
                                       SchemaShape.Column.VectorKind.Scalar,
                                       NumberType.R4,
                                       false,
                                       new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),

                new SchemaShape.Column(DefaultColumnNames.PredictedLabel,
                                       SchemaShape.Column.VectorKind.Scalar,
                                       BoolType.Instance,
                                       false,
                                       new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
            });
        }
コード例 #3
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var columnDictionary = inputSchema.ToDictionary(x => x.Name);

            for (int i = 0; i < _columns.Length; i++)
            {
                for (int j = 0; j < _columns[i].InputColumnNames.Length; j++)
                {
                    if (!inputSchema.TryFindColumn(_columns[i].InputColumnNames[j], out var inputCol))
                    {
                        throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[i].InputColumnNames[j]);
                    }
                }

                // Make sure there is at most one vector valued source column.
                var inputTypes = new DataViewType[_columns[i].InputColumnNames.Length];
                var ivec       = FindVectorInputColumn(_host, _columns[i].InputColumnNames, inputSchema, inputTypes);
                var node       = ParseAndBindLambda(_host, _columns[i].Expression, ivec, inputTypes, out var perm);

                var typeRes = node.ResultType;
                _host.Assert(typeRes is PrimitiveDataViewType);

                // If one of the input columns is a vector column, we pass through the slot names metadata.
                SchemaShape.Column.VectorKind outputVectorKind;
                var metadata = new List <SchemaShape.Column>();
                if (ivec == -1)
                {
                    outputVectorKind = SchemaShape.Column.VectorKind.Scalar;
                }
                else
                {
                    inputSchema.TryFindColumn(_columns[i].InputColumnNames[ivec], out var vectorCol);
                    outputVectorKind = vectorCol.Kind;
                    if (vectorCol.HasSlotNames())
                    {
                        var b = vectorCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotNames);
                        _host.Assert(b);
                        metadata.Add(slotNames);
                    }
                }
                var outputSchemaShapeColumn = new SchemaShape.Column(_columns[i].Name, outputVectorKind, typeRes, false, new SchemaShape(metadata));
                columnDictionary[_columns[i].Name] = outputSchemaShapeColumn;
            }
            return(new SchemaShape(columnDictionary.Values));
        }
コード例 #4
0
        private protected TTransformer TrainTransformer(IDataView trainSet,
                                                        IDataView validationSet = null, IPredictor initPredictor = null)
        {
            CheckInputSchema(SchemaShape.Create(trainSet.Schema));
            var            trainRoleMapped = MakeRoles(trainSet);
            RoleMappedData validRoleMapped = null;

            if (validationSet != null)
            {
                CheckInputSchema(SchemaShape.Create(validationSet.Schema));
                validRoleMapped = MakeRoles(validationSet);
            }

            var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor));

            return(MakeTransformer(pred, trainSet.Schema));
        }
コード例 #5
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(_args.Source, out var col))
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source);
            if (col.ItemType != NumberType.R4)
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, NumberType.R4.ToString(), col.GetTypeString());

            var metadata = new List<SchemaShape.Column>() {
                new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
            };
            var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
            resultDic[_args.Name] = new SchemaShape.Column(
                _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));

            return new SchemaShape(resultDic.Values);
        }
コード例 #6
0
        /// <summary>
        ///  Gets the output columns.
        /// </summary>
        /// <param name="inputSchema">The input schema. </param>
        /// <returns>The output <see cref="SchemaShape"/></returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                    throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", LabelColumn.Name);

                if (!LabelColumn.IsCompatibleWith(labelCol))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name, LabelColumn.GetTypeString(), labelCol.GetTypeString());
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);
            foreach (var col in GetOutputColumnsCore(inputSchema))
                outColumns[col.Name] = col;

            return new SchemaShape(outColumns.Values);
        }
コード例 #7
0
ファイル: ImagesTests.cs プロジェクト: artemiusgreat/ML-NET
        public void TestEstimatorSaveLoad()
        {
            IHostEnvironment env = new MLContext(1);
            var dataFile         = GetDataPath("images/images.tsv");
            var imageFolder      = Path.GetDirectoryName(dataFile);
            var data             = TextLoader.Create(env, new TextLoader.Options()
            {
                Columns = new[]
                {
                    new TextLoader.Column("ImagePath", DataKind.String, 0),
                    new TextLoader.Column("Name", DataKind.String, 1),
                }
            }, new MultiFileSource(dataFile));

            var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImageReal", "ImagePath"))
                       .Append(new ImageResizingEstimator(env, "ImageReal", 100, 100, "ImageReal"))
                       .Append(new ImagePixelExtractingEstimator(env, "ImagePixels", "ImageReal"))
                       .Append(new ImageGrayscalingEstimator(env, ("ImageGray", "ImageReal")));

            pipe.GetOutputSchema(SchemaShape.Create(data.Schema));
            var model = pipe.Fit(data);

            var tempPath = Path.GetTempFileName();

            using (var file = new SimpleFileHandle(env, tempPath, true, true))
            {
                using (var fs = file.CreateWriteStream())
                    ML.Model.Save(model, null, fs);
                ITransformer model2;
                using (var fs = file.OpenReadStream())
                    model2 = ML.Model.Load(fs, out var schema);

                var transformerChain = model2 as TransformerChain <ITransformer>;
                Assert.NotNull(transformerChain);

                var newCols = ((ImageLoadingTransformer)transformerChain.First()).Columns;
                var oldCols = ((ImageLoadingTransformer)model.First()).Columns;
                Assert.True(newCols
                            .Zip(oldCols, (x, y) => x == y)
                            .All(x => x));
            }
            Done();
        }
コード例 #8
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            if (!inputSchema.TryFindColumn(_labelColumnName, out var labelCol))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "label", _labelColumnName);
            }

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }
                if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "known-size vector or scalar", col.GetTypeString());
                }

                if (!col.IsKey || !col.ItemType.Equals(NumberDataViewType.UInt32))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "vector or scalar of U4 key type", col.GetTypeString());
                }

                // We supply slot names if the source is a single-value column, or if it has slot names.
                var newMetadataKinds = new List <SchemaShape.Column>();
                if (col.Kind == SchemaShape.Column.VectorKind.Scalar)
                {
                    newMetadataKinds.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
                }
                else if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                {
                    newMetadataKinds.Add(slotMeta);
                }
                var meta = new SchemaShape(newMetadataKinds);
                result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, meta);
            }

            return(new SchemaShape(result.Values));
        }
コード例 #9
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            var newColumns = new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
                new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
            };

            foreach (SchemaShape.Column column in newColumns)
            {
                outColumns[column.Name] = column;
            }

            return(new SchemaShape(outColumns.Values));
        }
コード例 #10
0
        /// <summary>
        /// Returns the schema that would be produced by the transformation.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colPair in _infos)
            {
                if (!inputSchema.TryFindColumn(colPair.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input);
                }
                var reason = VectorWhiteningTransformer.TestColumn(col.ItemType);
                if (reason != null)
                {
                    throw _host.ExceptUserArg(nameof(inputSchema), reason);
                }
                result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, col.IsKey, null);
            }
            return(new SchemaShape(result.Values));
        }
コード例 #11
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column column, string defaultName)
            {
                if (!inputSchema.TryFindColumn(column.Name, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(col), defaultName, defaultName);
                }

                if (!column.IsCompatibleWith(col))
                {
                    throw Host.Except($"{defaultName} column '{column.Name}' is not compatible");
                }
            }

            if (LabelColumn != null)
            {
                CheckColumnsCompatible(LabelColumn, DefaultColumnNames.Label);
            }

            foreach (var feat in FeatureColumns)
            {
                CheckColumnsCompatible(feat, DefaultColumnNames.Features);
            }

            if (WeightColumn != null)
            {
                CheckColumnsCompatible(WeightColumn, DefaultColumnNames.Weight);
            }

            var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
コード例 #12
0
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result    = inputSchema.Columns.ToDictionary(x => x.Name);
            var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);

            for (var i = 0; i < Transformer.Inputs.Length; i++)
            {
                var input = Transformer.Inputs[i];
                if (!inputSchema.TryFindColumn(input, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                }
                if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());
                }

                var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
                var idx        = Transformer.Model.InputNames.IndexOf(input);
                if (idx < 0)
                {
                    throw Host.Except($"Column {input} doesn't match input node names of model.");
                }

                var inputNodeInfo = inputsInfo[idx];
                var expectedType  = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
                if (col.ItemType != expectedType)
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
                }
            }

            for (var i = 0; i < Transformer.Outputs.Length; i++)
            {
                resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
                                                                           Transformer.OutputTypes[i].IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, NumberType.R4, false);
            }
            return(new SchemaShape(resultDic.Values));
        }
コード例 #13
0
        /// <summary>
        /// Schema propagation for transformers. Returns the output schema of the data, if
        /// the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string columnRole)
            {
                if (!inputSchema.TryFindColumn(cachedColumn.Name, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(col), columnRole, cachedColumn.Name);
                }

                if (!cachedColumn.IsCompatibleWith(col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, cachedColumn.Name,
                                                     cachedColumn.GetTypeString(), col.GetTypeString());
                }
            }

            // Check if label column is good.
            var labelColumn = new SchemaShape.Column(LabelName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false);

            CheckColumnsCompatible(labelColumn, "label");

            // Check if columns of matrix's row and column indexes are good. Note that column of IDataView and column of matrix are two different things.
            var matrixColumnIndexColumn = new SchemaShape.Column(MatrixColumnIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);
            var matrixRowIndexColumn    = new SchemaShape.Column(MatrixRowIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);

            CheckColumnsCompatible(matrixColumnIndexColumn, "matrixColumnIndex");
            CheckColumnsCompatible(matrixRowIndexColumn, "matrixRowIndex");

            // Input columns just pass through so that output column dictionary contains all input columns.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            // Add columns produced by this estimator.
            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
コード例 #14
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
                }
                if (!(col.ItemType is TextType) || (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.Kind != SchemaShape.Column.VectorKind.Vector))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new VectorType(TextType.Instance).ToString(), col.GetTypeString());
                }

                result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
            }

            return(new SchemaShape(result.Values));
        }
コード例 #15
0
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }

                if (!col.ItemType.IsStandardScalar())
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }
                SchemaShape metadata;

                // In the event that we are transforming something that is of type key, we will get their type of key value
                // metadata, unless it has none or is not vector in which case we back off to having key values over the item type.
                if (!col.IsKey || !col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
                {
                    kv = new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                                                colInfo.TextKeyValues ? TextDataViewType.Instance : col.ItemType, col.IsKey);
                }
                Contracts.Assert(kv.IsValid);

                if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata = new SchemaShape(new[] { slotMeta, kv });
                }
                else
                {
                    metadata = new SchemaShape(new[] { kv });
                }
                result[colInfo.OutputColumnName] = new SchemaShape.Column(colInfo.OutputColumnName, col.Kind, NumberDataViewType.UInt32, true, metadata);
            }

            return(new SchemaShape(result.Values));
        }
コード例 #16
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
                }

                if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
                }
                SchemaShape metadata;

                // In the event that we are transforming something that is of type key, we will get their type of key value
                // metadata, unless it has none or is not vector in which case we back off to having key values over the item type.
                if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
                {
                    kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                                                col.ItemType, col.IsKey);
                }
                Contracts.AssertValue(kv);

                if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata = new SchemaShape(new[] { slotMeta, kv });
                }
                else
                {
                    metadata = new SchemaShape(new[] { kv });
                }
                result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata);
            }

            return(new SchemaShape(result.Values));
        }
コード例 #17
0
        public TransformerChain <TLastTransformer> Fit(IDataView input)
        {
            // Before fitting, run schema propagation.
            GetOutputSchema(SchemaShape.Create(input.Schema));

            IDataView current = input;
            var       xfs     = new ITransformer[_estimators.Length];

            for (int i = 0; i < _estimators.Length; i++)
            {
                var est = _estimators[i];
                xfs[i]  = est.Fit(current);
                current = xfs[i].Transform(current);
                if (_needCacheAfter[i] && i < _estimators.Length - 1)
                {
                    Contracts.AssertValue(_host);
                    current = new CacheDataView(_host, current, null);
                }
            }

            return(new TransformerChain <TLastTransformer>(xfs, _scopes));
        }
コード例 #18
0
        private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            if (LabelColumn.IsValid)
            {
                bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
                Contracts.Assert(success);

                var metadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
                                .Concat(MetadataForScoreColumn()));
                return new[]
                {
                    new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
                    new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, metadata)
                };
            }
            else
                return new[]
                {
                    new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(MetadataForScoreColumn())),
                    new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, new SchemaShape(MetadataForScoreColumn()))
                };
        }
コード例 #19
0
        private void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            var featureCol = inputSchema.FindColumn(FeatureColumn.Name);

            if (featureCol == null)
            {
                throw Host.Except($"Feature column '{FeatureColumn.Name}' is not found");
            }
            if (!FeatureColumn.IsCompatibleWith(featureCol))
            {
                throw Host.Except($"Feature column '{FeatureColumn.Name}' is not compatible");
            }

            if (WeightColumn != null)
            {
                var weightCol = inputSchema.FindColumn(WeightColumn.Name);
                if (weightCol == null)
                {
                    throw Host.Except($"Weight column '{WeightColumn.Name}' is not found");
                }
                if (!WeightColumn.IsCompatibleWith(weightCol))
                {
                    throw Host.Except($"Weight column '{WeightColumn.Name}' is not compatible");
                }
            }

            // Special treatment for label column: we allow different types of labels, so the trainers
            // may define their own requirements on the label column.
            if (LabelColumn != null)
            {
                var labelCol = inputSchema.FindColumn(LabelColumn.Name);
                if (labelCol == null)
                {
                    throw Host.Except($"Label column '{LabelColumn.Name}' is not found");
                }
                CheckLabelCompatible(labelCol);
            }
        }
コード例 #20
0
        protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);

            Contracts.Assert(success);

            var scoreMetadata = new List <SchemaShape.Column>()
            {
                new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
            };

            scoreMetadata.AddRange(MetadataUtils.GetTrainerOutputMetadata());

            var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
                                                    .Concat(MetadataUtils.GetTrainerOutputMetadata()));

            return(new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(scoreMetadata)),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, predLabelMetadata)
            });
        }
コード例 #21
0
        /// <summary>
        /// Schema propagation for transformers. Returns the output schema of the data, if
        /// the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column column, string columnRole)
            {
                if (!inputSchema.TryFindColumn(column.Name, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                }

                if (!column.IsCompatibleWith(col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name,
                                                     column.GetTypeString(), col.GetTypeString());
                }
            }

            CheckColumnsCompatible(LabelColumn, "label");

            foreach (var feat in FeatureColumns)
            {
                CheckColumnsCompatible(feat, "feature");
            }

            if (WeightColumn.IsValid)
            {
                CheckColumnsCompatible(WeightColumn, "weight");
            }

            var outColumns = inputSchema.ToDictionary(x => x.Name);

            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
コード例 #22
0
        void SimpleTest()
        {
            var metadataBuilder = new DataViewSchema.Annotations.Builder();

            metadataBuilder.Add("M", NumberDataViewType.Single, (ref float v) => v = 484f);
            var schemaBuilder = new DataViewSchema.Builder();

            schemaBuilder.AddColumn("A", new VectorDataViewType(NumberDataViewType.Single, 94));
            schemaBuilder.AddColumn("B", new KeyDataViewType(typeof(uint), 17));
            schemaBuilder.AddColumn("C", NumberDataViewType.Int32, metadataBuilder.ToAnnotations());

            var shape = SchemaShape.Create(schemaBuilder.ToSchema());

            var fakeSchema = FakeSchemaFactory.Create(shape);

            var columnA = fakeSchema[0];
            var columnB = fakeSchema[1];
            var columnC = fakeSchema[2];

            Assert.Equal("A", columnA.Name);
            Assert.Equal(NumberDataViewType.Single, columnA.Type.GetItemType());
            Assert.Equal(10, columnA.Type.GetValueCount());

            Assert.Equal("B", columnB.Name);
            Assert.Equal(InternalDataKind.U4, columnB.Type.GetRawKind());
            Assert.Equal(10u, columnB.Type.GetKeyCount());

            Assert.Equal("C", columnC.Name);
            Assert.Equal(NumberDataViewType.Int32, columnC.Type);

            var metaC = columnC.Annotations;

            Assert.Single(metaC.Schema);

            float mValue = -1;

            metaC.GetValue("M", ref mValue);
            Assert.Equal(default, mValue);
コード例 #23
0
        void SimpleTest()
        {
            var metadataBuilder = new MetadataBuilder();

            metadataBuilder.Add("M", NumberType.R4, (ref float v) => v = 484f);
            var schemaBuilder = new SchemaBuilder();

            schemaBuilder.AddColumn("A", new VectorType(NumberType.R4, 94));
            schemaBuilder.AddColumn("B", new KeyType(typeof(uint), 17));
            schemaBuilder.AddColumn("C", NumberType.I4, metadataBuilder.GetMetadata());

            var shape = SchemaShape.Create(schemaBuilder.GetSchema());

            var fakeSchema = FakeSchemaFactory.Create(shape);

            var columnA = fakeSchema[0];
            var columnB = fakeSchema[1];
            var columnC = fakeSchema[2];

            Assert.Equal("A", columnA.Name);
            Assert.Equal(NumberType.R4, columnA.Type.GetItemType());
            Assert.Equal(10, columnA.Type.GetValueCount());

            Assert.Equal("B", columnB.Name);
            Assert.Equal(DataKind.U4, columnB.Type.GetRawKind());
            Assert.Equal(10u, columnB.Type.GetKeyCount());

            Assert.Equal("C", columnC.Name);
            Assert.Equal(NumberType.I4, columnC.Type);

            var metaC = columnC.Metadata;

            Assert.Single(metaC.Schema);

            float mValue = -1;

            metaC.GetValue("M", ref mValue);
            Assert.Equal(default, mValue);
コード例 #24
0
        public void TestSchemaPropagation()
        {
            string dataPath = GetDataPath("adult.test");
            var    source   = new MultiFileSource(dataPath);
            var    loader   = ML.Data.CreateTextLoader(new[] {
                new TextLoader.Column("Float1", DataKind.R4, 0),
                new TextLoader.Column("Float4", DataKind.R4, new[] { new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }),
                new TextLoader.Column("Text1", DataKind.Text, 0)
            }, hasHeader: true, separatorChar: ',');

            var data = loader.Read(source);

            Action <MyInput, MyOutput> mapping = (input, output) => output.Together = input.Float1.ToString();
            var est = ML.Transforms.CustomMapping(mapping, null);

            // Make sure schema propagation works for valid data.
            est.GetOutputSchema(SchemaShape.Create(data.Schema));

            var badData1 = ML.Transforms.CopyColumns("Text1", "Float1").Fit(data).Transform(data);

            try
            {
                est.GetOutputSchema(SchemaShape.Create(badData1.Schema));
                Assert.True(false);
            }
            catch (Exception) { }

            var badData2 = ML.Transforms.SelectColumns(new[] { "Float1" }).Fit(data).Transform(data);

            try
            {
                est.GetOutputSchema(SchemaShape.Create(badData2.Schema));
                Assert.True(false);
            }
            catch (Exception) { }

            Done();
        }
コード例 #25
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string expectedColumnName)
            {
                if (!inputSchema.TryFindColumn(cachedColumn.Name, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(col), expectedColumnName, expectedColumnName);
                }

                if (!cachedColumn.IsCompatibleWith(col))
                {
                    throw Host.Except($"{expectedColumnName} column '{cachedColumn.Name}' is not compatible");
                }
            }

            // In prediction phase, no label column is expected.
            if (LabelColumn != null)
            {
                CheckColumnsCompatible(LabelColumn, LabelColumn.Name);
            }

            // In both of training and prediction phases, we need columns of user ID and column ID.
            CheckColumnsCompatible(MatrixColumnIndexColumn, MatrixColumnIndexColumn.Name);
            CheckColumnsCompatible(MatrixRowIndexColumn, MatrixRowIndexColumn.Name);

            // Input columns just pass through so that output column dictionary contains all input columns.
            var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);

            // Add columns produced by this estimator.
            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
コード例 #26
0
        private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            SchemaShape.Column?labelCol      = null;
            var predictedLabelAnnotationCols = AnnotationUtils.GetTrainerOutputAnnotation();

            if (LabelColumn.IsValid)
            {
                bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var inputLabelCol);
                Contracts.Assert(success);
                labelCol = inputLabelCol;
                predictedLabelAnnotationCols = predictedLabelAnnotationCols.Concat(
                    inputLabelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues));
            }
            var scoreAnnotationCols = AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol);

            return(new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single,
                                       false, new SchemaShape(scoreAnnotationCols)),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32,
                                       true, new SchemaShape(predictedLabelAnnotationCols))
            });
        }
コード例 #27
0
        /// <summary>
        /// Checks whether this object is consistent with an actual schema shape from a dynamic object,
        /// throwing exceptions if not.
        /// </summary>
        /// <param name="ectx">The context on which to throw exceptions</param>
        /// <param name="shape">The schema shape to check</param>
        public void Check(IExceptionContext ectx, SchemaShape shape)
        {
            Contracts.AssertValue(ectx);
            ectx.AssertValue(shape);

            foreach (var pair in Pairs)
            {
                var col = shape.FindColumn(pair.Key);
                if (col == null)
                {
                    throw ectx.ExceptParam(nameof(shape), $"Column named '{pair.Key}' was not found");
                }
                var type = GetTypeOrNull(col);
                if ((type != null && !pair.Value.IsAssignableFromStaticPipeline(type)) || (type == null && IsStandard(ectx, pair.Value)))
                {
                    // When not null, we can use IsAssignableFrom to indicate we could assign to this, so as to allow
                    // for example Key<uint, string> to be considered to be compatible with Key<uint>.

                    // In the null case, while we cannot directly verify an unrecognized type, we can at least verify
                    // that the statically declared type should not have corresponded to a recognized type.
                    if (!pair.Value.IsAssignableFromStaticPipeline(type))
                    {
                        // This is generally an error, unless it's the situation where the asserted type is Key<,> but we could
                        // only resolve it so far as Key<>, since for the moment the SchemaShape cannot determine the type of key
                        // value metadata. In which case, we can check if the declared type is a subtype of the key that was determined
                        // from the analysis.
                        if (pair.Value.IsGenericType && pair.Value.GetGenericTypeDefinition() == typeof(Key <,>) &&
                            type.IsAssignableFromStaticPipeline(pair.Value))
                        {
                            continue;
                        }
                        throw ectx.ExceptParam(nameof(shape),
                                               $"Column '{pair.Key}' of type '{col.GetTypeString()}' cannot be expressed statically as type '{pair.Value}'.");
                    }
                }
            }
        }
コード例 #28
0
        private void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            if (!inputSchema.TryFindColumn(FeatureColumn.Name, out var featureCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name);
            }
            if (!FeatureColumn.IsCompatibleWith(featureCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "feature", FeatureColumn.Name,
                                                FeatureColumn.GetTypeString(), featureCol.GetTypeString());
            }

            if (WeightColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name);
                }
                if (!WeightColumn.IsCompatibleWith(weightCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "weight", WeightColumn.Name,
                                                    WeightColumn.GetTypeString(), weightCol.GetTypeString());
                }
            }

            // Special treatment for label column: we allow different types of labels, so the trainers
            // may define their own requirements on the label column.
            if (LabelColumn.IsValid)
            {
                if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name);
                }
                CheckLabelCompatible(labelCol);
            }
        }
コード例 #29
0
        /// <summary>
        /// Schema propagation for transformers.
        /// Returns the output schema of the data, if the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(_options.Source, out var col))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source);
            }
            if (col.ItemType != NumberDataViewType.Single)
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _options.Source, "Single", col.GetTypeString());
            }

            var metadata = new List <SchemaShape.Column>()
            {
                new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)
            };
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            resultDic[_options.Name] = new SchemaShape.Column(
                _options.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Double, false, new SchemaShape(metadata));

            return(new SchemaShape(resultDic.Values));
        }
コード例 #30
0
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName);
            }
            if (col.ItemType != NumberType.R4)
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, NumberType.R4.ToString(), col.GetTypeString());
            }

            var metadata = new List <SchemaShape.Column>()
            {
                new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
            };
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
                Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));

            return(new SchemaShape(resultDic.Values));
        }