Пример #1
0
        internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IChannel ch, IDataView inputData, IDataView outputData,
                                                                   LinkedList <ITransformCanSaveOnnx> transforms, HashSet <string> inputColumnNamesToDrop = null, HashSet <string> outputColumnNamesToDrop = null)
        {
            inputColumnNamesToDrop  = inputColumnNamesToDrop ?? new HashSet <string>();
            outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet <string>();
            HashSet <string> inputColumns = new HashSet <string>();

            // Create graph inputs.
            for (int i = 0; i < inputData.Schema.Count; i++)
            {
                string colName = inputData.Schema[i].Name;
                if (inputColumnNamesToDrop.Contains(colName))
                {
                    continue;
                }

                ctx.AddInputVariable(inputData.Schema[i].Type, colName);
                inputColumns.Add(colName);
            }

            // Create graph nodes, outputs and intermediate values.
            foreach (var trans in transforms)
            {
                ch.Assert(trans.CanSaveOnnx(ctx));
                trans.SaveAsOnnx(ctx);
            }

            // Add graph outputs.
            for (int i = 0; i < outputData.Schema.Count; ++i)
            {
                if (outputData.Schema[i].IsHidden)
                {
                    continue;
                }

                var idataviewColumnName = outputData.Schema[i].Name;

                // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
                // _inputToDrop should be removed too.
                if (inputColumnNamesToDrop.Contains(idataviewColumnName) || outputColumnNamesToDrop.Contains(idataviewColumnName))
                {
                    continue;
                }

                var variableName = ctx.TryGetVariableName(idataviewColumnName);
                // Null variable name occurs when an unsupported transform produces an output and a downsteam step consumes that output.
                // or user accidently removes a transform whose output is used by other transforms.
                ch.Check(variableName != null, "The targeted pipeline can not be fully converted into a well-defined ONNX model. " +
                         "Please check if all steps in that pipeline are convertible to ONNX " +
                         "and all necessary variables are not dropped (via command line arguments).");
                var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(outputData.Schema[i].Type, trueVariableName);
            }

            return(ctx.MakeModel());
        }
Пример #2
0
        private static void AddSlotNames(OnnxContextImpl ctx, DataViewSchema.Column column)
        {
            VBuffer <ReadOnlyMemory <char> > slotNames = default;

            column.GetSlotNames(ref slotNames);
            IEnumerable <string> slotNamesAsStrings = slotNames.DenseValues().Select(name => name.ToString());

            string opType = "LabelEncoder";
            string labelEncoderInputName  = $"mlnet.{column.Name}.unusedInput";
            string labelEncoderOutputName = $"mlnet.{column.Name}.unusedOutput";
            string labelEncoderNodeName   = $"mlnet.{column.Name}.SlotNames";

            string[] oneVals = new string[] { "one" };
            long[]   dims    = new long[] { 1, 1 };
            var      one     = ctx.AddInitializer(oneVals, dims, labelEncoderNodeName);

            var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, labelEncoderOutputName, true);
            var node = ctx.CreateNode(opType, one, labelEncoderOutput, labelEncoderNodeName);

            node.AddAttribute("keys_strings", slotNamesAsStrings);
            node.AddAttribute("values_int64s", Enumerable.Range(0, slotNames.Length).Select(x => (long)x));

            ctx.AddOutputVariable(NumberDataViewType.Int64, labelEncoderOutput);
        }