Example #1
0
 public CoercedAwaitableInfo(Expression coercerExpression, Type coercerResultType,
                             AwaitableInfo coercedAwaitableInfo)
 {
     CoercerExpression = coercerExpression;
     CoercerResultType = coercerResultType;
     AwaitableInfo     = coercedAwaitableInfo;
 }
    public static bool IsTypeAwaitable(
        [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.NonPublicProperties)] Type type,
        out CoercedAwaitableInfo info)
    {
        if (AwaitableInfo.IsTypeAwaitable(type, out var directlyAwaitableInfo))
        {
            info = new CoercedAwaitableInfo(directlyAwaitableInfo);
            return(true);
        }

        // It's not directly awaitable, but maybe we can coerce it.
        // Currently we support coercing FSharpAsync<T>.
        if (ObjectMethodExecutorFSharpSupport.TryBuildCoercerFromFSharpAsyncToAwaitable(type,
                                                                                        out var coercerExpression,
                                                                                        out var coercerResultType))
        {
            if (AwaitableInfo.IsTypeAwaitable(coercerResultType, out var coercedAwaitableInfo))
            {
                info = new CoercedAwaitableInfo(coercerExpression, coercerResultType, coercedAwaitableInfo);
                return(true);
            }
        }

        info = default(CoercedAwaitableInfo);
        return(false);
    }
    public static bool IsAwaitable(this Type type)
    {
        if (type == null || type == typeof(void))
        {
            return(false);
        }

        return(AwaitableInfo.IsTypeAwaitable(type, out _));
    }
Example #4
0
        public static bool IsTypeAwaitable(Type type, out CoercedAwaitableInfo info)
        {
            if (AwaitableInfo.IsTypeAwaitable(type, out var directlyAwaitableInfo))
            {
                info = new CoercedAwaitableInfo(directlyAwaitableInfo);
                return(true);
            }

            info = default(CoercedAwaitableInfo);
            return(false);
        }
Example #5
0
        public override BoundNode VisitAwaitExpression(BoundAwaitExpression node)
        {
            BoundExpression expression = (BoundExpression)this.Visit(node.Expression);
            TypeSymbol      type       = this.VisitType(node.Type);

            AwaitableInfo info = node.AwaitableInfo;

            return(node.Update(
                       expression,
                       info.Update(VisitMethodSymbol(info.GetAwaiter), VisitPropertySymbol(info.IsCompleted), VisitMethodSymbol(info.GetResult)),
                       type));
        }
    private void ApplyReturnTypeMetadata(ActionModel action)
    {
        var returnType = action.ActionMethod.ReturnType;

        if (AwaitableInfo.IsTypeAwaitable(returnType, out var awaitableInfo))
        {
            returnType = awaitableInfo.ResultType;
        }

        if (returnType is not null && typeof(IEndpointMetadataProvider).IsAssignableFrom(returnType))
        {
            object?[]? invokeArgs = null;

            for (var i = 0; i < action.Selectors.Count; i++)
            {
                // Return type implements IEndpointMetadataProvider
                var context = new EndpointMetadataContext(action.ActionMethod, action.Selectors[i].EndpointMetadata, _serviceProvider);
                invokeArgs ??= new object[1];
                invokeArgs[0] = context;
                PopulateMetadataForEndpointMethod.MakeGenericMethod(returnType).Invoke(null, invokeArgs);
            }
        }
    }
Example #7
0
        private void Generate(MethodModel method)
        {
            // [DebuggerStepThrough]
            WriteLine($"[{S(typeof(DebuggerStepThroughAttribute))}]");
            WriteLine($"public async {S(typeof(Task))} {method.MethodInfo.Name}({S(typeof(HttpContext))} httpContext)");
            WriteLine("{");
            Indent();
            var ctors = _model.HandlerType.GetConstructors();

            if (ctors.Length > 1 || ctors[0].GetParameters().Length > 0)
            {
                // Lazy, defer to DI system if
                WriteLine($"var handler = ({S(_model.HandlerType)})_factory(httpContext.RequestServices, {S(typeof(Array))}.Empty<{S(typeof(object))}>());");
            }
            else
            {
                WriteLine($"var handler = new {S(_model.HandlerType)}();");
            }

            // Declare locals
            var hasFromBody = false;
            var hasFromForm = false;

            foreach (var parameter in method.Parameters)
            {
                var parameterName = "arg_" + parameter.Name.Replace("_", "__");
                if (parameter.ParameterType.Equals(typeof(HttpContext)))
                {
                    WriteLine($"var {parameterName} = httpContext;");
                }
                else if (parameter.ParameterType.Equals(typeof(IFormCollection)))
                {
                    WriteLine($"var {parameterName} = await httpContext.Request.ReadFormAsync();");
                }
                else if (parameter.FromRoute != null)
                {
                    GenerateConvert(parameterName, parameter.ParameterType, parameter.FromRoute, "httpContext.Request.RouteValues", nullable: true);
                }
                else if (parameter.FromQuery != null)
                {
                    GenerateConvert(parameterName, parameter.ParameterType, parameter.FromQuery, "httpContext.Request.Query");
                }
                else if (parameter.FromHeader != null)
                {
                    GenerateConvert(parameterName, parameter.ParameterType, parameter.FromHeader, "httpContext.Request.Headers");
                }
                else if (parameter.FromServices)
                {
                    WriteLine($"var {parameterName} = httpContext.RequestServices.GetRequiredService<{S(parameter.ParameterType)}>();");
                }
                else if (parameter.FromForm != null)
                {
                    if (!hasFromForm)
                    {
                        WriteLine($"var formCollection = await httpContext.Request.ReadFormAsync();");
                        hasFromForm = true;
                    }
                    GenerateConvert(parameterName, parameter.ParameterType, parameter.FromForm, "formCollection");
                }
                else if (parameter.FromBody)
                {
                    if (!hasFromBody)
                    {
                        WriteLine($"var reader = httpContext.RequestServices.GetService<{S(typeof(IHttpRequestReader))}>() ?? _requestReader;");
                        hasFromBody = true;
                    }

                    if (!parameter.ParameterType.Equals(typeof(JsonElement)))
                    {
                        FromBodyTypes.Add(parameter.ParameterType);
                    }

                    WriteLine($"var {parameterName} = ({S(parameter.ParameterType)})await reader.ReadAsync(httpContext, typeof({S(parameter.ParameterType)}));");
                }
            }

            AwaitableInfo awaitableInfo = default;

            // Populate locals
            if (method.MethodInfo.ReturnType.Equals(typeof(void)))
            {
                Write("");
            }
            else
            {
                if (AwaitableInfo.IsTypeAwaitable(method.MethodInfo.ReturnType, out awaitableInfo))
                {
                    if (awaitableInfo.ResultType.Equals(typeof(void)))
                    {
                        Write("await ");
                    }
                    else
                    {
                        Write("var result = await ");
                    }
                }
                else
                {
                    Write("var result = ");
                }
            }
            WriteNoIndent($"handler.{method.MethodInfo.Name}(");
            bool first = true;

            foreach (var parameter in method.Parameters)
            {
                var parameterName = "arg_" + parameter.Name.Replace("_", "__");
                if (!first)
                {
                    WriteNoIndent(", ");
                }
                WriteNoIndent(parameterName);
                first = false;
            }
            WriteLineNoIndent(");");
            var unwrappedType = awaitableInfo.ResultType ?? method.MethodInfo.ReturnType;

            if (_metadataLoadContext.Resolve <Result>().IsAssignableFrom(unwrappedType))
            {
                WriteLine("await result.ExecuteAsync(httpContext);");
            }
            else if (!unwrappedType.Equals(typeof(void)))
            {
                WriteLine($"await new {S(typeof(ObjectResult))}(result).ExecuteAsync(httpContext);");
            }
            Unindent();
            WriteLine("}");
            WriteLine("");
        }
 public CoercedAwaitableInfo(AwaitableInfo awaitableInfo)
 {
     AwaitableInfo     = awaitableInfo;
     CoercerExpression = null;
     CoercerResultType = null;
 }
        // Expression tree impl
        internal static void Build(Type handlerType, IEndpointRouteBuilder routes)
        {
            var model = HttpModel.FromType(handlerType);

            ObjectFactory factory = null;

            // REVIEW: Should this be lazy?
            var httpRequestReader = routes.ServiceProvider.GetRequiredService <IHttpRequestReader>();

            foreach (var method in model.Methods)
            {
                // Nothing to route to
                if (method.RoutePattern == null)
                {
                    continue;
                }

                var  needForm = false;
                var  needBody = false;
                Type bodyType = null;
                // Non void return type

                // Task Invoke(HttpContext httpContext)
                // {
                //     // The type is activated via DI if it has args
                //     return ExecuteResultAsync(new THttpHandler(...).Method(..), httpContext);
                // }

                // void return type

                // Task Invoke(HttpContext httpContext)
                // {
                //     new THttpHandler(...).Method(..)
                //     return Task.CompletedTask;
                // }

                var httpContextArg = Expression.Parameter(typeof(HttpContext), "httpContext");
                // This argument represents the deserialized body returned from IHttpRequestReader
                // when the method has a FromBody attribute declared
                var deserializedBodyArg = Expression.Parameter(typeof(object), "bodyValue");

                var requestServicesExpr = Expression.Property(httpContextArg, nameof(HttpContext.RequestServices));

                // Fast path: We can skip the activator if there's only a default ctor with 0 args
                var ctors = handlerType.GetConstructors();

                Expression httpHandlerExpression = null;

                if (method.MethodInfo.IsStatic)
                {
                    // Do nothing
                }
                else if (ctors.Length == 1 && ctors[0].GetParameters().Length == 0)
                {
                    httpHandlerExpression = Expression.New(ctors[0]);
                }
                else
                {
                    // Create a factory lazily for this handlerType
                    if (factory == null)
                    {
                        factory = ActivatorUtilities.CreateFactory(handlerType, Type.EmptyTypes);
                    }

                    // This invokes the cached factory to create the instance then casts it to the target type
                    var invokeFactoryExpr = Expression.Invoke(Expression.Constant(factory), requestServicesExpr, Expression.Constant(null, typeof(object[])));
                    httpHandlerExpression = Expression.Convert(invokeFactoryExpr, handlerType);
                }

                var args = new List <Expression>();

                var httpRequestExpr = Expression.Property(httpContextArg, nameof(HttpContext.Request));
                foreach (var parameter in method.Parameters)
                {
                    Expression paramterExpression = Expression.Default(parameter.ParameterType);

                    if (parameter.FromQuery != null)
                    {
                        var queryProperty = Expression.Property(httpRequestExpr, nameof(HttpRequest.Query));
                        paramterExpression = BindArgument(queryProperty, parameter, parameter.FromQuery);
                    }
                    else if (parameter.FromHeader != null)
                    {
                        var headersProperty = Expression.Property(httpRequestExpr, nameof(HttpRequest.Headers));
                        paramterExpression = BindArgument(headersProperty, parameter, parameter.FromHeader);
                    }
                    else if (parameter.FromRoute != null)
                    {
                        var routeValuesProperty = Expression.Property(httpRequestExpr, nameof(HttpRequest.RouteValues));
                        paramterExpression = BindArgument(routeValuesProperty, parameter, parameter.FromRoute);
                    }
                    else if (parameter.FromCookie != null)
                    {
                        var cookiesProperty = Expression.Property(httpRequestExpr, nameof(HttpRequest.Cookies));
                        paramterExpression = BindArgument(cookiesProperty, parameter, parameter.FromCookie);
                    }
                    else if (parameter.FromServices)
                    {
                        paramterExpression = Expression.Call(GetRequiredServiceMethodInfo.MakeGenericMethod(parameter.ParameterType), requestServicesExpr);
                    }
                    else if (parameter.FromForm != null)
                    {
                        needForm = true;

                        var formProperty = Expression.Property(httpRequestExpr, nameof(HttpRequest.Form));
                        paramterExpression = BindArgument(formProperty, parameter, parameter.FromForm);
                    }
                    else if (parameter.FromBody)
                    {
                        if (needBody)
                        {
                            throw new InvalidOperationException(method.MethodInfo.Name + " cannot have more than one FromBody attribute.");
                        }

                        if (needForm)
                        {
                            throw new InvalidOperationException(method.MethodInfo.Name + " cannot mix FromBody and FromForm on the same method.");
                        }

                        needBody           = true;
                        bodyType           = parameter.ParameterType;
                        paramterExpression = Expression.Convert(deserializedBodyArg, bodyType);
                    }
                    else
                    {
                        if (parameter.ParameterType == typeof(IFormCollection))
                        {
                            needForm = true;

                            paramterExpression = Expression.Property(httpRequestExpr, nameof(HttpRequest.Form));
                        }
                        else if (parameter.ParameterType == typeof(HttpContext))
                        {
                            paramterExpression = httpContextArg;
                        }
                    }

                    args.Add(paramterExpression);
                }

                Expression body = null;

                var methodCall = Expression.Call(httpHandlerExpression, method.MethodInfo, args);

                // Exact request delegate match
                if (method.MethodInfo.ReturnType == typeof(void))
                {
                    var bodyExpressions = new List <Expression>
                    {
                        methodCall,
                        Expression.Property(null, (PropertyInfo)CompletedTaskMemberInfo)
                    };

                    body = Expression.Block(bodyExpressions);
                }
                else if (AwaitableInfo.IsTypeAwaitable(method.MethodInfo.ReturnType, out var info))
                {
                    if (method.MethodInfo.ReturnType == typeof(Task))
                    {
                        body = methodCall;
                    }
                    else if (method.MethodInfo.ReturnType.IsGenericType &&
                             method.MethodInfo.ReturnType.GetGenericTypeDefinition() == typeof(Task <>))
                    {
                        var typeArg = method.MethodInfo.ReturnType.GetGenericArguments()[0];

                        if (typeof(Result).IsAssignableFrom(typeArg))
                        {
                            body = Expression.Call(
                                ExecuteTaskResultOfTMethodInfo.MakeGenericMethod(typeArg),
                                methodCall,
                                httpContextArg);
                        }
                        else
                        {
                            // ExecuteTask<T>(handler.Method(..), httpContext);
                            body = Expression.Call(
                                ExecuteTaskOfTMethodInfo.MakeGenericMethod(typeArg),
                                methodCall,
                                httpContextArg);
                        }
                    }
                    else if (method.MethodInfo.ReturnType.IsGenericType &&
                             method.MethodInfo.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask <>))
                    {
                        var typeArg = method.MethodInfo.ReturnType.GetGenericArguments()[0];

                        if (typeof(Result).IsAssignableFrom(typeArg))
                        {
                            body = Expression.Call(
                                ExecuteValueResultTaskOfTMethodInfo.MakeGenericMethod(typeArg),
                                methodCall,
                                httpContextArg);
                        }
                        else
                        {
                            // ExecuteTask<T>(handler.Method(..), httpContext);
                            body = Expression.Call(
                                ExecuteValueTaskOfTMethodInfo.MakeGenericMethod(typeArg),
                                methodCall,
                                httpContextArg);
                        }
                    }
                    else
                    {
                        // TODO: Handle custom awaitables
                        throw new NotSupportedException("Unsupported return type " + method.MethodInfo.ReturnType);
                    }
                }
                else if (typeof(Result).IsAssignableFrom(method.MethodInfo.ReturnType))
                {
                    body = Expression.Call(methodCall, ResultExecuteAsync, httpContextArg);
                }
                else
                {
                    var newObjectResult = Expression.New(ObjectResultCtor, methodCall);
                    body = Expression.Call(newObjectResult, ObjectResultExecuteAsync, httpContextArg);
                }

                RequestDelegate requestDelegate = null;

                if (needBody)
                {
                    // We need to generate the code for reading from the body before calling into the
                    // delegate
                    var lambda  = Expression.Lambda <Func <HttpContext, object, Task> >(body, httpContextArg, deserializedBodyArg);
                    var invoker = lambda.Compile();

                    requestDelegate = async httpContext =>
                    {
                        var bodyValue = await httpRequestReader.ReadAsync(httpContext, bodyType);

                        await invoker(httpContext, bodyValue);
                    };
                }
                else if (needForm)
                {
                    var lambda  = Expression.Lambda <RequestDelegate>(body, httpContextArg);
                    var invoker = lambda.Compile();

                    requestDelegate = async httpContext =>
                    {
                        // Generating async code would just be insane so if the method needs the form populate it here
                        // so the within the method it's cached
                        await httpContext.Request.ReadFormAsync();

                        await invoker(httpContext);
                    };
                }
                else
                {
                    var lambda  = Expression.Lambda <RequestDelegate>(body, httpContextArg);
                    var invoker = lambda.Compile();

                    requestDelegate = invoker;
                }

                var displayName = method.MethodInfo.DeclaringType.Name + "." + method.MethodInfo.Name;

                routes.Map(method.RoutePattern, requestDelegate).Add(b =>
                {
                    foreach (CustomAttributeData item in method.Metadata)
                    {
                        var attr = item.Constructor.Invoke(item.ConstructorArguments.Select(a => a.Value).ToArray());
                        b.Metadata.Add(attr);
                    }
                });
            }
        }
Example #10
0
        private void Generate(MethodModel method)
        {
            // [DebuggerStepThrough]
            WriteLine($"[{typeof(DebuggerStepThroughAttribute)}]");

            var methodStartIndex = _codeBuilder.Length + 4 * _indent;

            WriteLine($"public async {typeof(Task)} {method.UniqueName}({typeof(HttpContext)} httpContext)");
            WriteLine("{");
            Indent();
            var ctors = _model.HandlerType.GetConstructors();

            if (!method.MethodInfo.IsStatic)
            {
                if (ctors.Length > 1 || ctors[0].GetParameters().Length > 0)
                {
                    // Lazy, defer to DI system if
                    WriteLine($"var handler = ({S(_model.HandlerType)})_factory(httpContext.RequestServices, {typeof(Array)}.Empty<{typeof(object)}>());");
                }
                else
                {
                    WriteLine($"var handler = new {S(_model.HandlerType)}();");
                }
            }

            // Declare locals
            var hasAwait    = false;
            var hasFromBody = false;
            var hasFromForm = false;

            foreach (var parameter in method.Parameters)
            {
                var parameterName = "arg_" + parameter.Name.Replace("_", "__");
                if (parameter.ParameterType.Equals(typeof(HttpContext)))
                {
                    WriteLine($"var {parameterName} = httpContext;");
                }
                else if (parameter.ParameterType.Equals(typeof(IFormCollection)))
                {
                    WriteLine($"var {parameterName} = await httpContext.Request.ReadFormAsync();");
                    hasAwait = true;
                }
                else if (parameter.FromRoute != null)
                {
                    GenerateConvert(parameterName, parameter.ParameterType, parameter.FromRoute, "httpContext.Request.RouteValues", nullable: true);
                }
                else if (parameter.FromQuery != null)
                {
                    GenerateConvert(parameterName, parameter.ParameterType, parameter.FromQuery, "httpContext.Request.Query");
                }
                else if (parameter.FromHeader != null)
                {
                    GenerateConvert(parameterName, parameter.ParameterType, parameter.FromHeader, "httpContext.Request.Headers");
                }
                else if (parameter.FromServices)
                {
                    WriteLine($"var {parameterName} = httpContext.RequestServices.GetRequiredService<{S(parameter.ParameterType)}>();");
                }
                else if (parameter.FromForm != null)
                {
                    if (!hasFromForm)
                    {
                        WriteLine($"var formCollection = await httpContext.Request.ReadFormAsync();");
                        hasAwait    = true;
                        hasFromForm = true;
                    }
                    GenerateConvert(parameterName, parameter.ParameterType, parameter.FromForm, "formCollection");
                }
                else if (parameter.FromBody)
                {
                    if (!hasFromBody)
                    {
                        hasFromBody = true;
                    }

                    if (!parameter.ParameterType.Equals(typeof(JsonElement)))
                    {
                        FromBodyTypes.Add(parameter.ParameterType);
                    }

                    WriteLine($"var {parameterName} = await httpContext.Request.ReadFromJsonAsync<{S(parameter.ParameterType)}>();");
                    hasAwait = true;
                }
                else
                {
                    WriteLine($"{S(parameter.ParameterType)} {parameterName} = default;");
                }
            }

            AwaitableInfo awaitableInfo = default;

            // Populate locals
            if (method.MethodInfo.ReturnType.Equals(typeof(void)))
            {
                Write("");
            }
            else
            {
                if (AwaitableInfo.IsTypeAwaitable(method.MethodInfo.ReturnType, out awaitableInfo))
                {
                    if (awaitableInfo.ResultType.Equals(typeof(void)))
                    {
                        if (hasAwait)
                        {
                            Write("await ");
                        }
                        else
                        {
                            Write("return ");
                        }
                    }
                    else
                    {
                        Write("var result = await ");
                        hasAwait = true;
                    }
                }
                else
                {
                    Write("var result = ");
                }
            }
            WriteNoIndent($"{(method.MethodInfo.IsStatic ? S(_model.HandlerType) : "handler")}.{method.MethodInfo.Name}(");
            bool first = true;

            foreach (var parameter in method.Parameters)
            {
                var parameterName = "arg_" + parameter.Name.Replace("_", "__");
                if (!first)
                {
                    WriteNoIndent(", ");
                }
                WriteNoIndent(parameterName);
                first = false;
            }
            WriteLineNoIndent(");");

            if (!hasAwait)
            {
                // Remove " async" from method signature.
                _codeBuilder.Remove(methodStartIndex + 6, 6);
            }

            void AwaitOrReturn(string executeAsync)
            {
                if (hasAwait)
                {
                    Write("await ");
                }
                else
                {
                    Write("return ");
                }

                WriteLineNoIndent(executeAsync);
            }

            var unwrappedType = awaitableInfo.ResultType ?? method.MethodInfo.ReturnType;

            if (_metadataLoadContext.Resolve <IResult>().IsAssignableFrom(unwrappedType))
            {
                AwaitOrReturn("result.ExecuteAsync(httpContext);");
            }
            else if (unwrappedType.Equals(typeof(string)))
            {
                AwaitOrReturn($"httpContext.Response.WriteAsync(result);");
            }
            else if (!unwrappedType.Equals(typeof(void)))
            {
                AwaitOrReturn($"httpContext.Response.WriteAsJsonAsync(result);");
            }
            else if (!hasAwait && method.MethodInfo.ReturnType.Equals(typeof(void)))
            {
                // If awaitableInfo.ResultType is void, we've already returned the awaitable directly.
                WriteLine($"return {typeof(Task)}.CompletedTask;");
            }

            Unindent();
            WriteLine("}");
            WriteLine("");
        }
Example #11
0
    public static bool IsTypeAwaitable(Type type, out AwaitableInfo?awaitableInfo)
    {
        // Based on Roslyn code: http://source.roslyn.io/#Microsoft.CodeAnalysis.Workspaces/Shared/Extensions/ISymbolExtensions.cs,db4d48ba694b9347

        // Awaitable must have method matching "object GetAwaiter()"
        var getAwaiterMethod = type.GetRuntimeMethods().FirstOrDefault(m =>
                                                                       m.Name.Equals("GetAwaiter", StringComparison.OrdinalIgnoreCase) &&
                                                                       m.GetParameters().Length == 0 &&
                                                                       m.ReturnType != null);

        if (getAwaiterMethod == null)
        {
            awaitableInfo = default;
            return(false);
        }

        var awaiterType = getAwaiterMethod.ReturnType;

        // Awaiter must have property matching "bool IsCompleted { get; }"
        var isCompletedProperty = awaiterType.GetRuntimeProperties().FirstOrDefault(p =>
                                                                                    p.Name.Equals("IsCompleted", StringComparison.OrdinalIgnoreCase) &&
                                                                                    p.PropertyType == typeof(bool) &&
                                                                                    p.GetMethod != null);

        if (isCompletedProperty == null)
        {
            awaitableInfo = default(AwaitableInfo);
            return(false);
        }

        // Awaiter must implement INotifyCompletion
        var awaiterInterfaces           = awaiterType.GetInterfaces();
        var implementsINotifyCompletion = awaiterInterfaces.Any(t => t == typeof(INotifyCompletion));

        if (!implementsINotifyCompletion)
        {
            awaitableInfo = default(AwaitableInfo);
            return(false);
        }

        // INotifyCompletion supplies a method matching "void OnCompleted(Action action)"
        var onCompletedMethod = typeof(INotifyCompletion).GetRuntimeMethods().Single(m =>
                                                                                     m.Name.Equals("OnCompleted", StringComparison.OrdinalIgnoreCase) &&
                                                                                     m.ReturnType == typeof(void) &&
                                                                                     m.GetParameters().Length == 1 &&
                                                                                     m.GetParameters()[0].ParameterType == typeof(Action));

        // Awaiter optionally implements ICriticalNotifyCompletion
        var        implementsICriticalNotifyCompletion = awaiterInterfaces.Any(t => t == typeof(ICriticalNotifyCompletion));
        MethodInfo?unsafeOnCompletedMethod;

        if (implementsICriticalNotifyCompletion)
        {
            // ICriticalNotifyCompletion supplies a method matching "void UnsafeOnCompleted(Action action)"
            unsafeOnCompletedMethod = typeof(ICriticalNotifyCompletion).GetRuntimeMethods().Single(m =>
                                                                                                   m.Name.Equals("UnsafeOnCompleted", StringComparison.OrdinalIgnoreCase) &&
                                                                                                   m.ReturnType == typeof(void) &&
                                                                                                   m.GetParameters().Length == 1 &&
                                                                                                   m.GetParameters()[0].ParameterType == typeof(Action));
        }
        else
        {
            unsafeOnCompletedMethod = null;
        }

        // Awaiter must have method matching "void GetResult" or "T GetResult()"
        var getResultMethod = awaiterType.GetRuntimeMethods().FirstOrDefault(m =>
                                                                             m.Name.Equals("GetResult") &&
                                                                             m.GetParameters().Length == 0);

        if (getResultMethod == null)
        {
            awaitableInfo = default;
            return(false);
        }

        awaitableInfo = new AwaitableInfo(
            awaiterType,
            isCompletedProperty,
            getResultMethod,
            onCompletedMethod,
            unsafeOnCompletedMethod,
            getResultMethod.ReturnType,
            getAwaiterMethod);
        return(true);
    }
Example #12
0
    public static bool IsTypeAwaitable(
        Type type,
        out AwaitableInfo awaitableInfo)
    {
        // Based on Roslyn code: http://source.roslyn.io/#Microsoft.CodeAnalysis.Workspaces/Shared/Extensions/ISymbolExtensions.cs,db4d48ba694b9347

        // Awaitable must have method matching "object GetAwaiter()"
        var getAwaiterMethod = type.GetMethod("GetAwaiter", Everything, Type.EmptyTypes);

        if (getAwaiterMethod is null)
        {
            awaitableInfo = default(AwaitableInfo);
            return(false);
        }

        var awaiterType = getAwaiterMethod.ReturnType;

        // Awaiter must have property matching "bool IsCompleted { get; }"
        var isCompletedProperty = awaiterType.GetProperty("IsCompleted", Everything, binder: null, returnType: typeof(bool), types: Type.EmptyTypes, modifiers: null);

        if (isCompletedProperty is null)
        {
            awaitableInfo = default(AwaitableInfo);
            return(false);
        }

        // Awaiter must implement INotifyCompletion
        var implementsINotifyCompletion = typeof(INotifyCompletion).IsAssignableFrom(awaiterType);

        if (!implementsINotifyCompletion)
        {
            awaitableInfo = default(AwaitableInfo);
            return(false);
        }

        // INotifyCompletion supplies a method matching "void OnCompleted(Action action)"
        var onCompletedMethod = INotifyCompletion_OnCompleted;

        // Awaiter optionally implements ICriticalNotifyCompletion
        var        implementsICriticalNotifyCompletion = typeof(ICriticalNotifyCompletion).IsAssignableFrom(awaiterType);
        MethodInfo?unsafeOnCompletedMethod             = null;

        if (implementsICriticalNotifyCompletion)
        {
            // ICriticalNotifyCompletion supplies a method matching "void UnsafeOnCompleted(Action action)"
            unsafeOnCompletedMethod = ICriticalNotifyCompletion_UnsafeOnCompleted;
        }

        // Awaiter must have method matching "void GetResult" or "T GetResult()"
        var getResultMethod = awaiterType.GetMethod("GetResult", Everything, Type.EmptyTypes);

        if (getResultMethod is null)
        {
            awaitableInfo = default(AwaitableInfo);
            return(false);
        }

        awaitableInfo = new AwaitableInfo(
            awaiterType,
            isCompletedProperty,
            getResultMethod,
            onCompletedMethod,
            unsafeOnCompletedMethod,
            getResultMethod.ReturnType,
            getAwaiterMethod);
        return(true);
    }