Exemplo n.º 1
0
 internal static OvaModelParameters Create(IHost host, TScalarPredictor[] predictors)
 {
     Contracts.CheckValue(host, nameof(host));
     host.CheckNonEmpty(predictors, nameof(predictors));
     return(Create(host, OutputFormula.ProbabilityNormalization, predictors));
 }
Exemplo n.º 2
0
        private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(RegistrationName));
            _host.CheckValue(session, nameof(session));
            _host.CheckNonEmpty(inputs, nameof(inputs));
            _host.CheckNonEmpty(outputs, nameof(outputs));
            Session = session;
            foreach (var input in inputs)
            {
                _host.CheckNonWhiteSpace(input, nameof(inputs));
                if (Session.Graph[input] == null)
                {
                    throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model");
                }
                var tfInput = new TFOutput(Session.Graph[input]);
                if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType))
                {
                    throw _host.ExceptParam(nameof(session), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
                }
            }

            var newNames = new HashSet <string>();

            foreach (var output in outputs)
            {
                _host.CheckNonWhiteSpace(output, nameof(outputs));
                if (!newNames.Add(output))
                {
                    throw _host.ExceptParam(nameof(outputs), $"Output column '{output}' specified multiple times");
                }
                if (Session.Graph[output] == null)
                {
                    throw _host.ExceptParam(nameof(outputs), $"Output column '{output}' does not exist in the model");
                }
            }

            Inputs        = inputs;
            TFInputTypes  = new TFDataType[Inputs.Length];
            TFInputShapes = new TFShape[Inputs.Length];
            for (int i = 0; i < Inputs.Length; i++)
            {
                var tfInput = new TFOutput(Graph[Inputs[i]]);
                TFInputTypes[i]  = tfInput.OutputType;
                TFInputShapes[i] = Graph.GetTensorShape(tfInput);
                if (TFInputShapes[i].NumDimensions != -1)
                {
                    var newShape = new long[TFInputShapes[i].NumDimensions];
                    newShape[0] = TFInputShapes[i][0] == -1 ? BatchSize : TFInputShapes[i][0];

                    for (int j = 1; j < TFInputShapes[i].NumDimensions; j++)
                    {
                        newShape[j] = TFInputShapes[i][j];
                    }
                    TFInputShapes[i] = new TFShape(newShape);
                }
            }

            Outputs       = outputs;
            OutputTypes   = new ColumnType[Outputs.Length];
            TFOutputTypes = new TFDataType[Outputs.Length];
            for (int i = 0; i < Outputs.Length; i++)
            {
                var   tfOutput = new TFOutput(Graph[Outputs[i]]);
                var   shape    = Graph.GetTensorShape(tfOutput);
                int[] dims     = shape.NumDimensions > 0 ? shape.ToIntArray().Skip(shape[0] == -1 ? BatchSize : 0).ToArray() : new[] { 0 };
                var   type     = TensorFlowUtils.Tf2MlNetType(tfOutput.OutputType);
                OutputTypes[i]   = new VectorType(type, dims);
                TFOutputTypes[i] = tfOutput.OutputType;
            }
        }
Exemplo n.º 3
0
 /// <summary>
 /// Create a OVA predictor from an array of predictors.
 /// </summary>
 public static OvaPredictor Create(IHost host, TScalarPredictor[] predictors)
 {
     Contracts.CheckValue(host, nameof(host));
     host.CheckNonEmpty(predictors, nameof(predictors));
     return(Create(host, true, predictors));
 }
Exemplo n.º 4
0
 public void PrintOverallResults(IChannel ch, string filename, params Dictionary <string, IDataView>[] metrics)
 {
     Host.CheckValue(ch, nameof(ch));
     Host.CheckNonEmpty(metrics, nameof(metrics));
     PrintOverallResultsCore(ch, filename, metrics);
 }