private ImmutableList <MethodDeclarationSyntax> GetOperationContractMethodDeclarations(SemanticModel semanticModel, SyntaxGenerator gen, INamedTypeSymbol serviceInterface, bool includeAttributes, bool includeSourceInterfaceMethods, bool excludeAsyncMethods)
        {
            ImmutableList <MethodDeclarationSyntax> methods = ImmutableList <MethodDeclarationSyntax> .Empty;

            var sourceMethods = GetOperationContractMethodInfos(semanticModel, serviceInterface);

            foreach (var methodInfo in sourceMethods.OrderBy(m => m.ContractMethodName).ThenBy(m => m.IsAsync))
            {
                // Emit non-async version of method
                if (!methodInfo.IsAsync || !sourceMethods.Any(m => m.ContractMacthes(methodInfo) && !m.IsAsync))
                {
                    SyntaxNode targetMethod = gen.MethodDeclaration(methodInfo.Method);
                    targetMethod = gen.WithType(targetMethod, gen.TypeExpression(methodInfo.ContractReturnType));
                    targetMethod = gen.WithName(targetMethod, methodInfo.ContractMethodName);

                    if (includeAttributes)
                    {
                        var extraFaultContractAttributes = sourceMethods.FirstOrDefault(m => m.ContractMacthes(methodInfo) && m.IsAsync != methodInfo.IsAsync)?.FaultContractAttributes ?? Enumerable.Empty <AttributeData>();
                        targetMethod = gen.AddAttributes(targetMethod, methodInfo.AllAttributes.Concat(extraFaultContractAttributes).Select(a => gen.Attribute(a)));
                    }

                    targetMethod = targetMethod.AddNewLineTrivia().AddNewLineTrivia();
                    methods      = methods.Add((MethodDeclarationSyntax)targetMethod);
                }

                // Emit async-version of method
                if (!excludeAsyncMethods)
                {
                    if (methodInfo.IsAsync || !sourceMethods.Any(m => m.ContractMacthes(methodInfo) && m.IsAsync))
                    {
                        SyntaxNode targetMethod = gen.MethodDeclaration(methodInfo.Method);
                        targetMethod = gen.WithType(targetMethod, gen.TypeExpression(methodInfo.AsyncReturnType));
                        targetMethod = gen.WithName(targetMethod, methodInfo.ContractMethodName + "Async");

                        if (includeAttributes)
                        {
                            targetMethod = gen.AddAttributes(targetMethod, methodInfo.AdditionalAttributes.Select(a => gen.Attribute(a)));
                            targetMethod = gen.AddAttributes(targetMethod,
                                                             gen.AddAttributeArguments(gen.Attribute(methodInfo.OperationContractAttribute), new[] { gen.AttributeArgument("AsyncPattern", gen.TrueLiteralExpression()) })
                                                             );
                        }

                        targetMethod = targetMethod.AddNewLineTrivia().AddNewLineTrivia();
                        methods      = methods.Add((MethodDeclarationSyntax)targetMethod);
                    }
                }
            }

            return(methods);
        }
        public ClassDeclarationSyntax GenerateProxyClass(SemanticModel semanticModel, SyntaxGenerator generator, INamedTypeSymbol sourceProxyInterface, string name, Accessibility accessibility, bool suppressWarningComments, MemberAccessibility constructorAccessibility, out IEnumerable <IMethodSymbol> sourceConstructors)
        {
            if (name == null)
            {
                if (sourceProxyInterface.Name.StartsWith("I"))
                {
                    name = sourceProxyInterface.Name.Substring(1) + "Proxy";
                }
                else
                {
                    name = sourceProxyInterface.Name + "Proxy";
                }
            }

            var compilation = semanticModel.Compilation;

            // Resolve the callback contract if any
            ITypeSymbol   serviceContractAttributeType = compilation.RequireTypeByMetadataName("System.ServiceModel.ServiceContractAttribute");
            AttributeData serviceContractAttribute     = sourceProxyInterface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass.Equals(serviceContractAttributeType));

            if (serviceContractAttribute == null)
            {
                throw new CodeGeneratorException(sourceProxyInterface, $"The interface {sourceProxyInterface.Name} is not decorated with ServiceContractAttribute.");
            }

            ITypeSymbol callbackContractType;
            var         callbackContractArg = serviceContractAttribute.NamedArguments.FirstOrDefault(arg => arg.Key.Equals("CallbackContract"));

            if (callbackContractArg.Key != null)
            {
                callbackContractType = callbackContractArg.Value.Value as ITypeSymbol;
            }
            else
            {
                callbackContractType = null;
            }

            // Resolve the base type (ClientBase or DuplexClientBase depending on whether a CallbackContract exists or not)
            INamedTypeSymbol baseType;

            if (callbackContractType != null)
            {
                baseType = compilation.RequireTypeByMetadataName("System.ServiceModel.DuplexClientBase`1").Construct(sourceProxyInterface);
            }
            else
            {
                baseType = compilation.RequireTypeByMetadataName("System.ServiceModel.ClientBase`1").Construct(sourceProxyInterface);
            }

            // Create class declaration
            SyntaxNode targetClass = generator.ClassDeclaration(name, accessibility: accessibility, baseType: generator.TypeExpression(baseType), interfaceTypes: new[] { generator.TypeExpression(sourceProxyInterface) });

            targetClass = generator.AddWarningCommentIf(!suppressWarningComments, targetClass);


            // Copy constructors from base class.
            sourceConstructors = baseType.Constructors.Where(ctor => ctor.DeclaredAccessibility != Accessibility.Private).ToImmutableArray();

            foreach (var baseCtor in sourceConstructors)
            {
                var targetCtor = generator.ConstructorDeclaration(baseCtor, baseCtor.Parameters.Select(p => generator.Argument(generator.IdentifierName(p.Name))));

                targetCtor = generator.AddWarningCommentIf(!suppressWarningComments, targetCtor);

                targetCtor  = generator.WithAccessibility(targetCtor, ToAccessibility(constructorAccessibility));
                targetClass = generator.AddMembers(targetClass, targetCtor.AddNewLineTrivia());
            }

            foreach (IMethodSymbol sourceMethod in GetOperationContractMethods(semanticModel.Compilation, sourceProxyInterface))
            {
                SyntaxNode targetMethod = generator.MethodDeclaration(sourceMethod);

                targetMethod = generator.AddWarningCommentIf(!suppressWarningComments, targetMethod);

                targetMethod = generator.WithModifiers(targetMethod, DeclarationModifiers.None);

                bool isVoid = sourceMethod.ReturnType.SpecialType == SpecialType.System_Void;
                targetMethod = targetMethod.AddNewLineTrivia().AddNewLineTrivia();

                var expression = generator.InvocationExpression(
                    generator.MemberAccessExpression(
                        generator.MemberAccessExpression(
                            generator.BaseExpression(),
                            "Channel"
                            ),
                        sourceMethod.Name
                        ),
                    sourceMethod.Parameters.Select(p => generator.IdentifierName(p.Name)).ToArray()
                    );

                SyntaxNode statement;
                if (!isVoid)
                {
                    statement = generator.ReturnStatement(expression);
                }
                else
                {
                    statement = generator.ExpressionStatement(expression);
                }

                targetMethod = generator.WithStatements(targetMethod,
                                                        new[]
                {
                    statement
                }
                                                        );
                targetClass = generator.AddMembers(targetClass, targetMethod.AddNewLineTrivia());
            }

            return((ClassDeclarationSyntax)targetClass);
        }
Example #3
0
        public Task <ClassDeclarationSyntax> GenerateClientClass(SemanticModel semanticModel, SyntaxGenerator gen, INamedTypeSymbol proxyInterface, string name, Accessibility accessibility, bool includeCancellableAsyncMethods, bool suppressWarningComments, MemberAccessibility constructorAccessibility, bool withInternalProxy)
        {
            if (name == null)
            {
                if (proxyInterface.Name.StartsWith("I"))
                {
                    name = proxyInterface.Name.Substring(1);
                }

                if (name.EndsWith("Proxy"))
                {
                    name = name.Substring(0, name.Length - "Proxy".Length);
                }

                if (!name.EndsWith("Client"))
                {
                    name = name + "Client";
                }
            }


            SyntaxNode targetClass = gen.ClassDeclaration(name,
                                                          baseType: gen.TypeExpression(semanticModel.Compilation.RequireType <MarshalByRefObject>()),
                                                          accessibility: accessibility,
                                                          modifiers: DeclarationModifiers.Sealed);

            targetClass = gen.AddWarningCommentIf(!suppressWarningComments, targetClass);

            targetClass = gen.AddInterfaceType(targetClass, gen.TypeExpression(semanticModel.Compilation.GetSpecialType(SpecialType.System_IDisposable)));
            targetClass = gen.AddInterfaceType(targetClass, gen.TypeExpression(proxyInterface));

            IEnumerable <IMethodSymbol> methods = GetOperationContractMethods(semanticModel.Compilation, proxyInterface).ToArray();

            GenerationNameTable nameTable = new GenerationNameTable(methods.Select(m => m.Name).Concat(new[] { name }));


            #region Private Fields

            // ==> private IProxy m_cachedProxy;
            SyntaxNode cachedProxyField =
                gen.FieldDeclaration(nameTable[MemberNames.CachedProxyField], gen.TypeExpression(proxyInterface), Accessibility.Private, DeclarationModifiers.None)
                .PrependLeadingTrivia(gen.CreateRegionTrivia("Private Fields"));

            targetClass = gen.AddMembers(targetClass, cachedProxyField);

            // ==> private readonly Func<IProxy> m_proxyFactory;
            SyntaxNode proxyFactoryTypeExpression = gen.TypeExpression(semanticModel.Compilation.RequireTypeByMetadataName("System.Func`1").Construct(proxyInterface));

            targetClass = gen.AddMembers(targetClass, gen.FieldDeclaration(nameTable[MemberNames.ProxyFactoryField], proxyFactoryTypeExpression, Accessibility.Private, DeclarationModifiers.ReadOnly)
                                         .AddTrailingTrivia(gen.CreateEndRegionTrivia()).AddNewLineTrivia());

            #endregion


            #region Constructors

            // Constructor
            SyntaxNode constructor = gen.ConstructorDeclaration(
                parameters: new[] { gen.ParameterDeclaration("proxyFactory", proxyFactoryTypeExpression) },
                accessibility: withInternalProxy?Accessibility.Private: ToAccessibility(constructorAccessibility)
                );

            constructor = gen.AddWarningCommentIf(!suppressWarningComments, constructor);
            constructor = constructor.PrependLeadingTrivia(gen.CreateRegionTrivia("Constructors"));

            constructor = gen.WithStatements(constructor,
                                             new[]
            {
                // ==> if (proxyFactory == null)
                // ==>   throw new System.ArgumentNullException("proxyFactory");
                gen.ThrowIfNullStatement("proxyFactory"),

                // ==> m_proxyFactory = proxyFactory
                gen.AssignmentStatement(
                    gen.MemberAccessExpression(
                        gen.ThisExpression(),
                        gen.IdentifierName(nameTable[MemberNames.ProxyFactoryField])),
                    gen.IdentifierName("proxyFactory")
                    )
            }
                                             ).AddNewLineTrivia();

            if (!withInternalProxy)
            {
                constructor = constructor.AddTrailingTrivia(gen.CreateEndRegionTrivia()).AddNewLineTrivia();
            }

            targetClass = gen.AddMembers(targetClass, constructor);

            ClassDeclarationSyntax proxyClass = null;
            if (withInternalProxy)
            {
                IEnumerable <IMethodSymbol> ctors;
                proxyClass = GenerateProxyClass(semanticModel, gen, proxyInterface, nameTable[MemberNames.ProxyClass], Accessibility.Private, suppressWarningComments, MemberAccessibility.Public, out ctors)
                             .PrependLeadingTrivia(gen.CreateRegionTrivia("Proxy Class").Insert(0, gen.NewLine()))
                             .AddTrailingTrivia(gen.CreateEndRegionTrivia());

                // Generate one constructor for each of the proxy's constructors.
                foreach (var ctorEntry in ctors.AsSmartEnumerable())
                {
                    var ctor       = ctorEntry.Value;
                    var targetCtor = gen.ConstructorDeclaration(ctor);

                    var lambda = gen.ValueReturningLambdaExpression(
                        gen.ObjectCreationExpression(gen.IdentifierName(gen.GetName(proxyClass)), ctor.Parameters.Select(p => gen.IdentifierName(p.Name)))
                        );

                    targetCtor = gen.WithThisConstructorInitializer(targetCtor, new[] { lambda });

                    targetCtor = gen.AddWarningCommentIf(!suppressWarningComments, targetCtor);
                    targetCtor = gen.WithAccessibility(targetCtor, ToAccessibility(constructorAccessibility));

                    if (ctorEntry.IsLast)
                    {
                        targetCtor = targetCtor.AddTrailingTrivia(gen.CreateEndRegionTrivia()).AddNewLineTrivia();
                    }

                    targetClass = gen.AddMembers(targetClass, targetCtor.AddNewLineTrivia());
                }
            }

            #endregion

            #region Operation Contract Methods

            // ==> catch
            // ==> {
            // ==>    this.CloseProxy(false);
            // ==>    throw;
            // ==> }
            var catchAndCloseProxyStatement = gen.CatchClause(new SyntaxNode[]
            {
                // ==> this.CloseProxy(false);
                gen.ExpressionStatement(
                    gen.InvocationExpression(
                        gen.MemberAccessExpression(
                            gen.ThisExpression(),
                            nameTable[MemberNames.CloseProxyMethod]
                            ),
                        gen.FalseLiteralExpression()
                        )
                    ),

                // throw;
                gen.ThrowStatement()
            });


            foreach (var sourceMethodEntry in methods.AsSmartEnumerable())
            {
                var sourceMethod = sourceMethodEntry.Value;

                using (nameTable.PushScope(sourceMethod.Parameters.Select(p => p.Name)))
                {
                    bool isAsync = ReturnsTask(semanticModel.Compilation, sourceMethod);
                    bool isVoid  = sourceMethod.ReturnType.SpecialType == SpecialType.System_Void || sourceMethod.ReturnType.Equals(semanticModel.Compilation.RequireType <Task>());

                    SyntaxNode targetMethod = gen.MethodDeclaration(sourceMethod);

                    if (sourceMethodEntry.IsFirst)
                    {
                        targetMethod = targetMethod.PrependLeadingTrivia(gen.CreateRegionTrivia("Contract Methods")).AddLeadingTrivia(gen.NewLine());
                    }

                    targetMethod = gen.AddWarningCommentIf(!suppressWarningComments, targetMethod);

                    targetMethod = gen.WithModifiers(targetMethod, isAsync ? DeclarationModifiers.Async : DeclarationModifiers.None);


                    targetMethod = gen.WithStatements(targetMethod, new SyntaxNode[]
                    {
                        // ==> try {
                        gen.TryCatchStatement(new SyntaxNode[]
                        {
                            CreateProxyVaraibleDeclaration(gen, nameTable, isAsync),
                            CreateProxyInvocationStatement(semanticModel.Compilation, gen, nameTable, sourceMethod)
                        }, new SyntaxNode[]
                        {
                            catchAndCloseProxyStatement
                        }
                                              )
                    });

                    targetMethod = targetMethod.AddNewLineTrivia();

                    if (sourceMethodEntry.IsLast && !(isAsync && includeCancellableAsyncMethods))
                    {
                        targetMethod = targetMethod.AddTrailingTrivia(gen.CreateEndRegionTrivia()).AddNewLineTrivia();
                    }

                    targetClass = gen.AddMembers(targetClass, targetMethod);

                    if (isAsync && includeCancellableAsyncMethods)
                    {
                        targetMethod = gen.MethodDeclaration(sourceMethod);
                        targetMethod = gen.AddParameters(targetMethod, new[] { gen.ParameterDeclaration(nameTable[MemberNames.CancellationTokenParameter], gen.TypeExpression(semanticModel.Compilation.RequireType <CancellationToken>())) });
                        targetMethod = gen.WithModifiers(targetMethod, isAsync ? DeclarationModifiers.Async : DeclarationModifiers.None);


                        targetMethod = gen.WithStatements(targetMethod, new SyntaxNode[]
                        {
                            // ==> try {
                            gen.TryCatchStatement(new SyntaxNode[]
                            {
                                CreateProxyVaraibleDeclaration(gen, nameTable, isAsync),
                                CreateCancellableProxyInvocationStatement(semanticModel.Compilation, gen, nameTable, sourceMethod)
                            }, new SyntaxNode[]
                            {
                                catchAndCloseProxyStatement
                            }
                                                  )
                        });


                        targetMethod = gen.AddWarningCommentIf(!suppressWarningComments, targetMethod.AddNewLineTrivia());

                        if (sourceMethodEntry.IsLast)
                        {
                            targetMethod = targetMethod.AddTrailingTrivia(gen.CreateEndRegionTrivia()).AddNewLineTrivia();
                        }

                        targetClass = gen.AddMembers(targetClass, targetMethod);
                    }
                }
            }

            #endregion

            #region Internal Methods

            targetClass = gen.AddMembers(targetClass, gen.AddWarningCommentIf(!suppressWarningComments, CreateGetProxyMethod(semanticModel.Compilation, gen, proxyInterface, nameTable, false).AddLeadingTrivia(gen.CreateRegionTrivia("Private Methods")).AddNewLineTrivia()));
            targetClass = gen.AddMembers(targetClass, gen.AddWarningCommentIf(!suppressWarningComments, CreateGetProxyMethod(semanticModel.Compilation, gen, proxyInterface, nameTable, true).AddNewLineTrivia()));
            targetClass = gen.AddMembers(targetClass, gen.AddWarningCommentIf(!suppressWarningComments, CreateStaticCloseProxyMethod(semanticModel.Compilation, gen, nameTable, false).AddNewLineTrivia()));
            targetClass = gen.AddMembers(targetClass, gen.AddWarningCommentIf(!suppressWarningComments, CreateStaticCloseProxyMethod(semanticModel.Compilation, gen, nameTable, true).AddNewLineTrivia()));
            targetClass = gen.AddMembers(targetClass, gen.AddWarningCommentIf(!suppressWarningComments, CreateCloseProxyMethod(semanticModel.Compilation, gen, nameTable, false).AddNewLineTrivia()));
            targetClass = gen.AddMembers(targetClass, gen.AddWarningCommentIf(!suppressWarningComments, CreateCloseProxyMethod(semanticModel.Compilation, gen, nameTable, true).AddNewLineTrivia()));
            targetClass = gen.AddMembers(targetClass, gen.AddWarningCommentIf(!suppressWarningComments, CreateEnsureProxyMethod(semanticModel.Compilation, gen, nameTable, false).AddNewLineTrivia()));
            targetClass = gen.AddMembers(targetClass, gen.AddWarningCommentIf(!suppressWarningComments, CreateEnsureProxyMethod(semanticModel.Compilation, gen, nameTable, true).AddTrailingTrivia(gen.CreateEndRegionTrivia()).AddNewLineTrivia()));
            targetClass = gen.AddMembers(targetClass, CreateDisposeMethods(semanticModel.Compilation, gen, nameTable, suppressWarningComments));

            if (withInternalProxy)
            {
                targetClass = gen.AddMembers(targetClass, proxyClass);
            }

            #endregion


            targetClass = AddGeneratedCodeAttribute(gen, targetClass);
            return(Task.FromResult((ClassDeclarationSyntax)targetClass));
        }