/// <summary> /// If any column names in <param name="colNames" /> are present in <param name="input" />, this /// method will create a transform that copies them to temporary columns. It will populate <param name="hiddenNames" /> /// with an array of string pairs containing the original name and the generated temporary column name, respectively. /// </summary> /// <param name="env"></param> private static IDataView AliasIfNeeded(IHostEnvironment env, IDataView input, string[] colNames, out KeyValuePair <string, string>[] hiddenNames) { hiddenNames = null; var toHide = new List <string>(colNames.Length); foreach (var name in colNames) { int discard; if (input.Schema.TryGetColumnIndex(name, out discard)) { toHide.Add(name); } } if (toHide.Count == 0) { return(input); } hiddenNames = toHide.Select(colName => new KeyValuePair <string, string>(colName, input.Schema.GetTempColumnName(colName))).ToArray(); return(CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments() { Column = hiddenNames.Select(pair => new CopyColumnsTransform.Column() { Name = pair.Value, Source = pair.Key }).ToArray() }, input)); }
public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, CopyColumnsTransform.Arguments input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("CopyColumns"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); var xf = CopyColumnsTransform.Create(env, input, input.Data); return(new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }); }
private static IDataView UnaliasIfNeeded(IHostEnvironment env, IDataView input, KeyValuePair <string, string>[] hiddenNames) { if (Utils.Size(hiddenNames) == 0) { return(input); } input = CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments() { Column = hiddenNames.Select(pair => new CopyColumnsTransform.Column() { Name = pair.Key, Source = pair.Value }).ToArray() }, input); return(SelectColumnsTransform.CreateDrop(env, input, hiddenNames.Select(pair => pair.Value).ToArray())); }
public void TensorFlowTransformMNISTConvTest() { var model_location = "mnist_model/frozen_saved_model.pb"; using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) { var dataPath = GetDataPath("Train-Tiny-28x28.txt"); var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); // Pipeline var loader = TextLoader.ReadFile(env, new TextLoader.Arguments() { Separator = "tab", HasHeader = true, Column = new[] { new TextLoader.Column("Label", DataKind.Num, 0), new TextLoader.Column("Placeholder", DataKind.Num, new [] { new TextLoader.Range(1, 784) }) } }, new MultiFileSource(dataPath)); IDataView trans = CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments() { Column = new[] { new CopyColumnsTransform.Column() { Name = "reshape_input", Source = "Placeholder" } } }, loader); trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); trans = new ConcatTransform(env, "Features", "Softmax", "dense/Relu").Transform(trans); var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments()); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var pred = trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); var metrics = Evaluate(env, testDataScorer); Assert.Equal(0.99, metrics.AccuracyMicro, 2); Assert.Equal(1.0, metrics.AccuracyMacro, 2); // Create prediction engine and test predictions var model = env.CreatePredictionEngine <MNISTData, MNISTPrediction>(testDataScorer); var sample1 = new MNISTData() { Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 18, 18, 126, 136, 175, 26, 166, 255, 247, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253, 253, 253, 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253, 198, 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 11, 0, 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }; var prediction = model.Predict(sample1); float max = -1; int maxIndex = -1; for (int i = 0; i < prediction.PredictedLabels.Length; i++) { if (prediction.PredictedLabels[i] > max) { max = prediction.PredictedLabels[i]; maxIndex = i; } } Assert.Equal(5, maxIndex); } }