private void GetPipe(IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList <ITransformCanSavePfa> transforms)
        {
            Host.AssertValue(end);
            source = trueEnd = (end as CompositeDataLoader)?.View ?? end;
            IDataTransform transform = source as IDataTransform;

            transforms = new LinkedList <ITransformCanSavePfa>();
            while (transform != null)
            {
                ITransformCanSavePfa pfaTransform = transform as ITransformCanSavePfa;
                if (pfaTransform == null || !pfaTransform.CanSavePfa)
                {
                    ch.Warning("Had to stop walkback of pipeline at {0} since it cannot save itself as PFA", transform.GetType().Name);
                    return;
                }
                transforms.AddFirst(pfaTransform);
                transform = (source = transform.Source) as IDataTransform;
            }
            Host.AssertValue(source);
        }
示例#2
0
        /// <summary>
        /// Finalize the test on a transform, calls the transform,
        /// saves the data, saves the models, loads it back, saves the data again,
        /// checks the output is the same.
        /// </summary>
        /// <param name="env">environment</param>
        /// <param name="outModelFilePath"model filename</param>
        /// <param name="transform">transform to test</param>
        /// <param name="source">source (view before applying the transform</param>
        /// <param name="outData">fist data file</param>
        /// <param name="outData2">second data file</param>
        /// <param name="startsWith">Check that outputs is the same on disk after outputting the transformed data after the model was serialized</param>
        public static void SerializationTestTransform(IHostEnvironment env,
                                                      string outModelFilePath, IDataTransform transform,
                                                      IDataView source, string outData, string outData2,
                                                      bool startsWith = false, bool skipDoubleQuote = false,
                                                      bool forceDense = false)
        {
            // Saves model.
            var roles = env.CreateExamples(transform, null);

            using (var ch = env.Start("SaveModel"))
                using (var fs = File.Create(outModelFilePath))
                    TrainUtils.SaveModel(env, ch, fs, null, roles);
            if (!File.Exists(outModelFilePath))
            {
                throw new FileNotFoundException(outModelFilePath);
            }

            // We load it again.
            using (var fs = File.OpenRead(outModelFilePath))
            {
                var tr2 = env.LoadTransforms(fs, source);
                if (tr2 == null)
                {
                    throw new Exception(string.Format("Unable to load '{0}'", outModelFilePath));
                }
                if (transform.GetType() != tr2.GetType())
                {
                    throw new Exception(string.Format("Type mismatch {0} != {1}", transform.GetType(), tr2.GetType()));
                }
            }

            // Checks the outputs.
            var saver   = env.CreateSaver(forceDense ? "Text{dense=+}" : "Text");
            var columns = new int[transform.Schema.Count];

            for (int i = 0; i < columns.Length; ++i)
            {
                columns[i] = i;
            }
            using (var fs2 = File.Create(outData))
                saver.SaveData(fs2, transform, columns);

            if (!File.Exists(outModelFilePath))
            {
                throw new FileNotFoundException(outData);
            }

            // Check we have the same output.
            using (var fs = File.OpenRead(outModelFilePath))
            {
                var tr = env.LoadTransforms(fs, source);
                saver = env.CreateSaver(forceDense ? "Text{dense=+}" : "Text");
                using (var fs2 = File.Create(outData2))
                    saver.SaveData(fs2, tr, columns);
            }

            var t1 = File.ReadAllLines(outData);
            var t2 = File.ReadAllLines(outData2);

            if (t1.Length != t2.Length)
            {
                throw new Exception(string.Format("Not the same number of lines: {0} != {1}", t1.Length, t2.Length));
            }
            for (int i = 0; i < t1.Length; ++i)
            {
                if (skipDoubleQuote && (t1[i].Contains("\"\"\t\"\"") || t2[i].Contains("\"\"\t\"\"")))
                {
                    continue;
                }
                if ((startsWith && !t1[i].StartsWith(t2[i])) || (!startsWith && t1[i] != t2[i]))
                {
                    if (t1[i].EndsWith("\t5\t0:\"\""))
                    {
                        var a = t1[i].Substring(0, t1[i].Length - "\t5\t0:\"\"".Length);
                        a += "\t\"\"\t\"\"\t\"\"\t\"\"\t\"\"";
                        var b = t2[i];
                        if ((startsWith && !a.StartsWith(b)) || (!startsWith && a != b))
                        {
                            throw new Exception(string.Format("2-Mismatch on line {0}/{3}:\n{1}\n{2}", i, t1[i], t2[i], t1.Length));
                        }
                    }
                    else
                    {
                        // The test might fail because one side is dense and the other is sparse.
                        throw new Exception(string.Format("3-Mismatch on line {0}/{3}:\n{1}\n{2}", i, t1[i], t2[i], t1.Length));
                    }
                }
            }
        }
示例#3
0
        internal static void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList <ITransformCanSaveOnnx> transforms)
        {
            ch.AssertValue(end);

            source = trueEnd = (end as LegacyCompositeDataLoader)?.View ?? end;
            IDataTransform transform = source as IDataTransform;

            transforms = new LinkedList <ITransformCanSaveOnnx>();
            while (transform != null)
            {
                ITransformCanSaveOnnx onnxTransform = transform as ITransformCanSaveOnnx;
                if (onnxTransform == null || !onnxTransform.CanSaveOnnx(ctx))
                {
                    ch.Warning("Had to stop walkback of pipeline at {0} since it cannot save itself as ONNX.", transform.GetType().Name);
                    while (source as IDataTransform != null)
                    {
                        source = (source as IDataTransform).Source;
                    }

                    return;
                }
                transforms.AddFirst(onnxTransform);
                transform = (source = transform.Source) as IDataTransform;
            }

            ch.AssertValue(source);
        }
示例#4
0
        public static ScikitOnnxContext ToOnnx(IDataTransform trans, ref string[] inputs, ref string[] outputs,
                                               string name             = null, string producer = "Scikit.ML",
                                               long version            = 0, string domain      = "onnx.ai.ml",
                                               OnnxVersion onnxVersion = OnnxVersion.Stable,
                                               IDataView[] begin       = null, IHostEnvironment host = null)
        {
            if (host == null)
            {
                using (var env = new DelegateEnvironment())
                    return(ToOnnx(trans, ref inputs, ref outputs, name, producer, version, domain, onnxVersion, begin, env));
            }
            if (name == null)
            {
                name = trans.GetType().Name;
            }

            GuessOutputs(trans, ref outputs);
            GuessInputs(trans, ref inputs, begin);

            if (inputs == null || inputs.Length == 0)
            {
                throw host.Except("Inputs cannot be empty.");
            }
            if (outputs == null || outputs.Length == 0)
            {
                throw host.Except("Outputs cannot be empty.");
            }

            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new ScikitOnnxContext(host, name, producer, versionInfo.FileVersion,
                                                    version, domain, onnxVersion);

            var hasin        = new HashSet <string>(inputs);
            var uniqueVars   = EnumerateVariables(trans, begin).ToArray();
            var mapInputType = new Dictionary <string, ColumnType>();

            foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.first))
            {
                mapInputType[it.variableName] = it.variableType;
            }

            foreach (var col in inputs)
            {
                ctx.AddInputVariable(mapInputType[col], col);
            }

            var views      = TagHelper.EnumerateAllViews(trans, begin);
            var transforms = views.Where(c => (c.Item1 as IDataTransform) != null)
                             .Select(c => c.Item1)
                             .ToArray();

            foreach (var tr in transforms.Reverse())
            {
                var tron = tr as ICanSaveOnnx;
                if (tron == null)
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in Onnx format.");
                }
                if (!tron.CanSaveOnnx(ctx))
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in ONNX format.");
                }
                var tron2 = tron as ISaveAsOnnx;
                if (!tron2.CanSaveOnnx(ctx))
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} does not implement SaveAsOnnx.");
                }
                tron2.SaveAsOnnx(ctx);
            }

            var mapOuputType = new Dictionary <string, ColumnType>();

            foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.last))
            {
                mapOuputType[it.variableName] = it.variableType;
            }

            foreach (var col in outputs)
            {
                var variableName     = ctx.TryGetVariableName(col);
                var trueVariableName = ctx.AddIntermediateVariable(null, col, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(mapOuputType[col], trueVariableName);
            }

            return(ctx);
        }