[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline public void OnnxModelMultiInput() { if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { return; } var modelFile = @"twoinput\twoinput.onnx"; using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) { var samplevector = GetSampleArrayData(); var dataView = ComponentCreation.CreateDataView(Env, new TestDataMulti[] { new TestDataMulti() { ina = new float[] { 1, 2, 3, 4, 5 }, inb = new float[] { 1, 2, 3, 4, 5 } } }); var onnx = OnnxTransform.Create(env, dataView, modelFile, new[] { "ina", "inb" }, new[] { "outa", "outb" }); onnx.Schema.TryGetColumnIndex("outa", out int scoresa); onnx.Schema.TryGetColumnIndex("outb", out int scoresb); using (var curs = onnx.GetRowCursor(col => col == scoresa || col == scoresb)) { var getScoresa = curs.GetGetter <VBuffer <float> >(scoresa); var getScoresb = curs.GetGetter <VBuffer <float> >(scoresb); var buffera = default(VBuffer <float>); var bufferb = default(VBuffer <float>); while (curs.MoveNext()) { getScoresa(ref buffera); getScoresb(ref bufferb); Assert.Equal(5, buffera.Length); Assert.Equal(5, bufferb.Length); Assert.Equal(0, buffera.GetValues().ToArray().Sum()); Assert.Equal(30, bufferb.GetValues().ToArray().Sum()); } } } }
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline public void OnnxModelScenario() { if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { return; } var modelFile = "squeezenet/00000001/model.onnx"; using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) { var samplevector = GetSampleArrayData(); var dataView = ComponentCreation.CreateDataView(Env, new TestData[] { new TestData() { data_0 = samplevector } }); var onnx = OnnxTransform.Create(env, dataView, modelFile, new[] { "data_0" }, new[] { "softmaxout_1" }); onnx.Schema.TryGetColumnIndex("softmaxout_1", out int scores); using (var curs = onnx.GetRowCursor(col => col == scores)) { var getScores = curs.GetGetter <VBuffer <float> >(scores); var buffer = default(VBuffer <float>); while (curs.MoveNext()) { getScores(ref buffer); Assert.Equal(1000, buffer.Length); } } } }
/// <summary> /// Initializes a new instance of <see cref="OnnxScoringEstimator"/>. /// </summary> /// <param name="catalog">The transform's catalog.</param> /// <param name="transformer">The ONNX transformer.</param> public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog, OnnxTransform transformer) => new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), transformer);