コード例 #1
0
        public ExtendedCacheTransform(IHostEnvironment env, Arguments args, IDataView input)
            : base(env, RegistrationName, input)
        {
            Host.CheckValue(args, "args");
            Host.CheckUserArg(args.inDataFrame || !string.IsNullOrEmpty(args.cacheFile), "cacheFile cannot be empty if inDataFrame is false.");
            Host.CheckUserArg(!args.async || args.inDataFrame, "inDataFrame must be true if async is true.");
            Host.CheckUserArg(!args.numTheads.HasValue || args.numTheads > 0, "numThread must be > 0 if specified.");
            var saverSettings = args.saverSettings as ICommandLineComponentFactory;

            Host.CheckValue(saverSettings, nameof(saverSettings));
            _saverSettings = string.Format("{0}{{{1}}}", saverSettings.Name, saverSettings.GetSettingsString());
            _saverSettings = _saverSettings.Replace("{}", "");
            if (!_saverSettings.ToLower().StartsWith("binary"))
            {
                throw env.ExceptNotSupp("Only binary format is supported.");
            }
            _inDataFrame = args.inDataFrame;
            _cacheFile   = args.cacheFile;
            _reuse       = args.reuse;
            _async       = args.async;
            _numThreads  = args.numTheads;

            var saver = ComponentCreation.CreateSaver(Host, _saverSettings);

            if (saver == null)
            {
                throw Host.Except("Cannot parse '{0}'", _saverSettings);
            }

            _pipedTransform = CreatePipeline(env, input);
        }
コード例 #2
0
        private MultiToRankerPredictor(IHostEnvironment env, ModelLoadContext ctx)
            : base(env, ctx, RegistrationName)
        {
            byte bkind = ctx.Reader.ReadByte();

            env.Check(bkind >= 0 && bkind <= 100, "kind");
            var kind = (DataKind)bkind;

            switch (kind)
            {
            case DataKind.R4:
                _impl = new ImplRawRanker <float>(ctx, env);
                break;

            case DataKind.U1:
                _impl = new ImplRawRanker <byte>(ctx, env);
                break;

            case DataKind.U2:
                _impl = new ImplRawRanker <ushort>(ctx, env);
                break;

            case DataKind.U4:
                _impl = new ImplRawRanker <uint>(ctx, env);
                break;

            default:
                throw env.ExceptNotSupp("Not supported label type.");
            }
        }
コード例 #3
0
            public AutoInference.EntryPointGraphDef ToEntryPointGraph(IHostEnvironment env)
            {
                // All transforms must have associated PipelineNode objects
                var unsupportedTransform = Transforms.Where(transform => transform.PipelineNode == null).Cast <TransformInference.SuggestedTransform?>().FirstOrDefault();

                if (unsupportedTransform != null)
                {
                    throw env.ExceptNotSupp($"All transforms in recipe must have entrypoint support. {unsupportedTransform} is not yet supported.");
                }
                var subGraph = env.CreateExperiment();

                Var <IDataView> lastOutput = new Var <IDataView>();

                // Chain transforms
                var transformsModels = new List <Var <ITransformModel> >();

                foreach (var transform in Transforms)
                {
                    transform.PipelineNode.SetInputData(lastOutput);
                    var transformAddResult = transform.PipelineNode.Add(subGraph);
                    transformsModels.Add(transformAddResult.Model);
                    lastOutput = transformAddResult.OutData;
                }

                // Add learner, if present. If not, just return transforms graph object.
                if (Learners.Length > 0 && Learners[0].PipelineNode != null)
                {
                    // Add learner
                    var learner = Learners[0];
                    learner.PipelineNode.SetInputData(lastOutput);
                    var learnerAddResult = learner.PipelineNode.Add(subGraph);

                    // Create single model for featurizing and scoring data,
                    // if transforms present.
                    if (Transforms.Length > 0)
                    {
                        var modelCombine = new ML.Legacy.Transforms.ManyHeterogeneousModelCombiner
                        {
                            TransformModels = new ArrayVar <ITransformModel>(transformsModels.ToArray()),
                            PredictorModel  = learnerAddResult.Model
                        };
                        var modelCombineOutput = subGraph.Add(modelCombine);

                        return(new AutoInference.EntryPointGraphDef(subGraph, modelCombineOutput.PredictorModel, lastOutput));
                    }

                    // No transforms present, so just return predictor's model.
                    return(new AutoInference.EntryPointGraphDef(subGraph, learnerAddResult.Model, lastOutput));
                }

                return(new AutoInference.EntryPointGraphDef(subGraph, null, lastOutput));
            }
コード例 #4
0
        public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView data, string graphColName, string metricColName, string idColName)
        {
            var results = new List <PipelineResultRow>();
            var schema  = data.Schema;

            if (!schema.TryGetColumnIndex(graphColName, out var graphCol))
            {
                throw env.ExceptNotSupp($"Column name {graphColName} not found");
            }
            if (!schema.TryGetColumnIndex(metricColName, out var metricCol))
            {
                throw env.ExceptNotSupp($"Column name {metricColName} not found");
            }
            if (!schema.TryGetColumnIndex(idColName, out var pipelineIdCol))
            {
                throw env.ExceptNotSupp($"Column name {idColName} not found");
            }

            using (var cursor = data.GetRowCursor(col => true))
            {
                while (cursor.MoveNext())
                {
                    var    getter1     = cursor.GetGetter <double>(metricCol);
                    double metricValue = 0;
                    getter1(ref metricValue);
                    var    getter2   = cursor.GetGetter <DvText>(graphCol);
                    DvText graphJson = new DvText();
                    getter2(ref graphJson);
                    var    getter3    = cursor.GetGetter <DvText>(pipelineIdCol);
                    DvText pipelineId = new DvText();
                    getter3(ref pipelineId);
                    results.Add(new PipelineResultRow(graphJson.ToString(), metricValue, pipelineId.ToString()));
                }
            }

            return(results.ToArray());
        }
コード例 #5
0
        public PrePostProcessPredictor(IHostEnvironment env, IDataTransform preProcess, IPredictor predictor,
                                       string inputColumn, string outputColumn, IDataTransform postProcess)
        {
            Contracts.CheckValue(env, "env");
            _host = env.Register("PrePostProcessPredictor");
            _host.CheckValue(predictor, "predictor");
            var val = predictor as IValueMapper;

            if (val == null)
            {
                throw env.ExceptNotSupp("Predictor must implemented IValueMapper interface.");
            }
            _preProcess             = preProcess;
            _inputColumn            = inputColumn;
            _outputColumn           = outputColumn;
            _transformFromPredictor = new TransformFromValueMapper(env, predictor as IValueMapper, _preProcess, inputColumn, outputColumn);
            _postProcess            = postProcess;
            _predictor = predictor;
        }
コード例 #6
0
        public ValueMapper <TSrc, TDst> GetMapper <TSrc, TDst>()
        {
            if (typeof(TSrc) != typeof(DataFrame))
            {
                throw _env.Except($"Cannot create a mapper with input type {typeof(TSrc)} != {typeof(DataFrame)} (expected).");
            }

            ValueMapper <TSrc, TDst> res;

            if (typeof(TDst) == typeof(DataFrame))
            {
                res = GetMapperRow() as ValueMapper <TSrc, TDst>;
            }
            else
            {
                res = GetMapperColumn <TDst>() as ValueMapper <TSrc, TDst>;
            }
            if (res == null)
            {
                throw _env.ExceptNotSupp($"Unable to create mapper from {typeof(TSrc)} to {typeof(TDst)}.");
            }
            return(res);
        }
コード例 #7
0
        public static ScikitOnnxContext ToOnnx(IDataTransform trans, ref string[] inputs, ref string[] outputs,
                                               string name             = null, string producer = "Scikit.ML",
                                               long version            = 0, string domain      = "onnx.ai.ml",
                                               OnnxVersion onnxVersion = OnnxVersion.Stable,
                                               IDataView[] begin       = null, IHostEnvironment host = null)
        {
            if (host == null)
            {
                using (var env = new DelegateEnvironment())
                    return(ToOnnx(trans, ref inputs, ref outputs, name, producer, version, domain, onnxVersion, begin, env));
            }
            if (name == null)
            {
                name = trans.GetType().Name;
            }

            GuessOutputs(trans, ref outputs);
            GuessInputs(trans, ref inputs, begin);

            if (inputs == null || inputs.Length == 0)
            {
                throw host.Except("Inputs cannot be empty.");
            }
            if (outputs == null || outputs.Length == 0)
            {
                throw host.Except("Outputs cannot be empty.");
            }

            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new ScikitOnnxContext(host, name, producer, versionInfo.FileVersion,
                                                    version, domain, onnxVersion);

            var hasin        = new HashSet <string>(inputs);
            var uniqueVars   = EnumerateVariables(trans, begin).ToArray();
            var mapInputType = new Dictionary <string, ColumnType>();

            foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.first))
            {
                mapInputType[it.variableName] = it.variableType;
            }

            foreach (var col in inputs)
            {
                ctx.AddInputVariable(mapInputType[col], col);
            }

            var views      = TagHelper.EnumerateAllViews(trans, begin);
            var transforms = views.Where(c => (c.Item1 as IDataTransform) != null)
                             .Select(c => c.Item1)
                             .ToArray();

            foreach (var tr in transforms.Reverse())
            {
                var tron = tr as ICanSaveOnnx;
                if (tron == null)
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in Onnx format.");
                }
                if (!tron.CanSaveOnnx(ctx))
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in ONNX format.");
                }
                var tron2 = tron as ISaveAsOnnx;
                if (!tron2.CanSaveOnnx(ctx))
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} does not implement SaveAsOnnx.");
                }
                tron2.SaveAsOnnx(ctx);
            }

            var mapOuputType = new Dictionary <string, ColumnType>();

            foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.last))
            {
                mapOuputType[it.variableName] = it.variableType;
            }

            foreach (var col in outputs)
            {
                var variableName     = ctx.TryGetVariableName(col);
                var trueVariableName = ctx.AddIntermediateVariable(null, col, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(mapOuputType[col], trueVariableName);
            }

            return(ctx);
        }