private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = null)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.CheckValue(args, nameof(args));
            _host.CheckNonWhiteSpace(args.InputColumn, nameof(args.InputColumn));
            _host.CheckNonWhiteSpace(args.OutputColumn, nameof(args.OutputColumn));

            if (modelBytes == null)
            {
                _host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
                _host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
                Model = new OnnxModel(args.ModelFile);
            }
            else
            {
                Model = OnnxModel.CreateFromBytes(modelBytes);
            }

            Input  = args.InputColumn;
            Output = args.OutputColumn;

            var outputNodeInfo = Model.GetOutputsInfo().Where(x => x.Name == args.OutputColumn).First();
            var type           = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
            var shape          = outputNodeInfo.Shape;
            var dims           = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 };

            OutputType = new VectorType(type, dims);
            _args      = args;
        }
        protected PerGroupTransformBase(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol, string registrationName)
        {
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register(registrationName);
            Host.CheckValue(input, nameof(input));
            Host.CheckNonWhiteSpace(labelCol, nameof(labelCol));
            Host.CheckNonWhiteSpace(scoreCol, nameof(scoreCol));
            Host.CheckNonWhiteSpace(groupCol, nameof(groupCol));

            Source   = input;
            LabelCol = labelCol;
            ScoreCol = scoreCol;
            GroupCol = groupCol;
        }
        private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = null)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.CheckValue(args, nameof(args));
            _host.CheckNonWhiteSpace(args.InputColumn, nameof(args.InputColumn));
            _host.CheckNonWhiteSpace(args.OutputColumn, nameof(args.OutputColumn));

            if (modelBytes == null)
            {
                _host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
                _host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
                Model = new OnnxModel(args.ModelFile);
            }
            else
            {
                Model = OnnxModel.CreateFromBytes(modelBytes);
            }

            var modelInfo = Model.ModelInfo;

            if (modelInfo.InputsInfo.Length != 1)
            {
                throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s).");
            }
            if (modelInfo.OutputsInfo.Length != 1)
            {
                throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s).");
            }

            Input  = args.InputColumn;
            Output = args.OutputColumn;

            var outputNodeInfo = Model.ModelInfo.OutputsInfo[0];
            var type           = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
            var shape          = outputNodeInfo.Shape;
            var dims           = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 };

            OutputType = new VectorType(type, dims);
            _args      = args;
        }
示例#4
0
        private SchemaBindablePipelineEnsembleBase(IHostEnvironment env, IPredictorModel[] predictors, string registrationName, string scoreColumnKind)
        {
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register(registrationName);
            Host.CheckNonEmpty(predictors, nameof(predictors));
            Host.CheckNonWhiteSpace(scoreColumnKind, nameof(scoreColumnKind));

            PredictorModels  = predictors;
            _scoreColumnKind = scoreColumnKind;

            HashSet <string> inputCols = null;

            for (int i = 0; i < predictors.Length; i++)
            {
                var predModel = predictors[i];

                // Get the input column names.
                var inputSchema = predModel.TransformModel.InputSchema;
                if (inputCols == null)
                {
                    inputCols = new HashSet <string>();
                    for (int j = 0; j < inputSchema.ColumnCount; j++)
                    {
                        if (inputSchema.IsHidden(j))
                        {
                            continue;
                        }
                        inputCols.Add(inputSchema.GetColumnName(j));
                    }
                    _inputCols = inputCols.ToArray();
                }
                else
                {
                    int nonHiddenCols = 0;
                    for (int j = 0; j < inputSchema.ColumnCount; j++)
                    {
                        if (inputSchema.IsHidden(j))
                        {
                            continue;
                        }
                        var name = inputSchema.GetColumnName(j);
                        if (!inputCols.Contains(name))
                        {
                            throw Host.Except("Inconsistent schemas: Some schemas do not contain the column '{0}'", name);
                        }
                        nonHiddenCols++;
                    }
                    Host.Check(nonHiddenCols == _inputCols.Length,
                               "Inconsistent schemas: not all schemas have the same number of columns");
                }
            }
        }
示例#5
0
        /// <summary>
        /// Returns the feature selection scores for each slot of each column.
        /// </summary>
        /// <param name="host">The host.</param>
        /// <param name="input">The input dataview.</param>
        /// <param name="labelColumnName">The label column.</param>
        /// <param name="columns">The columns for which to compute the feature selection scores.</param>
        /// <param name="numBins">The number of bins to use for numeric features.</param>
        /// <returns>A list of scores for each column and each slot.</returns>
        public static Single[][] Train(IHost host, IDataView input, string labelColumnName, string[] columns, int numBins)
        {
            Contracts.CheckValue(host, nameof(host));
            host.CheckValue(input, nameof(input));
            host.CheckNonWhiteSpace(labelColumnName, nameof(labelColumnName));
            host.CheckValue(columns, nameof(columns));
            host.Check(columns.Length > 0, "At least one column must be specified.");
            host.Check(numBins > 1, "numBins must be greater than 1.");

            HashSet <string> colSet = new HashSet <string>();

            foreach (string col in columns)
            {
                if (!colSet.Add(col))
                {
                    throw host.Except("Column '{0}' specified multiple times.", col);
                }
            }

            var colSizes = new int[columns.Length];

            return(TrainCore(host, input, labelColumnName, columns, numBins, colSizes));
        }
        private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(RegistrationName));
            _host.CheckValue(modelBytes, nameof(modelBytes));
            Session = LoadTFSession(modelBytes);
            foreach (var input in inputs)
            {
                _host.CheckNonWhiteSpace(input, nameof(inputs));
                if (Session.Graph[input] == null)
                {
                    throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model");
                }
                var tfInput = new TFOutput(Session.Graph[input]);
                if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType))
                {
                    throw _host.ExceptParam(nameof(modelBytes), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
                }
            }

            var newNames = new HashSet <string>();

            foreach (var output in outputs)
            {
                _host.CheckNonEmpty(output, nameof(outputs));
                if (!newNames.Add(output))
                {
                    throw _host.ExceptParam(nameof(outputs), $"Output column '{output}' specified multiple times");
                }
                if (Session.Graph[output] == null)
                {
                    throw _host.ExceptParam(nameof(outputs), $"Output column '{output}' does not exist in the model");
                }
            }

            Inputs        = inputs;
            TFInputTypes  = new TFDataType[Inputs.Length];
            TFInputShapes = new TFShape[Inputs.Length];
            for (int i = 0; i < Inputs.Length; i++)
            {
                var tfInput = new TFOutput(Graph[Inputs[i]]);
                TFInputTypes[i]  = tfInput.OutputType;
                TFInputShapes[i] = Graph.GetTensorShape(tfInput);
                var newShape = new long[TFInputShapes[i].NumDimensions];
                for (int j = 0; j < TFInputShapes[i].NumDimensions; j++)
                {
                    newShape[j] = TFInputShapes[i][j] == -1 ? BatchSize : TFInputShapes[i][j];
                }
                TFInputShapes[i] = new TFShape(newShape);
            }

            Outputs       = outputs;
            OutputTypes   = new ColumnType[Outputs.Length];
            TFOutputTypes = new TFDataType[Outputs.Length];
            for (int i = 0; i < Outputs.Length; i++)
            {
                var   tfOutput = new TFOutput(Graph[Outputs[i]]);
                var   shape    = Graph.GetTensorShape(tfOutput);
                int[] dims     = shape.ToIntArray().Skip(shape[0] == -1 ? BatchSize : 0).ToArray();
                var   type     = TensorFlowUtils.Tf2MlNetType(tfOutput.OutputType);
                OutputTypes[i]   = new VectorType(type, dims);
                TFOutputTypes[i] = tfOutput.OutputType;
            }
        }