Example #1
0
        public override bool Execute()
        {
            var files = InputFiles.Select(f => new SourceFile {
                Name = f.ItemSpec,
                SyntaxTree = SyntaxFactory.ParseSyntaxTree(File.ReadAllText(f.ItemSpec))
            }).ToList();

            // Creation of the assembly reference should be done this way:
            // var mscorlib = new MetadataFileReference(typeof(object).Assembly.Location);
            // But this creates a "path not absolute" exception on mono in GetAssemblyOrModuleSymbol() below.
            // See http://stackoverflow.com/questions/26355922/creating-roslyn-metadatafilereference-bombs-on-mono-linux

            var mscorlibMetadata = AssemblyMetadata.CreateFromImageStream(new FileStream(typeof(object).Assembly.Location, FileMode.Open, FileAccess.Read));
            var mscorlib = new MetadataImageReference (mscorlibMetadata);

            var datalibMetadata = AssemblyMetadata.CreateFromImageStream(new FileStream(typeof(CommandBehavior).Assembly.Location, FileMode.Open, FileAccess.Read));
            var datalib = new MetadataImageReference (datalibMetadata);

            var compilation = CSharpCompilation.Create(
                "Temp",
                files.Select(f => f.SyntaxTree),
                new[] { mscorlib, datalib }
            );
            foreach (var file in files) {
                file.SemanticModel = compilation.GetSemanticModel(file.SyntaxTree);
            }

            var corlibSymbol = (IAssemblySymbol)compilation.GetAssemblyOrModuleSymbol(mscorlib);
            _excludedTypes = new HashSet<ITypeSymbol> {
                corlibSymbol.GetTypeByMetadataName("System.IO.TextWriter"),
                corlibSymbol.GetTypeByMetadataName("System.IO.MemoryStream") 
            };

            // First pass: find methods with the [GenerateAsync] attribute
            foreach (var file in files)
            {
                foreach (var m in file.SyntaxTree.GetRoot()
                    .DescendantNodes()
                    .OfType<MethodDeclarationSyntax>()
                )
                {
                    // Syntactically filter out any method without [GenerateAsync] (for performance)
                    if (m.AttributeLists.SelectMany(al => al.Attributes).All(a => a.Name.ToString() != "GenerateAsync")) {
                        continue;
                    }

                    var methodSymbol = file.SemanticModel.GetDeclaredSymbol(m);

                    var cls = m.FirstAncestorOrSelf<ClassDeclarationSyntax>();
                    var ns = cls.FirstAncestorOrSelf<NamespaceDeclarationSyntax>();

                    Dictionary<ClassDeclarationSyntax, HashSet<MethodInfo>> classes;
                    if (!file.NamespaceToClasses.TryGetValue(ns, out classes))
                        classes = file.NamespaceToClasses[ns] = new Dictionary<ClassDeclarationSyntax, HashSet<MethodInfo>>();

                    HashSet<MethodInfo> methods;
                    if (!classes.TryGetValue(cls, out methods))
                        methods = classes[cls] = new HashSet<MethodInfo>();

                    var methodInfo = new MethodInfo
                    {
                        DeclarationSyntax = m,
                        Symbol = methodSymbol,
                        Transformed = m.Identifier.Text + "Async",
                        WithOverride = false
                    };

                    var attr = methodSymbol.GetAttributes().Single(a => a.AttributeClass.Name == "GenerateAsync");
                    if (attr.ConstructorArguments[0].Value != null)
                        methodInfo.Transformed = (string)attr.ConstructorArguments[0].Value;
                    if (((bool) attr.ConstructorArguments[1].Value))
                        methodInfo.WithOverride = true;
                    methods.Add(methodInfo);
                }
            }

            Log.LogMessage("Found {0} methods marked for async rewriting",
                           files.SelectMany(f => f.NamespaceToClasses.Values).SelectMany(ctm => ctm.Values).SelectMany(m => m).Count());

            // Second pass: transform
            foreach (var f in files)
            {
                Log.LogMessage("Writing out {0}", f.TransformedName);
                File.WriteAllText(f.TransformedName, RewriteFile(f).ToString());
            }

            OutputFiles = files.Select(f => new TaskItem(f.TransformedName)).ToArray();

            return true;
        }
Example #2
0
        MethodDeclarationSyntax RewriteMethod(SourceFile file, MethodInfo inMethodInfo)
        {
            var inMethodSyntax = inMethodInfo.DeclarationSyntax;
            //Log.LogMessage("Method {0}: {1}", inMethodInfo.Symbol.Name, inMethodInfo.Symbol.);

            Log.LogMessage(MessageImportance.Low, "  Rewriting method {0} to {1}", inMethodInfo.Symbol.Name, inMethodInfo.Transformed);

            // Visit all method invocations inside the method, rewrite them to async if needed
            var rewriter = new MethodInvocationRewriter(Log, file.SemanticModel, _excludedTypes);
            var outMethod = (MethodDeclarationSyntax)rewriter.Visit(inMethodSyntax);

            // Method signature
            outMethod = outMethod
                .WithIdentifier(SyntaxFactory.Identifier(inMethodInfo.Transformed))
                .WithAttributeLists(new SyntaxList<AttributeListSyntax>())
                .WithModifiers(inMethodSyntax.Modifiers
                  .Add(SyntaxFactory.Token(SyntaxKind.AsyncKeyword))
                  //.Remove(SyntaxFactory.Token(SyntaxKind.OverrideKeyword))
                  //.Remove(SyntaxFactory.Token(SyntaxKind.NewKeyword))
                );

            // Transform return type adding Task<>
            var returnType = inMethodSyntax.ReturnType.ToString();
            outMethod = outMethod.WithReturnType(SyntaxFactory.ParseTypeName(
                returnType == "void" ? "Task" : String.Format("Task<{0}>", returnType))
            );

            // Remove the override and new attributes. Seems like the clean .Remove above doesn't work...
            for (var i = 0; i < outMethod.Modifiers.Count;)
            {
                var text = outMethod.Modifiers[i].Text;
                if (text == "override" || text == "new") {
                    outMethod = outMethod.WithModifiers(outMethod.Modifiers.RemoveAt(i));
                    continue;
                }
                i++;
            }

            if (inMethodInfo.WithOverride) {
                outMethod = outMethod.AddModifiers(SyntaxFactory.Token(SyntaxKind.OverrideKeyword));
            }

            return outMethod;
        }