public BoundBase(SchemaBindablePipelineEnsembleBase parent, RoleMappedSchema schema)
            {
                Parent = parent;
                InputRoleMappedSchema = schema;
                OutputSchema          = ScoreSchemaFactory.Create(Parent.ScoreType, Parent._scoreColumnKind);
                _inputColIndices      = new HashSet <int>();
                for (int i = 0; i < Parent._inputCols.Length; i++)
                {
                    var name = Parent._inputCols[i];
                    var col  = InputRoleMappedSchema.Schema.GetColumnOrNull(name);
                    if (!col.HasValue)
                    {
                        throw Parent.Host.ExceptSchemaMismatch(nameof(InputRoleMappedSchema), "input", name);
                    }
                    _inputColIndices.Add(col.Value.Index);
                }

                Mappers        = new ISchemaBoundRowMapper[Parent.PredictorModels.Length];
                BoundPipelines = new IRowToRowMapper[Parent.PredictorModels.Length];
                ScoreCols      = new int[Parent.PredictorModels.Length];
                for (int i = 0; i < Mappers.Length; i++)
                {
                    // Get the RoleMappedSchema to pass to the predictor.
                    var emptyDv = new EmptyDataView(Parent.Host, schema.Schema);
                    Parent.PredictorModels[i].PrepareData(Parent.Host, emptyDv, out RoleMappedData rmd, out IPredictor predictor);

                    // Get the predictor as a bindable mapper, and bind it to the RoleMappedSchema found above.
                    var bindable = ScoreUtils.GetSchemaBindableMapper(Parent.Host, Parent.PredictorModels[i].Predictor);
                    Mappers[i] = bindable.Bind(Parent.Host, rmd.Schema) as ISchemaBoundRowMapper;
                    if (Mappers[i] == null)
                    {
                        throw Parent.Host.Except("Predictor {0} is not a row to row mapper", i);
                    }

                    // Make sure there is a score column, and remember its index.
                    var scoreCol = Mappers[i].OutputSchema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score);
                    if (!scoreCol.HasValue)
                    {
                        throw Parent.Host.Except("Predictor {0} does not contain a score column", i);
                    }
                    ScoreCols[i] = scoreCol.Value.Index;

                    // Get the pipeline.
                    var dv       = new EmptyDataView(Parent.Host, schema.Schema);
                    var tm       = new TransformModelImpl(Parent.Host, dv, dv);
                    var pipeline = Parent.PredictorModels[i].TransformModel.Apply(Parent.Host, tm);
                    BoundPipelines[i] = pipeline.AsRowToRowMapper(Parent.Host);
                    if (BoundPipelines[i] == null)
                    {
                        throw Parent.Host.Except("Transform pipeline {0} contains transforms that do not implement IRowToRowMapper", i);
                    }
                }
            }
Esempio n. 2
0
        public void SetInputFromPath(GraphRunner runner, string varName, string path, TlcModule.DataKind kind)
        {
            _host.CheckUserArg(runner != null, nameof(runner), "Provide a GraphRunner instance.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(varName), nameof(varName), "Specify a graph variable name.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(path), nameof(path), "Specify a valid file path.");

            switch (kind)
            {
            case TlcModule.DataKind.FileHandle:
                var fh = new SimpleFileHandle(_host, path, false, false);
                runner.SetInput(varName, fh);
                break;

            case TlcModule.DataKind.DataView:
                IDataView loader = new BinaryLoader(_host, new BinaryLoader.Arguments(), path);
                runner.SetInput(varName, loader);
                break;

            case TlcModule.DataKind.PredictorModel:
                PredictorModelImpl pm;
                using (var fs = File.OpenRead(path))
                    pm = new PredictorModelImpl(_host, fs);
                runner.SetInput(varName, pm);
                break;

            case TlcModule.DataKind.TransformModel:
                TransformModelImpl tm;
                using (var fs = File.OpenRead(path))
                    tm = new TransformModelImpl(_host, fs);
                runner.SetInput(varName, tm);
                break;

            default:
                throw _host.Except("Port type {0} not supported", kind);
            }
        }
Esempio n. 3
0
        private static void RunGraphCore(EnvironmentBlock *penv, IHostEnvironment env, string graphStr, int cdata, DataSourceBlock **ppdata)
        {
            Contracts.AssertValue(env);

            var     host = env.Register("RunGraph", penv->seed, null);
            JObject graph;

            try
            {
                graph = JObject.Parse(graphStr);
            }
            catch (JsonReaderException ex)
            {
                throw host.Except(ex, "Failed to parse experiment graph: {0}", ex.Message);
            }

            var runner = new GraphRunner(host, graph["nodes"] as JArray);

            var dvNative = new IDataView[cdata];

            try
            {
                for (int i = 0; i < cdata; i++)
                {
                    dvNative[i] = new NativeDataView(host, ppdata[i]);
                }

                // Setting inputs.
                var jInputs = graph["inputs"] as JObject;
                if (graph["inputs"] != null && jInputs == null)
                {
                    throw host.Except("Unexpected value for 'inputs': {0}", graph["inputs"]);
                }
                int iDv = 0;
                if (jInputs != null)
                {
                    foreach (var kvp in jInputs)
                    {
                        var pathValue = kvp.Value as JValue;
                        if (pathValue == null)
                        {
                            throw host.Except("Invalid value for input: {0}", kvp.Value);
                        }

                        var path    = pathValue.Value <string>();
                        var varName = kvp.Key;
                        var type    = runner.GetPortDataKind(varName);

                        switch (type)
                        {
                        case TlcModule.DataKind.FileHandle:
                            var fh = new SimpleFileHandle(host, path, false, false);
                            runner.SetInput(varName, fh);
                            break;

                        case TlcModule.DataKind.DataView:
                            IDataView dv;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                var extension = Path.GetExtension(path);
                                if (extension == ".txt")
                                {
                                    dv = TextLoader.LoadFile(host, new TextLoader.Options(), new MultiFileSource(path));
                                }
                                else if (extension == ".dprep")
                                {
                                    dv = LoadDprepFile(BytesToString(penv->pythonPath), path);
                                }
                                else
                                {
                                    dv = new BinaryLoader(host, new BinaryLoader.Arguments(), path);
                                }
                            }
                            else
                            {
                                Contracts.Assert(iDv < dvNative.Length);
                                // prefetch all columns
                                dv = dvNative[iDv++];
                                var prefetch = new int[dv.Schema.Count];
                                for (int i = 0; i < prefetch.Length; i++)
                                {
                                    prefetch[i] = i;
                                }
                                dv = new CacheDataView(host, dv, prefetch);
                            }
                            runner.SetInput(varName, dv);
                            break;

                        case TlcModule.DataKind.PredictorModel:
                            PredictorModel pm;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                using (var fs = File.OpenRead(path))
                                    pm = new PredictorModelImpl(host, fs);
                            }
                            else
                            {
                                throw host.Except("Model must be loaded from a file");
                            }
                            runner.SetInput(varName, pm);
                            break;

                        case TlcModule.DataKind.TransformModel:
                            TransformModel tm;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                using (var fs = File.OpenRead(path))
                                    tm = new TransformModelImpl(host, fs);
                            }
                            else
                            {
                                throw host.Except("Model must be loaded from a file");
                            }
                            runner.SetInput(varName, tm);
                            break;

                        default:
                            throw host.Except("Port type {0} not supported", type);
                        }
                    }
                }
                runner.RunAll();

                // Reading outputs.
                using (var ch = host.Start("Reading outputs"))
                {
                    var jOutputs = graph["outputs"] as JObject;
                    if (jOutputs != null)
                    {
                        foreach (var kvp in jOutputs)
                        {
                            var pathValue = kvp.Value as JValue;
                            if (pathValue == null)
                            {
                                throw host.Except("Invalid value for input: {0}", kvp.Value);
                            }
                            var path    = pathValue.Value <string>();
                            var varName = kvp.Key;
                            var type    = runner.GetPortDataKind(varName);

                            switch (type)
                            {
                            case TlcModule.DataKind.FileHandle:
                                var fh = runner.GetOutput <IFileHandle>(varName);
                                throw host.ExceptNotSupp("File handle outputs not yet supported.");

                            case TlcModule.DataKind.DataView:
                                var idv = runner.GetOutput <IDataView>(varName);
                                if (path == CSR_MATRIX)
                                {
                                    SendViewToNativeAsCsr(ch, penv, idv);
                                }
                                else if (!string.IsNullOrWhiteSpace(path))
                                {
                                    SaveIdvToFile(idv, path, host);
                                }
                                else
                                {
                                    var infos = ProcessColumns(ref idv, penv->maxSlots, host);
                                    SendViewToNativeAsDataFrame(ch, penv, idv, infos);
                                }
                                break;

                            case TlcModule.DataKind.PredictorModel:
                                var pm = runner.GetOutput <PredictorModel>(varName);
                                if (!string.IsNullOrWhiteSpace(path))
                                {
                                    SavePredictorModelToFile(pm, path, host);
                                }
                                else
                                {
                                    throw host.Except("Returning in-memory models is not supported");
                                }
                                break;

                            case TlcModule.DataKind.TransformModel:
                                var tm = runner.GetOutput <TransformModel>(varName);
                                if (!string.IsNullOrWhiteSpace(path))
                                {
                                    using (var fs = File.OpenWrite(path))
                                        tm.Save(host, fs);
                                }
                                else
                                {
                                    throw host.Except("Returning in-memory models is not supported");
                                }
                                break;

                            case TlcModule.DataKind.Array:
                                var objArray = runner.GetOutput <object[]>(varName);
                                if (objArray is PredictorModel[])
                                {
                                    var modelArray = (PredictorModel[])objArray;
                                    // Save each model separately
                                    for (var i = 0; i < modelArray.Length; i++)
                                    {
                                        var modelPath = string.Format(CultureInfo.InvariantCulture, path, i);
                                        SavePredictorModelToFile(modelArray[i], modelPath, host);
                                    }
                                }
                                else
                                {
                                    throw host.Except("DataKind.Array type {0} not supported", objArray.First().GetType());
                                }
                                break;

                            default:
                                throw host.Except("Port type {0} not supported", type);
                            }
                        }
                    }
                }
            }
            finally
            {
                // The raw data view is disposable so it lets go of unmanaged raw pointers before we return.
                for (int i = 0; i < dvNative.Length; i++)
                {
                    var view = dvNative[i];
                    if (view == null)
                    {
                        continue;
                    }
                    host.Assert(view is IDisposable);
                    var disp = (IDisposable)dvNative[i];
                    disp.Dispose();
                }
            }
        }
Esempio n. 4
0
 internal PredictionModel(Stream stream)
 {
     _env = new MLContext();
     AssemblyRegistration.RegisterAssemblies(_env);
     PredictorModel = new TransformModelImpl(_env, stream);
 }