Example #1
0
        public void TestDefaultAnalyzeSubNodes()
        {
            string     code1 = @"class TestClass { public void TestMethod( int test = 0; test = 1;){}}";
            SyntaxTree tree1 = CSharpSyntaxTree.ParseText(code1);

            MethodDeclarationSyntax method = tree1.GetRootAsync().Result.DescendantNodes().OfType <MethodDeclarationSyntax>().First();

            List <SyntaxTree> trees1 = new List <SyntaxTree> {
                tree1
            };
            Compilation comp1 = CSharpCompilation.Create("TestCompilation1", trees1);

            ScriptAnalyzer analyzer = this.createAnalyzer(comp1);

            TaggedSyntaxLibrary lib = analyzer.AnalyzeNode(method);

            Assert.IsNotNull(lib, "Library defined");
            Assert.AreEqual(1, lib.TaggedSyntaxTrees.Count(), "Has one tree");

            TaggedSyntaxTree tree = lib.TaggedSyntaxTrees.First();

            IEnumerable <SyntaxNode> nodes = method.DescendantNodesAndSelf();

            foreach (SyntaxNode node in nodes)
            {
                CollectionAssert.Contains(tree.TaggedNodes, node, "SubNode is added to tree");
            }
        }
Example #2
0
        static MethodDeclarationSyntax ExpandConstrainedBuiltIns(MethodDeclarationSyntax mtd, EnumerableDetails builtIn, string outTypeStr)
        {
            var constraintEnumerablesMentions =
                mtd
                .DescendantNodesAndSelf()
                .OfType <SimpleNameSyntax>()
                .Where(w => w.Identifier.ValueText == CONSTRAINED_BUILTIN_ENUMERABLE_NAME)
                .ToList();

            var constraintEnumeratorsMentions =
                mtd
                .DescendantNodesAndSelf()
                .OfType <SimpleNameSyntax>()
                .Where(w => w.Identifier.ValueText == CONSTRAINED_BUILTIN_ENUMERATOR_NAME)
                .ToList();

            if (constraintEnumerablesMentions.Count == 0 || constraintEnumeratorsMentions.Count == 0)
            {
                return(mtd);
            }

            var enumerable = builtIn.BridgeEnumerable;
            var enumerator = builtIn.BridgeEnumerator;

            var outType = SyntaxFactory.ParseTypeName(outTypeStr);

            var enumerableOutMentions = enumerable.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(i => i.Identifier.ValueText == builtIn.OutItem).ToList();
            var enumeratorOutMentions = enumerator.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(i => i.Identifier.ValueText == builtIn.OutItem).ToList();

            var boundEnumerable = enumerable.ReplaceNodes(enumerableOutMentions, (old, _) => outType.WithTriviaFrom(old));
            var boundEnumerator = enumerator.ReplaceNodes(enumeratorOutMentions, (old, _) => outType.WithTriviaFrom(old));

            var replacements = new Dictionary <SyntaxNode, SyntaxNode>();

            constraintEnumerablesMentions.ForEach(e => replacements[e] = boundEnumerable);
            constraintEnumeratorsMentions.ForEach(e => replacements[e] = boundEnumerator);

            var updatedMtd = mtd.ReplaceNodes(replacements.Keys, (old, _) => replacements[old].WithTriviaFrom(old));

            updatedMtd = updatedMtd.WithAdditionalAnnotations(DO_NOT_PARAMETERIZE);

            return(updatedMtd);
        }
Example #3
0
        public static PropertyDeclarationSyntax GetHalSynchronizedPropertyUsed(
            HalfSynchronizedClassRepresentation halfSynchronizedClass,
            MethodDeclarationSyntax methodWithHalfSynchronizedProperties)
        {
            var identifiersInMethods =
                methodWithHalfSynchronizedProperties.DescendantNodesAndSelf()
                .OfType <IdentifierNameSyntax>()
                .Select(e => e.Identifier.Text);

            var propUsed =
                halfSynchronizedClass.UnsynchronizedPropertiesInSynchronizedMethods.ToList()
                .First(e => identifiersInMethods.Contains(e.Identifier.Text));

            return(propUsed);
        }
        private void AnalyzeMembers(MethodDeclarationSyntax method, SyntaxNodeAnalysisContext context, bool isAsync)
        {
            foreach (var invocation in method.DescendantNodesAndSelf().OfType <InvocationExpressionSyntax>())
            {
                switch (invocation.Expression)
                {
                case MemberAccessExpressionSyntax memberAccess:
                    IsAccessingThreadDotSleep(memberAccess.Name, context, isAsync);
                    break;

                case IdentifierNameSyntax identifierName:
                    IsAccessingThreadDotSleep(identifierName, context, isAsync);
                    break;
                }
            }
        }
        protected static void ReplaceRefParamCalls(ref MethodDeclarationSyntax updatedMtd)
        {
            var refParamUses =
                updatedMtd
                .DescendantNodesAndSelf()
                .OfType <InvocationExpressionSyntax>()
                .Where(e => (e.Expression as IdentifierNameSyntax)?.Identifier.ValueText == REF_PARAM_PLACEHOLDER)
                .ToList();

            foreach (var p in refParamUses)
            {
                var parameter = p.ArgumentList.Arguments.ElementAt(0).Expression;

                var withRef    = SyntaxFactory.RefExpression(parameter);
                var refKeyword = withRef.RefKeyword.WithTrailingTrivia(SyntaxFactory.Whitespace(" "));
                withRef = withRef.WithRefKeyword(refKeyword);

                updatedMtd = updatedMtd.ReplaceNode(p, withRef.WithTriviaFrom(p));
            }
        }
Example #6
0
        private void TryAnalyzeEqualsOrGetHashCode(IMethodSymbol methodSymbol, MethodDeclarationSyntax syntax,
                                                   SemanticModel semanticModel, Action <Diagnostic> diagnosticReporter)
        {
            if (
                // Overrides Equals or GetHashCode
                (methodSymbol.IsOverride &&
                 (methodSymbol.Name == nameof(Equals) || methodSymbol.Name == nameof(GetHashCode))) ||
                // Or implements IEquatable<T>.Equals
                (methodSymbol.IsInterfaceImplementation() && methodSymbol.Name == nameof(Equals))
                )
            {
                // Looking through all the method calls in the current method declaration
                foreach (var node in syntax.DescendantNodesAndSelf().OfType <InvocationExpressionSyntax>())
                {
                    if (node.Expression is MemberAccessExpressionSyntax ms)
                    {
                        // Interested only 'a.b()'
                        var s = semanticModel.GetSymbolInfo(ms.Expression).Symbol;

                        // Looking for a field access with the calls to Equals/GetHashCode
                        // on structs with default Equals/GetHashCode implementations.
                        if (s is IFieldSymbol fs &&
                            fs.Type.IsStruct() &&
                            fs.Type.HasDefaultEqualsOrHashCodeImplementations(out _) &&
                            semanticModel.GetSymbolInfo(ms).Symbol is var reference &&
                            reference is IMethodSymbol referencedMethod &&

                            (referencedMethod.Name == nameof(Equals) || referencedMethod.Name == nameof(GetHashCode)))
                        {
                            string equalsOrHashCodeAsString = referencedMethod.Name;
                            var    diagnostic = Diagnostic.Create(Rule, ms.Name.GetLocation(), equalsOrHashCodeAsString, $"{methodSymbol.ContainingType.ToDisplayString()}.{referencedMethod.Name}");
                            diagnosticReporter(diagnostic);
                        }
                    }
                }
            }
        }
        protected static void ReplaceRefLocalCalls(ref MethodDeclarationSyntax updatedMtd)
        {
            var refLocalUses =
                updatedMtd
                .DescendantNodesAndSelf()
                .OfType <InvocationExpressionSyntax>()
                .Where(e => (e.Expression as IdentifierNameSyntax)?.Identifier.ValueText == REF_LOCAL_PLACEHOLDER)
                .ToList();

            var replacements = new Dictionary <ExpressionSyntax, RefExpressionSyntax>();

            foreach (var l in refLocalUses)
            {
                var local = l.ArgumentList.Arguments.ElementAt(0).Expression;

                var withRef    = SyntaxFactory.RefExpression(local);
                var refKeyword = withRef.RefKeyword.WithTrailingTrivia(SyntaxFactory.Whitespace(" "));
                withRef = withRef.WithRefKeyword(refKeyword);

                replacements[l] = withRef.WithTriviaFrom(l);
            }

            updatedMtd = updatedMtd.ReplaceNodes(replacements.Keys, (old, _) => replacements[old]);
        }
        private MethodTransformationResult TransformMethod(MethodDeclarationSyntax methodNode, bool canCopy, MethodTransformationResult result, ITypeTransformationMetadata typeMetadata,
                                                           INamespaceTransformationMetadata namespaceMetadata)
        {
            //var result = new MethodTransformationResult(methodResult);
            var methodResult     = result.AnalyzationResult;
            var methodConversion = methodResult.Conversion;

            if (!canCopy)
            {
                methodConversion &= ~MethodConversion.Copy;
            }
            //var methodNode = customNode ?? methodResult.Node;
            var methodBodyNode = methodResult.GetBodyNode();

            // Calculate whitespace method trivias
            result.EndOfLineTrivia             = methodNode.GetEndOfLine();
            result.LeadingWhitespaceTrivia     = methodNode.GetLeadingWhitespace();
            result.IndentTrivia                = methodNode.GetIndent(result.LeadingWhitespaceTrivia, typeMetadata.LeadingWhitespaceTrivia);
            result.BodyLeadingWhitespaceTrivia = Whitespace(result.LeadingWhitespaceTrivia.ToFullString() + result.IndentTrivia.ToFullString());

            if (methodConversion == MethodConversion.Ignore)
            {
                return(result);
            }

            if (methodBodyNode == null)
            {
                if (methodConversion.HasFlag(MethodConversion.ToAsync))
                {
                    result.Transformed = methodNode;
                    if (methodConversion.HasFlag(MethodConversion.Copy))
                    {
                        result.AddMethod(methodResult.Node);
                    }
                    return(result);
                }
                if (methodConversion.HasFlag(MethodConversion.Copy))
                {
                    result.Transformed = methodResult.Node;
                }
                return(result);
            }
            var startMethodSpan = methodResult.Node.Span.Start;

            methodNode       = methodNode.WithAdditionalAnnotations(new SyntaxAnnotation(result.Annotation));
            startMethodSpan -= methodNode.SpanStart;

            // First we need to annotate nodes that will be modified in order to find them later on.
            // We cannot rely on spans after the first modification as they will change
            var typeReferencesAnnotations = new List <string>();

            foreach (var typeReference in methodResult.TypeReferences.Where(o => o.TypeAnalyzationResult.Conversion == TypeConversion.NewType))
            {
                var reference  = typeReference.ReferenceLocation;
                var startSpan  = reference.Location.SourceSpan.Start - startMethodSpan;
                var nameNode   = methodNode.GetSimpleName(startSpan, reference.Location.SourceSpan.Length);
                var annotation = Guid.NewGuid().ToString();
                methodNode = methodNode.ReplaceNode(nameNode, nameNode.WithAdditionalAnnotations(new SyntaxAnnotation(annotation)));
                typeReferencesAnnotations.Add(annotation);
            }

            // For copied methods we need just to replace type references
            if (methodConversion.HasFlag(MethodConversion.Copy))
            {
                var copiedMethod = methodNode;
                // Modify references
                foreach (var refAnnotation in typeReferencesAnnotations)
                {
                    var nameNode = copiedMethod.GetAnnotatedNodes(refAnnotation).OfType <SimpleNameSyntax>().First();
                    copiedMethod = copiedMethod
                                   .ReplaceNode(nameNode, nameNode.WithIdentifier(Identifier(nameNode.Identifier.Value + "Async").WithTriviaFrom(nameNode.Identifier)));
                }
                if (!methodConversion.HasFlag(MethodConversion.ToAsync))
                {
                    result.Transformed = copiedMethod;
                    return(result);
                }
                result.AddMethod(copiedMethod.WithoutAnnotations(result.Annotation));
            }

            foreach (var childFunction in methodResult.ChildFunctions.Where(o => o.Conversion != MethodConversion.Ignore))
            {
                var functionNode   = childFunction.GetNode();
                var functionKind   = functionNode.Kind();
                var typeSpanStart  = functionNode.SpanStart - startMethodSpan;
                var typeSpanLength = functionNode.Span.Length;
                var funcNode       = methodNode.DescendantNodesAndSelf()
                                     .First(o => o.IsKind(functionKind) && o.SpanStart == typeSpanStart && o.Span.Length == typeSpanLength);
                var transformFuncResult = TransformFunction(childFunction, result, typeMetadata, namespaceMetadata);
                result.TransformedFunctions.Add(transformFuncResult);
                methodNode = methodNode.ReplaceNode(funcNode, funcNode.WithAdditionalAnnotations(new SyntaxAnnotation(transformFuncResult.Annotation)));
            }

            foreach (var referenceResult in methodResult.FunctionReferences
                     .Where(o => o.GetConversion() == ReferenceConversion.ToAsync))
            {
                var transfromReference = new FunctionReferenceTransformationResult(referenceResult);
                var isCref             = referenceResult.IsCref;
                var reference          = referenceResult.ReferenceLocation;
                var startSpan          = reference.Location.SourceSpan.Start - startMethodSpan;
                var nameNode           = methodNode.GetSimpleName(startSpan, reference.Location.SourceSpan.Length, isCref);
                methodNode = methodNode.ReplaceNode(nameNode, nameNode.WithAdditionalAnnotations(new SyntaxAnnotation(transfromReference.Annotation)));
                result.TransformedFunctionReferences.Add(transfromReference);

                if (isCref || referenceResult.IsNameOf || !methodResult.OmitAsync)
                {
                    continue;
                }
                // We need to annotate the reference node (InvocationExpression, IdentifierName) in order to know if we need to wrap the node in a Task.FromResult
                var refNode       = referenceResult.ReferenceNode;
                var bodyReference = (IBodyFunctionReferenceAnalyzationResult)referenceResult;
                if (bodyReference.UseAsReturnValue || refNode.IsReturned())
                {
                    startSpan = refNode.SpanStart - startMethodSpan;
                    var referenceNode = methodNode.DescendantNodes().First(o => o.SpanStart == startSpan && o.Span.Length == refNode.Span.Length);
                    methodNode = methodNode.ReplaceNode(referenceNode, referenceNode.WithAdditionalAnnotations(new SyntaxAnnotation(Annotations.TaskReturned)));
                }
            }
            // Before modifying, fixup method body formatting in order to prevent weird formatting when adding additinal code
            methodNode = FixupBodyFormatting(methodNode, result);

            // Modify references
            foreach (var refAnnotation in typeReferencesAnnotations)
            {
                var nameNode = methodNode.GetAnnotatedNodes(refAnnotation).OfType <SimpleNameSyntax>().First();
                methodNode = methodNode
                             .ReplaceNode(nameNode, nameNode.WithIdentifier(Identifier(nameNode.Identifier.Value + "Async").WithTriviaFrom(nameNode.Identifier)));
            }

            foreach (var transformFunction in result.TransformedFunctions)
            {
                var funcNode = methodNode.GetAnnotatedNodes(transformFunction.Annotation).First();
                methodNode = methodNode
                             .ReplaceNode(funcNode, transformFunction.Transformed);
            }

            // We have to order by OriginalStartSpan in order to have consistent formatting when adding awaits
            foreach (var transfromReference in result.TransformedFunctionReferences.OrderByDescending(o => o.OriginalStartSpan))
            {
                methodNode = TransformFunctionReference(methodNode, methodResult, transfromReference, typeMetadata, namespaceMetadata);
            }

            result.Transformed = methodNode;

            return(result);
        }
Example #9
0
        private void Analyze(SyntaxNodeAnalysisContext context)
        {
            MethodDeclarationSyntax methodDeclarationSyntax = context.Node as MethodDeclarationSyntax;

            if (methodDeclarationSyntax == null)
            {
                return;
            }

            // Only analyzing static method declarations
            if (!methodDeclarationSyntax.Modifiers.Any(SyntaxKind.StaticKeyword))
            {
                return;
            }

            // If the method is marked "extern", let it go.
            if (methodDeclarationSyntax.Modifiers.Any(SyntaxKind.ExternKeyword))
            {
                return;
            }

            // If the class is static, we need to let it go.
            ClassDeclarationSyntax classDeclarationSyntax = context.Node.FirstAncestorOrSelf <ClassDeclarationSyntax>();

            if (classDeclarationSyntax == null)
            {
                return;
            }
            if (classDeclarationSyntax.Modifiers.Any(SyntaxKind.StaticKeyword))
            {
                return;
            }

            // The Main entrypoint to the program must be static
            if (methodDeclarationSyntax.Identifier != null && methodDeclarationSyntax.Identifier.ValueText == @"Main")
            {
                return;
            }

            // Hunt for static members
            INamedTypeSymbol us = context.SemanticModel.GetDeclaredSymbol(classDeclarationSyntax);

            if (us == null)
            {
                return;
            }

            foreach (IdentifierNameSyntax identifierNameSyntax in methodDeclarationSyntax.DescendantNodesAndSelf().OfType <IdentifierNameSyntax>())
            {
                ISymbol symbol = context.SemanticModel.GetSymbolInfo(identifierNameSyntax).Symbol;
                if (symbol == null)
                {
                    continue;
                }
                if (symbol.IsStatic && !symbol.IsExtern)
                {
                    // We found a static thing being used in this method.  Is the thing ours?
                    if (SymbolEqualityComparer.Default.Equals(symbol.ContainingType, us))
                    {
                        // This method must be static because it references something static of ours.  We are done.
                        return;
                    }
                }
            }

            // Hunt for evidence that this is a factory method
            foreach (ObjectCreationExpressionSyntax objectCreationExpressionSyntax in methodDeclarationSyntax.DescendantNodesAndSelf().OfType <ObjectCreationExpressionSyntax>())
            {
                ISymbol objectCreationSymbol = context.SemanticModel.GetSymbolInfo(objectCreationExpressionSyntax).Symbol;
                if (SymbolEqualityComparer.Default.Equals(objectCreationSymbol?.ContainingType, us))
                {
                    return;
                }
            }

            // Check if this method is being used for DynamicData, if so, let it go
            string returnType = methodDeclarationSyntax.ReturnType.ToString();

            if (string.Equals(returnType, "IEnumerable<object[]>", StringComparison.CurrentCultureIgnoreCase))
            {
                return;
            }

            Diagnostic diagnostic = Diagnostic.Create(Rule, methodDeclarationSyntax.Modifiers.First(t => t.Kind() == SyntaxKind.StaticKeyword).GetLocation());

            context.ReportDiagnostic(diagnostic);
        }
        static protected IEnumerable <MethodDeclarationSyntax> ExpandMethodFromPlaceholders(
            MethodDeclarationSyntax template,
            IEnumerable <EnumerableDetails> enumerables,
            string placeHolderEnumerableName,
            string placeHolderEnumeratorName,
            bool includeReturnTypes
            )
        {
            var ret = new List <MethodDeclarationSyntax>();

            Func <SyntaxNode, bool> inReturn = null;

            inReturn =
                node =>
            {
                if (node.Parent == null)
                {
                    return(false);
                }

                var isPartOfMethod = node.Parent is MethodDeclarationSyntax;
                if (!isPartOfMethod)
                {
                    return(inReturn(node.Parent));
                }

                var parentMethod = (MethodDeclarationSyntax)node.Parent;

                if (parentMethod.ReturnType == null)
                {
                    return(false);
                }

                // hit the containing method, so it's make or break time
                return(node == parentMethod.ReturnType);
            };

            var mentionsOfPlaceholderEnumerable =
                template.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(s => s.Identifier.ValueText == placeHolderEnumerableName).ToList();
            var mentionsOfPlaceholderEnumerator =
                template.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(s => s.Identifier.ValueText == placeHolderEnumeratorName).ToList();

            if (!includeReturnTypes)
            {
                var inReturnEnumerables = mentionsOfPlaceholderEnumerable.Where(n => inReturn(n)).ToList();
                var inReturnEnumerators = mentionsOfPlaceholderEnumerator.Where(n => inReturn(n)).ToList();

                mentionsOfPlaceholderEnumerable = mentionsOfPlaceholderEnumerable.Except(inReturnEnumerables).ToList();
                mentionsOfPlaceholderEnumerator = mentionsOfPlaceholderEnumerator.Except(inReturnEnumerators).ToList();
            }

            // no changes to be made, leave it alone
            if (mentionsOfPlaceholderEnumerable.Count == 0 && mentionsOfPlaceholderEnumerator.Count == 0)
            {
                ret.Add(template);
                return(ret);
            }

            var outTypes =
                mentionsOfPlaceholderEnumerable
                .OfType <GenericNameSyntax>()
                .Concat(mentionsOfPlaceholderEnumerator.OfType <GenericNameSyntax>())
                .Select(g => g.TypeArgumentList.Arguments.ElementAt(0))
                .OfType <TypeSyntax>()
                .Select(t => t.ToString())
                .Distinct()
                .ToList();

            if (outTypes.Count > 1)
            {
                throw new Exception("Expected only a single out type in extension method placeholder usage");
            }

            var outTypeStr = outTypes.Single();
            var outType    = SyntaxFactory.ParseTypeName(outTypeStr);

            foreach (var pair in enumerables)
            {
                var updatedMtd = template;

                // replace all the uses of the out item with whatever is bound in the template
                var enumerableOutTypeUses = pair.Enumerable.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(t => t.Identifier.ValueText == pair.OutItem).ToList();
                var enumeratorOutTypeUses =
                    pair.Enumerator != null?
                    pair.Enumerator.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(t => t.Identifier.ValueText == pair.OutItem).ToList() :
                        new List <SimpleNameSyntax>();

                // rework enumerable and enumerator to bind to the appropriate type
                var boundEnumerable = pair.Enumerable.ReplaceNodes(enumerableOutTypeUses, (old, _) => outType.WithTriviaFrom(old));
                var boundEnumerator = pair.Enumerator?.ReplaceNodes(enumeratorOutTypeUses, (old, _) => outType.WithTriviaFrom(old));

                var dontInjectIntoCommon = false;
                var attrs    = updatedMtd.AttributeLists.SelectMany(a => a.Attributes).ToList();
                var dnpAttrs = attrs.Where(a => (a.Name as IdentifierNameSyntax)?.Identifier.ValueText == "DoNotInject").ToList();
                if (dnpAttrs.Any())
                {
                    var attrKeeps = new List <AttributeSyntax>(attrs);
                    foreach (var attr in dnpAttrs)
                    {
                        attrKeeps.Remove(attr);
                    }

                    if (attrKeeps.Count == 0)
                    {
                        updatedMtd = updatedMtd.RemoveNodes(updatedMtd.AttributeLists, SyntaxRemoveOptions.KeepLeadingTrivia);
                    }
                    else
                    {
                        var attrListSyntax = SyntaxFactory.AttributeList().AddAttributes(attrKeeps.ToArray());
                        var list           = SyntaxFactory.List(new[] { attrListSyntax });

                        updatedMtd = updatedMtd.WithAttributeLists(list);
                    }

                    dontInjectIntoCommon = true;
                }

                if (dontInjectIntoCommon)
                {
                    Func <SimpleNameSyntax, bool> inParameterList =
                        p =>
                    {
                        var pList = updatedMtd.ParameterList;

                        return(pList.Parameters.Any(x => x.Type.Equals(p)));
                    };

                    var replace = new Dictionary <SimpleNameSyntax, SyntaxNode>();

                    var bridingeEnumerableOutTypeUses = pair.BridgeEnumerable.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(t => t.Identifier.ValueText == pair.OutItem).ToList();
                    var bridingeEnumeratorOutTypeUses = pair.BridgeEnumerator.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(t => t.Identifier.ValueText == pair.OutItem).ToList();

                    var bridgingEnumerable = pair.BridgeEnumerable.ReplaceNodes(bridingeEnumerableOutTypeUses, (old, _) => outType.WithTriviaFrom(old));
                    var bridgingEnumerator = pair.BridgeEnumerator.ReplaceNodes(bridingeEnumeratorOutTypeUses, (old, _) => outType.WithTriviaFrom(old));

                    var toReplaceEnumerables = updatedMtd.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(s => s.Identifier.ValueText == placeHolderEnumerableName).ToList();
                    var toReplaceEnumerators = updatedMtd.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(s => s.Identifier.ValueText == placeHolderEnumeratorName).ToList();

                    foreach (var e in toReplaceEnumerables)
                    {
                        replace[e] = bridgingEnumerable;
                    }

                    foreach (var e in toReplaceEnumerators)
                    {
                        replace[e] = bridgingEnumerator;
                    }

                    var inParams = new List <SimpleNameSyntax>();
                    foreach (var kv in replace)
                    {
                        if (inParameterList(kv.Key))
                        {
                            inParams.Add(kv.Key);
                        }
                    }

                    foreach (var p in inParams)
                    {
                        replace[p] = boundEnumerable;
                    }

                    updatedMtd = updatedMtd.ReplaceNodes(replace.Keys, (old, _) => replace[old].WithTriviaFrom(old));
                    updatedMtd = updatedMtd.WithAdditionalAnnotations(DO_NOT_PARAMETERIZE);
                }
                else
                {
                    // replace the old enumerable and enumerator references
                    updatedMtd = updatedMtd.ReplaceNodes(mentionsOfPlaceholderEnumerable, (old, _) => boundEnumerable.WithTriviaFrom(old));

                    var updatedMentionsOfPlaceholderEnumerator = updatedMtd.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(s => s.Identifier.ValueText == placeHolderEnumeratorName).ToList();

                    if (!includeReturnTypes)
                    {
                        var enumeratorsInReturn = updatedMentionsOfPlaceholderEnumerator.Where(e => inReturn(e)).ToList();
                        updatedMentionsOfPlaceholderEnumerator = updatedMentionsOfPlaceholderEnumerator.Except(enumeratorsInReturn).ToList();
                    }

                    if (boundEnumerator != null)
                    {
                        updatedMtd = updatedMtd.ReplaceNodes(updatedMentionsOfPlaceholderEnumerator, (old, _) => boundEnumerator.WithTriviaFrom(old));
                    }
                    else
                    {
                        updatedMtd = updatedMtd.RemoveNodes(updatedMentionsOfPlaceholderEnumerator, SyntaxRemoveOptions.KeepNoTrivia);
                    }
                }

                // rewrite any type constraints so that they refer to the new out item too
                var updatedConstraints = new List <TypeParameterConstraintClauseSyntax>();
                updatedConstraints.AddRange(template.ConstraintClauses);

                foreach (var constraint in pair.Constraints)
                {
                    var constraitOutTypeUses = constraint.DescendantNodesAndSelf().OfType <SimpleNameSyntax>().Where(t => t.Identifier.ValueText == pair.OutItem).ToList();
                    var updatedConstraint    = constraint.ReplaceNodes(constraitOutTypeUses, (old, _) => outType.WithTriviaFrom(old));
                    updatedConstraints.Add(updatedConstraint);
                }

                updatedMtd = updatedMtd.WithConstraintClauses(SyntaxFactory.List(updatedConstraints));

                // slam all the generic types that need to be known into place
                var typeList = new List <TypeParameterSyntax>();
                foreach (var param in pair.GenericArgs)
                {
                    typeList.Add(SyntaxFactory.TypeParameter(param));
                }

                if (typeList.Count > 0)
                {
                    updatedMtd = updatedMtd.AddTypeParameterListParameters(typeList.ToArray());
                }

                if (pair.IsBridgeType && !dontInjectIntoCommon)
                {
                    // bridge types are handled with lots of specificly parameterized methods in CommonImplementation, so remove
                    //   any of the direct mentions in type argument lists in the body
                    updatedMtd = updatedMtd.WithAdditionalAnnotations(METHOD_ON_BRIDGE_TYPE);

                    var bodyGenericTypeArgs =
                        (updatedMtd.Body?.DescendantNodesAndSelf() ?? updatedMtd.ExpressionBody?.DescendantNodesAndSelf())
                        .OfType <TypeArgumentListSyntax>()
                        .SelectMany(t => t.Arguments)
                        .ToList();

                    var needRemoval = bodyGenericTypeArgs.Where(b => b.IsEquivalentTo(boundEnumerable)).ToList();

                    updatedMtd = updatedMtd.RemoveNodes(needRemoval, SyntaxRemoveOptions.KeepNoTrivia);

                    var replacements = new Dictionary <SyntaxNode, SyntaxNode>();

                    foreach (var withEmptyTypeArgs in (updatedMtd.Body?.DescendantNodesAndSelf() ?? updatedMtd.ExpressionBody?.DescendantNodesAndSelf()).OfType <TypeArgumentListSyntax>().Where(t => t.Arguments.Count == 0))
                    {
                        var parent     = (GenericNameSyntax)withEmptyTypeArgs.Parent;
                        var simpleName = SyntaxFactory.IdentifierName(parent.Identifier);

                        replacements[parent] = simpleName.WithTriviaFrom(parent);
                    }

                    updatedMtd = updatedMtd.ReplaceNodes(replacements.Keys, (old, _) => replacements[old]);

                    //while (true)
                    //{
                    //    var withEmptyTypeArgs =
                    //        (updatedMtd.Body?.DescendantNodesAndSelf() ?? updatedMtd.ExpressionBody?.DescendantNodesAndSelf())
                    //            .OfType<TypeArgumentListSyntax>()
                    //            .FirstOrDefault(t => t.Arguments.Count == 0);

                    //    if (withEmptyTypeArgs == null) break;

                    //    var parent = (GenericNameSyntax)withEmptyTypeArgs.Parent;
                    //    var simpleName = SyntaxFactory.IdentifierName(parent.Identifier);

                    //    updatedMtd = updatedMtd.ReplaceNode(parent, simpleName.WithTriviaFrom(parent));
                    //}
                }

                ret.Add(updatedMtd.WithTriviaFrom(template));
            }

            return(ret);
        }
Example #11
0
        static void ExtractEnumerablesAndEnumeratorsInFirstAndSecondPosition(
            MethodDeclarationSyntax mtd,
            out List <SyntaxNode> firstMentionsEnumerables,
            out List <SyntaxNode> secondMentionsEnumerables,
            out List <SyntaxNode> firstMentionsEnumerators,
            out List <SyntaxNode> secondMentionsEnumerators,
            out List <SyntaxNode> firstParameterEnumerables,
            out List <SyntaxNode> secondParameterEnumerables
            )
        {
            var typeArgLists         = mtd.DescendantNodesAndSelf().OfType <TypeArgumentListSyntax>().ToList();
            var relevantTypeArgLists =
                typeArgLists
                .Where(t => t.Parent is GenericNameSyntax)
                .Where(t => ((GenericNameSyntax)t.Parent).Identifier.ValueText != "Func")       // technically we should have a proper whitelist, but ehhh
                .ToList();

            var typeArgsMentioningBuiltIn =
                relevantTypeArgLists
                .Where(l =>
                       l.Arguments
                       .Any(
                           a => (a as SimpleNameSyntax)?.Identifier.ValueText == BUILTIN_ENUMERABLE_NAME ||
                           (a as SimpleNameSyntax)?.Identifier.ValueText == PLACEHOLDER_ENUMERABLE_NAME ||
                           (a as SimpleNameSyntax)?.Identifier.ValueText == CONSTRAINED_BUILTIN_ENUMERABLE_NAME
                           )
                       )
                .ToList();

            var parameterTypesMentioningBuiltIn =
                mtd.ParameterList.Parameters
                .Select(p => p.Type)
                .Where(
                    p =>
                    p.DescendantNodesAndSelf()
                    .OfType <SimpleNameSyntax>()
                    .Any(
                        x =>
                        x.Identifier.ValueText == BUILTIN_ENUMERABLE_NAME ||
                        x.Identifier.ValueText == PLACEHOLDER_ENUMERABLE_NAME ||
                        x.Identifier.ValueText == CONSTRAINED_BUILTIN_ENUMERABLE_NAME
                        )
                    )
                .ToList();

            firstMentionsEnumerables = new List <SyntaxNode>();
            firstMentionsEnumerators = new List <SyntaxNode>();

            secondMentionsEnumerables = new List <SyntaxNode>();
            secondMentionsEnumerators = new List <SyntaxNode>();

            firstParameterEnumerables  = new List <SyntaxNode>();
            secondParameterEnumerables = new List <SyntaxNode>();

            foreach (var argList in typeArgsMentioningBuiltIn)
            {
                var isFirstEnumerable = true;
                var isFirstEnumerator = true;

                foreach (var arg in argList.Arguments)
                {
                    var name = (arg as SimpleNameSyntax)?.Identifier.ValueText;

                    if (name == BUILTIN_ENUMERABLE_NAME || name == PLACEHOLDER_ENUMERABLE_NAME || name == CONSTRAINED_BUILTIN_ENUMERABLE_NAME)
                    {
                        if (isFirstEnumerable)
                        {
                            firstMentionsEnumerables.Add(arg);
                            isFirstEnumerable = false;
                        }
                        else
                        {
                            secondMentionsEnumerables.Add(arg);
                        }
                        continue;
                    }

                    if (name == BUILTIN_ENUMERATOR_NAME || name == PLACEHOLDER_ENUMERATOR_NAME || name == CONSTRAINED_BUILTIN_ENUMERATOR_NAME)
                    {
                        if (isFirstEnumerator)
                        {
                            firstMentionsEnumerators.Add(arg);
                            isFirstEnumerator = false;
                        }
                        else
                        {
                            secondMentionsEnumerators.Add(arg);
                        }
                        continue;
                    }
                }
            }

            {
                var isFirstEnumerable = true;

                foreach (var pType in parameterTypesMentioningBuiltIn)
                {
                    var name = (pType as SimpleNameSyntax)?.Identifier.ValueText;

                    // should be more general than this, but meh
                    if (name == "Func")
                    {
                        foreach (var p in ((GenericNameSyntax)pType).TypeArgumentList.Arguments)
                        {
                            var funcParamName = (p as SimpleNameSyntax)?.Identifier.ValueText;
                            if (funcParamName == null)
                            {
                                continue;
                            }

                            if (funcParamName == BUILTIN_ENUMERABLE_NAME || funcParamName == PLACEHOLDER_ENUMERABLE_NAME || funcParamName == CONSTRAINED_BUILTIN_ENUMERABLE_NAME)
                            {
                                secondParameterEnumerables.Add(p);
                            }
                        }

                        continue;
                    }

                    if (name == BUILTIN_ENUMERABLE_NAME || name == PLACEHOLDER_ENUMERABLE_NAME || name == CONSTRAINED_BUILTIN_ENUMERABLE_NAME)
                    {
                        if (isFirstEnumerable)
                        {
                            firstParameterEnumerables.Add(pType);
                            isFirstEnumerable = false;
                        }
                        else
                        {
                            secondParameterEnumerables.Add(pType);
                        }

                        continue;
                    }
                }
            }
        }