コード例 #1
0
        private void GenerateInput(IndentedTextWriter writer, ComponentCatalog.EntryPointInfo entryPointInfo, ComponentCatalog catalog)
        {
            var    entryPointMetadata = CSharpGeneratorUtils.GetEntryPointMetadata(entryPointInfo);
            string classBase          = "";

            if (entryPointInfo.InputKinds != null)
            {
                classBase += $" : {string.Join(", ", entryPointInfo.InputKinds.Select(CSharpGeneratorUtils.GetCSharpTypeName))}";
                if (entryPointInfo.InputKinds.Any(t => typeof(ITrainerInput).IsAssignableFrom(t) || typeof(ITransformInput).IsAssignableFrom(t)))
                {
                    classBase += ", Microsoft.ML.Legacy.ILearningPipelineItem";
                }
            }

            GenerateEnums(writer, entryPointInfo.InputType, _defaultNamespace + entryPointMetadata.Namespace);
            writer.WriteLineNoTabs();
            GenerateClasses(writer, entryPointInfo.InputType, catalog, _defaultNamespace + entryPointMetadata.Namespace);
            CSharpGeneratorUtils.GenerateSummary(writer, entryPointInfo.Description, entryPointInfo.XmlInclude);

            if (entryPointInfo.ObsoleteAttribute != null)
            {
                writer.WriteLine($"[Obsolete(\"{entryPointInfo.ObsoleteAttribute.Message}\")]");
            }
            else
            {
                writer.WriteLine("[Obsolete]");
            }

            writer.WriteLine($"public sealed partial class {entryPointMetadata.ClassName}{classBase}");
            writer.WriteLine("{");
            writer.Indent();
            writer.WriteLineNoTabs();
            if (entryPointInfo.InputKinds != null && entryPointInfo.InputKinds.Any(t => typeof(Legacy.ILearningPipelineLoader).IsAssignableFrom(t)))
            {
                CSharpGeneratorUtils.GenerateLoaderAddInputMethod(writer, entryPointMetadata.ClassName);
            }

            GenerateColumnAddMethods(writer, entryPointInfo.InputType, catalog, entryPointMetadata.ClassName, out Type transformType);
            writer.WriteLineNoTabs();
            GenerateInputFields(writer, entryPointInfo.InputType, catalog, _defaultNamespace + entryPointMetadata.Namespace);
            writer.WriteLineNoTabs();

            GenerateOutput(writer, entryPointInfo, out HashSet <string> outputVariableNames);
            GenerateApplyFunction(writer, entryPointMetadata.ClassName, transformType, outputVariableNames, entryPointInfo.InputKinds);
            writer.Outdent();
            writer.WriteLine("}");
        }
コード例 #2
0
 private void GenerateComponent(IndentedTextWriter writer, ComponentCatalog.ComponentInfo component, ComponentCatalog catalog)
 {
     GenerateEnums(writer, component.ArgumentType, "");
     writer.WriteLineNoTabs();
     GenerateClasses(writer, component.ArgumentType, catalog, "");
     writer.WriteLineNoTabs();
     CSharpGeneratorUtils.GenerateSummary(writer, component.Description);
     writer.WriteLine("[Obsolete]");
     writer.WriteLine($"public sealed class {CSharpGeneratorUtils.GetComponentName(component)} : {component.Kind}");
     writer.WriteLine("{");
     writer.Indent();
     GenerateInputFields(writer, component.ArgumentType, catalog, "");
     writer.WriteLine("[Obsolete]");
     writer.WriteLine($"internal override string ComponentName => \"{component.Name}\";");
     writer.Outdent();
     writer.WriteLine("}");
     writer.WriteLineNoTabs();
 }
コード例 #3
0
        private void GenerateOutput(IndentedTextWriter writer, ComponentCatalog.EntryPointInfo entryPointInfo, out HashSet <string> outputVariableNames)
        {
            outputVariableNames = new HashSet <string>();
            string classBase = "";

            if (entryPointInfo.OutputKinds != null)
            {
                classBase = $" : {string.Join(", ", entryPointInfo.OutputKinds.Select(CSharpGeneratorUtils.GetCSharpTypeName))}";
            }
            writer.WriteLine("[Obsolete]");
            writer.WriteLine($"public sealed class Output{classBase}");
            writer.WriteLine("{");
            writer.Indent();

            var outputType = entryPointInfo.OutputType;

            if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(CommonOutputs.MacroOutput <>))
            {
                outputType = outputType.GetGenericTypeArgumentsEx()[0];
            }
            foreach (var fieldInfo in outputType.GetFields())
            {
                var outputAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.OutputAttribute), false)
                                 .FirstOrDefault() as TlcModule.OutputAttribute;
                if (outputAttr == null)
                {
                    continue;
                }

                CSharpGeneratorUtils.GenerateSummary(writer, outputAttr.Desc);
                var outputTypeString = CSharpGeneratorUtils.GetOutputType(fieldInfo.FieldType);
                outputVariableNames.Add(CSharpGeneratorUtils.Capitalize(outputAttr.Name ?? fieldInfo.Name));
                writer.WriteLine($"public {outputTypeString} {CSharpGeneratorUtils.Capitalize(outputAttr.Name ?? fieldInfo.Name)} {{ get; set; }} = new {outputTypeString}();");
                writer.WriteLineNoTabs();
            }

            writer.Outdent();
            writer.WriteLine("}");
        }
コード例 #4
0
        private void GenerateInputFields(IndentedTextWriter writer, Type inputType, ComponentCatalog catalog, string rootNameSpace)
        {
            var defaults = Activator.CreateInstance(inputType);

            foreach (var fieldInfo in inputType.GetFields())
            {
                var inputAttr =
                    fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() as ArgumentAttribute;
                if (inputAttr == null || inputAttr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
                {
                    continue;
                }
                if (fieldInfo.FieldType == typeof(JObject))
                {
                    continue;
                }

                CSharpGeneratorUtils.GenerateSummary(writer, inputAttr.HelpText);
                if (fieldInfo.FieldType == typeof(JArray))
                {
                    writer.WriteLine("[Obsolete]");
                    writer.WriteLine($"public Experiment {CSharpGeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}");
                    writer.WriteLineNoTabs();
                    continue;
                }

                var inputTypeString = CSharpGeneratorUtils.GetInputType(catalog, fieldInfo.FieldType, _generatedClasses, rootNameSpace);
                if (CSharpGeneratorUtils.IsComponent(fieldInfo.FieldType))
                {
                    writer.WriteLine("[JsonConverter(typeof(ComponentSerializer))]");
                }
                if (CSharpGeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name) != (inputAttr.Name ?? fieldInfo.Name))
                {
                    writer.WriteLine($"[JsonProperty(\"{inputAttr.Name ?? fieldInfo.Name}\")]");
                }

                // For range attributes on properties
                if (fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault()
                    is TlcModule.RangeAttribute ranAttr)
                {
                    writer.WriteLine(ranAttr.ToString());
                }

                // For sweepable ranges on properties
                if (fieldInfo.GetCustomAttributes(typeof(TlcModule.SweepableParamAttribute), false).FirstOrDefault()
                    is TlcModule.SweepableParamAttribute sweepableParamAttr)
                {
                    if (string.IsNullOrEmpty(sweepableParamAttr.Name))
                    {
                        sweepableParamAttr.Name = fieldInfo.Name;
                    }
                    writer.WriteLine(sweepableParamAttr.ToString());
                }

                writer.WriteLine("[Obsolete]");
                var line         = $"public {inputTypeString} {CSharpGeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}";
                var defaultValue = CSharpGeneratorUtils.GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(defaults), _generatedClasses, rootNameSpace);
                if (defaultValue != null)
                {
                    line += $" = {defaultValue};";
                }
                writer.WriteLine(line);
                writer.WriteLineNoTabs();
            }
        }