public static EntryPointGenerationMetadata GetEntryPointMetadata(ModuleCatalog.EntryPointInfo entryPointInfo)
        {
            var split = entryPointInfo.Name.Split('.');

            Contracts.Check(split.Length == 2);
            return(new EntryPointGenerationMetadata(split[0], split[1]));
        }
Exemple #2
0
        private void GenerateInputOutput(IndentingTextWriter writer, ModuleCatalog.EntryPointInfo entryPointInfo, ModuleCatalog catalog)
        {
            var classAndMethod = CSharpGeneratorUtils.GetEntryPointMetadata(entryPointInfo);

            writer.WriteLine($"namespace {classAndMethod.Namespace}");
            writer.WriteLine("{");
            writer.Indent();
            GenerateInput(writer, entryPointInfo, catalog);
            writer.Outdent();
            writer.WriteLine("}");
            writer.WriteLine();
        }
Exemple #3
0
        private void GenerateInput(IndentingTextWriter writer, ModuleCatalog.EntryPointInfo entryPointInfo, ModuleCatalog catalog)
        {
            var    entryPointMetadata = CSharpGeneratorUtils.GetEntryPointMetadata(entryPointInfo);
            string classBase          = "";

            if (entryPointInfo.InputKinds != null)
            {
                classBase += $" : {string.Join(", ", entryPointInfo.InputKinds.Select(CSharpGeneratorUtils.GetCSharpTypeName))}";
                if (entryPointInfo.InputKinds.Any(t => typeof(ITrainerInput).IsAssignableFrom(t) || typeof(ITransformInput).IsAssignableFrom(t)))
                {
                    classBase += ", Microsoft.ML.ILearningPipelineItem";
                }
            }

            GenerateEnums(writer, entryPointInfo.InputType, _defaultNamespace + entryPointMetadata.Namespace);
            writer.WriteLine();
            GenerateClasses(writer, entryPointInfo.InputType, catalog, _defaultNamespace + entryPointMetadata.Namespace);
            CSharpGeneratorUtils.GenerateSummary(writer, entryPointInfo.Description, entryPointInfo.XmlInclude);

            if (entryPointInfo.ObsoleteAttribute != null)
            {
                writer.WriteLine($"[Obsolete(\"{entryPointInfo.ObsoleteAttribute.Message}\")]");
            }

            writer.WriteLine($"public sealed partial class {entryPointMetadata.ClassName}{classBase}");
            writer.WriteLine("{");
            writer.Indent();
            writer.WriteLine();
            if (entryPointInfo.InputKinds != null && entryPointInfo.InputKinds.Any(t => typeof(ILearningPipelineLoader).IsAssignableFrom(t)))
            {
                CSharpGeneratorUtils.GenerateLoaderAddInputMethod(writer, entryPointMetadata.ClassName);
            }

            GenerateColumnAddMethods(writer, entryPointInfo.InputType, catalog, entryPointMetadata.ClassName, out Type transformType);
            writer.WriteLine();
            GenerateInputFields(writer, entryPointInfo.InputType, catalog, _defaultNamespace + entryPointMetadata.Namespace);
            writer.WriteLine();

            GenerateOutput(writer, entryPointInfo, out HashSet <string> outputVariableNames);
            GenerateApplyFunction(writer, entryPointMetadata.ClassName, transformType, outputVariableNames, entryPointInfo.InputKinds);
            writer.Outdent();
            writer.WriteLine("}");
        }
Exemple #4
0
        private void GenerateOutput(IndentingTextWriter writer, ModuleCatalog.EntryPointInfo entryPointInfo, out HashSet <string> outputVariableNames)
        {
            outputVariableNames = new HashSet <string>();
            string classBase = "";

            if (entryPointInfo.OutputKinds != null)
            {
                classBase = $" : {string.Join(", ", entryPointInfo.OutputKinds.Select(CSharpGeneratorUtils.GetCSharpTypeName))}";
            }
            writer.WriteLine($"public sealed class Output{classBase}");
            writer.WriteLine("{");
            writer.Indent();

            var outputType = entryPointInfo.OutputType;

            if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(CommonOutputs.MacroOutput <>))
            {
                outputType = outputType.GetGenericTypeArgumentsEx()[0];
            }
            foreach (var fieldInfo in outputType.GetFields())
            {
                var outputAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.OutputAttribute), false)
                                 .FirstOrDefault() as TlcModule.OutputAttribute;
                if (outputAttr == null)
                {
                    continue;
                }

                CSharpGeneratorUtils.GenerateSummary(writer, outputAttr.Desc);
                var outputTypeString = CSharpGeneratorUtils.GetOutputType(fieldInfo.FieldType);
                outputVariableNames.Add(CSharpGeneratorUtils.Capitalize(outputAttr.Name ?? fieldInfo.Name));
                writer.WriteLine($"public {outputTypeString} {CSharpGeneratorUtils.Capitalize(outputAttr.Name ?? fieldInfo.Name)} {{ get; set; }} = new {outputTypeString}();");
                writer.WriteLine();
            }

            writer.Outdent();
            writer.WriteLine("}");
        }
Exemple #5
0
        public static JObject BuildEntryPointManifest(IExceptionContext ectx, ModuleCatalog.EntryPointInfo entryPointInfo, ModuleCatalog catalog)
        {
            Contracts.CheckValueOrNull(ectx);
            ectx.CheckValue(entryPointInfo, nameof(entryPointInfo));
            ectx.CheckValue(catalog, nameof(catalog));

            var result = new JObject();

            result[FieldNames.Name]         = entryPointInfo.Name;
            result[FieldNames.Desc]         = entryPointInfo.Description;
            result[FieldNames.FriendlyName] = entryPointInfo.FriendlyName;
            result[FieldNames.ShortName]    = entryPointInfo.ShortName;

            // There supposed to be 2 parameters, env and input.
            result[FieldNames.Inputs]  = BuildInputManifest(ectx, entryPointInfo.InputType, catalog);
            result[FieldNames.Outputs] = BuildOutputManifest(ectx, entryPointInfo.OutputType, catalog);

            if (entryPointInfo.InputKinds != null)
            {
                var jInputKinds = new JArray();
                foreach (var kind in entryPointInfo.InputKinds)
                {
                    jInputKinds.Add(kind.Name);
                }
                result[FieldNames.InputKind] = jInputKinds;
            }

            if (entryPointInfo.OutputKinds != null)
            {
                var jOutputKinds = new JArray();
                foreach (var kind in entryPointInfo.OutputKinds)
                {
                    jOutputKinds.Add(kind.Name);
                }
                result[FieldNames.OutputKind] = jOutputKinds;
            }
            return(result);
        }