private static TOut CreatePipelineEnsemble <TOut>(IHostEnvironment env, PredictorModel[] predictors, SchemaBindablePipelineEnsembleBase ensemble) where TOut : CommonOutputs.TrainerOutput, new() { var inputSchema = predictors[0].TransformModel.InputSchema; var dv = new EmptyDataView(env, inputSchema); // The role mappings are specific to the individual predictors. var rmd = new RoleMappedData(dv); var predictorModel = new PredictorModelImpl(env, rmd, dv, ensemble); var output = new TOut { PredictorModel = predictorModel }; return(output); }
public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, ClassifierInput input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("CombineModels"); host.CheckValue(input, nameof(input)); host.CheckNonEmpty(input.Models, nameof(input.Models)); GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData); var args = new EnsembleTrainer.Arguments(); switch (input.ModelCombiner) { case ClassifierCombiner.Median: args.OutputCombiner = new MedianFactory(); break; case ClassifierCombiner.Average: args.OutputCombiner = new AverageFactory(); break; case ClassifierCombiner.Vote: args.OutputCombiner = new VotingFactory(); break; default: throw host.Except("Unknown combiner kind"); } var trainer = new EnsembleTrainer(host, args); var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing <float>)); var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble); var output = new CommonOutputs.BinaryClassificationOutput { PredictorModel = predictorModel }; return(output); }
protected SchemaBindablePipelineEnsembleBase(IHostEnvironment env, ModelLoadContext ctx, string scoreColumnKind) { Host = env.Register(LoaderSignature); Host.AssertNonEmpty(scoreColumnKind); _scoreColumnKind = scoreColumnKind; // *** Binary format *** // int: id of _scoreColumnKind (loaded in the Create method) // int: number of predictors // The predictor models // int: the number of input columns // for each input column: // int: id of the column name var length = ctx.Reader.ReadInt32(); Host.CheckDecode(length > 0); PredictorModels = new PredictorModel[length]; for (int i = 0; i < PredictorModels.Length; i++) { string dir = ctx.Header.ModelVerWritten == 0x00010001 ? "PredictorModels" : Path.Combine(ctx.Directory, "PredictorModels"); using (var ent = ctx.Repository.OpenEntry(dir, $"PredictorModel_{i:000}")) PredictorModels[i] = new PredictorModelImpl(Host, ent.Stream); } length = ctx.Reader.ReadInt32(); Host.CheckDecode(length >= 0); _inputCols = new string[length]; for (int i = 0; i < length; i++) { _inputCols[i] = ctx.LoadNonEmptyString(); } }
public void SetInputFromPath(GraphRunner runner, string varName, string path, TlcModule.DataKind kind) { _host.CheckUserArg(runner != null, nameof(runner), "Provide a GraphRunner instance."); _host.CheckUserArg(!string.IsNullOrWhiteSpace(varName), nameof(varName), "Specify a graph variable name."); _host.CheckUserArg(!string.IsNullOrWhiteSpace(path), nameof(path), "Specify a valid file path."); switch (kind) { case TlcModule.DataKind.FileHandle: var fh = new SimpleFileHandle(_host, path, false, false); runner.SetInput(varName, fh); break; case TlcModule.DataKind.DataView: IDataView loader = new BinaryLoader(_host, new BinaryLoader.Arguments(), path); runner.SetInput(varName, loader); break; case TlcModule.DataKind.PredictorModel: PredictorModelImpl pm; using (var fs = File.OpenRead(path)) pm = new PredictorModelImpl(_host, fs); runner.SetInput(varName, pm); break; case TlcModule.DataKind.TransformModel: TransformModelImpl tm; using (var fs = File.OpenRead(path)) tm = new TransformModelImpl(_host, fs); runner.SetInput(varName, tm); break; default: throw _host.Except("Port type {0} not supported", kind); } }
private static void RunGraphCore(EnvironmentBlock *penv, IHostEnvironment env, string graphStr, int cdata, DataSourceBlock **ppdata) { Contracts.AssertValue(env); var host = env.Register("RunGraph", penv->seed, null); JObject graph; try { graph = JObject.Parse(graphStr); } catch (JsonReaderException ex) { throw host.Except(ex, "Failed to parse experiment graph: {0}", ex.Message); } var runner = new GraphRunner(host, graph["nodes"] as JArray); var dvNative = new IDataView[cdata]; try { for (int i = 0; i < cdata; i++) { dvNative[i] = new NativeDataView(host, ppdata[i]); } // Setting inputs. var jInputs = graph["inputs"] as JObject; if (graph["inputs"] != null && jInputs == null) { throw host.Except("Unexpected value for 'inputs': {0}", graph["inputs"]); } int iDv = 0; if (jInputs != null) { foreach (var kvp in jInputs) { var pathValue = kvp.Value as JValue; if (pathValue == null) { throw host.Except("Invalid value for input: {0}", kvp.Value); } var path = pathValue.Value <string>(); var varName = kvp.Key; var type = runner.GetPortDataKind(varName); switch (type) { case TlcModule.DataKind.FileHandle: var fh = new SimpleFileHandle(host, path, false, false); runner.SetInput(varName, fh); break; case TlcModule.DataKind.DataView: IDataView dv; if (!string.IsNullOrWhiteSpace(path)) { var extension = Path.GetExtension(path); if (extension == ".txt") { dv = TextLoader.LoadFile(host, new TextLoader.Options(), new MultiFileSource(path)); } else if (extension == ".dprep") { dv = LoadDprepFile(BytesToString(penv->pythonPath), path); } else { dv = new BinaryLoader(host, new BinaryLoader.Arguments(), path); } } else { Contracts.Assert(iDv < dvNative.Length); // prefetch all columns dv = dvNative[iDv++]; var prefetch = new int[dv.Schema.Count]; for (int i = 0; i < prefetch.Length; i++) { prefetch[i] = i; } dv = new CacheDataView(host, dv, prefetch); } runner.SetInput(varName, dv); break; case TlcModule.DataKind.PredictorModel: PredictorModel pm; if (!string.IsNullOrWhiteSpace(path)) { using (var fs = File.OpenRead(path)) pm = new PredictorModelImpl(host, fs); } else { throw host.Except("Model must be loaded from a file"); } runner.SetInput(varName, pm); break; case TlcModule.DataKind.TransformModel: TransformModel tm; if (!string.IsNullOrWhiteSpace(path)) { using (var fs = File.OpenRead(path)) tm = new TransformModelImpl(host, fs); } else { throw host.Except("Model must be loaded from a file"); } runner.SetInput(varName, tm); break; default: throw host.Except("Port type {0} not supported", type); } } } runner.RunAll(); // Reading outputs. using (var ch = host.Start("Reading outputs")) { var jOutputs = graph["outputs"] as JObject; if (jOutputs != null) { foreach (var kvp in jOutputs) { var pathValue = kvp.Value as JValue; if (pathValue == null) { throw host.Except("Invalid value for input: {0}", kvp.Value); } var path = pathValue.Value <string>(); var varName = kvp.Key; var type = runner.GetPortDataKind(varName); switch (type) { case TlcModule.DataKind.FileHandle: var fh = runner.GetOutput <IFileHandle>(varName); throw host.ExceptNotSupp("File handle outputs not yet supported."); case TlcModule.DataKind.DataView: var idv = runner.GetOutput <IDataView>(varName); if (path == CSR_MATRIX) { SendViewToNativeAsCsr(ch, penv, idv); } else if (!string.IsNullOrWhiteSpace(path)) { SaveIdvToFile(idv, path, host); } else { var infos = ProcessColumns(ref idv, penv->maxSlots, host); SendViewToNativeAsDataFrame(ch, penv, idv, infos); } break; case TlcModule.DataKind.PredictorModel: var pm = runner.GetOutput <PredictorModel>(varName); if (!string.IsNullOrWhiteSpace(path)) { SavePredictorModelToFile(pm, path, host); } else { throw host.Except("Returning in-memory models is not supported"); } break; case TlcModule.DataKind.TransformModel: var tm = runner.GetOutput <TransformModel>(varName); if (!string.IsNullOrWhiteSpace(path)) { using (var fs = File.OpenWrite(path)) tm.Save(host, fs); } else { throw host.Except("Returning in-memory models is not supported"); } break; case TlcModule.DataKind.Array: var objArray = runner.GetOutput <object[]>(varName); if (objArray is PredictorModel[]) { var modelArray = (PredictorModel[])objArray; // Save each model separately for (var i = 0; i < modelArray.Length; i++) { var modelPath = string.Format(CultureInfo.InvariantCulture, path, i); SavePredictorModelToFile(modelArray[i], modelPath, host); } } else { throw host.Except("DataKind.Array type {0} not supported", objArray.First().GetType()); } break; default: throw host.Except("Port type {0} not supported", type); } } } } } finally { // The raw data view is disposable so it lets go of unmanaged raw pointers before we return. for (int i = 0; i < dvNative.Length; i++) { var view = dvNative[i]; if (view == null) { continue; } host.Assert(view is IDisposable); var disp = (IDisposable)dvNative[i]; disp.Dispose(); } } }