public async Task ClientCanReceiveFullConnectionCloseResponseWithoutErrorAtALowDataRate()
        {
            var chunkSize  = 64 * 128 * 1024;
            var chunkCount = 4;
            var chunkData  = new byte[chunkSize];

            var requestAborted   = false;
            var appFuncCompleted = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var mockKestrelTrace = new Mock <IKestrelTrace>();

            var testContext = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object)
            {
                ServerOptions =
                {
                    Limits                  =
                    {
                        MinResponseDataRate = new MinDataRate(bytesPerSecond: 240, gracePeriod: TimeSpan.FromSeconds(2))
                    }
                }
            };

            testContext.InitializeHeartbeat();
            var dateHeaderValueManager = new DateHeaderValueManager();

            dateHeaderValueManager.OnHeartbeat(DateTimeOffset.MinValue);
            testContext.DateHeaderValueManager = dateHeaderValueManager;

            var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0));

            async Task App(HttpContext context)
            {
                context.RequestAborted.Register(() =>
                {
                    requestAborted = true;
                });

                for (var i = 0; i < chunkCount; i++)
                {
                    await context.Response.BodyWriter.WriteAsync(new Memory <byte>(chunkData, 0, chunkData.Length), context.RequestAborted);
                }

                appFuncCompleted.SetResult(null);
            }

            using (var server = new TestServer(App, testContext, listenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    // Close the connection with the last request so AssertStreamCompleted actually completes.
                    await connection.Send(
                        "GET / HTTP/1.1",
                        "Host:",
                        "Connection: close",
                        "",
                        "");

                    await connection.Receive(
                        "HTTP/1.1 200 OK",
                        "Connection: close",
                        $"Date: {dateHeaderValueManager.GetDateHeaderValues().String}");

                    // Make sure consuming a single chunk exceeds the 2 second timeout.
                    var targetBytesPerSecond = chunkSize / 4;

                    // expectedBytes was determined by manual testing. A constant Date header is used, so this shouldn't change unless
                    // the response header writing logic or response body chunking logic itself changes.
                    await AssertStreamCompletedAtTargetRate(connection.Stream, expectedBytes : 33_553_556, targetBytesPerSecond);

                    await appFuncCompleted.Task.DefaultTimeout();
                }
                await server.StopAsync();
            }

            mockKestrelTrace.Verify(t => t.ResponseMinimumDataRateNotSatisfied(It.IsAny <string>(), It.IsAny <string>()), Times.Never());
            mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny <string>()), Times.Once());
            Assert.False(requestAborted);
        }
        public async Task ConnectionClosedWhenBothRequestAndResponseExperienceBackPressure()
        {
            const int bufferSize   = 65536;
            const int bufferCount  = 100;
            var       responseSize = bufferCount * bufferSize;
            var       buffer       = new byte[bufferSize];

            var responseRateTimeoutMessageLogged = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var connectionStopMessageLogged      = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var requestAborted = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var copyToAsyncCts = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);

            var mockKestrelTrace = new Mock <IKestrelTrace>();

            mockKestrelTrace
            .Setup(trace => trace.ResponseMinimumDataRateNotSatisfied(It.IsAny <string>(), It.IsAny <string>()))
            .Callback(() => responseRateTimeoutMessageLogged.SetResult(null));
            mockKestrelTrace
            .Setup(trace => trace.ConnectionStop(It.IsAny <string>()))
            .Callback(() => connectionStopMessageLogged.SetResult(null));

            var testContext = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object)
            {
                ServerOptions =
                {
                    Limits                  =
                    {
                        MinResponseDataRate = new MinDataRate(bytesPerSecond: 1024 * 1024, gracePeriod: TimeSpan.FromSeconds(2)),
                        MaxRequestBodySize  = responseSize
                    }
                }
            };

            testContext.InitializeHeartbeat();

            var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0));

            async Task App(HttpContext context)
            {
                context.RequestAborted.Register(() =>
                {
                    requestAborted.SetResult(null);
                });

                try
                {
                    await context.Request.Body.CopyToAsync(context.Response.Body);
                }
                catch (Exception ex)
                {
                    copyToAsyncCts.SetException(ex);
                    throw;
                }
                finally
                {
                    await requestAborted.Task.DefaultTimeout();
                }

                copyToAsyncCts.SetException(new Exception("This shouldn't be reached."));
            }

            using (var server = new TestServer(App, testContext, listenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    // Close the connection with the last request so AssertStreamCompleted actually completes.
                    await connection.Send(
                        "POST / HTTP/1.1",
                        "Host:",
                        $"Content-Length: {responseSize}",
                        "",
                        "");

                    var sendTask = Task.Run(async() =>
                    {
                        for (var i = 0; i < bufferCount; i++)
                        {
                            await connection.Stream.WriteAsync(buffer, 0, buffer.Length);
                            await Task.Delay(10);
                        }
                    });

                    await requestAborted.Task.DefaultTimeout();

                    await responseRateTimeoutMessageLogged.Task.DefaultTimeout();

                    await connectionStopMessageLogged.Task.DefaultTimeout();

                    // Expect OperationCanceledException instead of IOException because the server initiated the abort due to a response rate timeout.
                    await Assert.ThrowsAnyAsync <OperationCanceledException>(() => copyToAsyncCts.Task).DefaultTimeout();
                    await AssertStreamAborted(connection.Stream, responseSize);
                }
                await server.StopAsync();
            }
        }
        public async Task ConnectionNotClosedWhenClientSatisfiesMinimumDataRateGivenLargeResponseHeaders()
        {
            var headerSize         = 1024 * 1024; // 1 MB for each header value
            var headerCount        = 64;          // 64 MB of headers per response
            var requestCount       = 4;           // Minimum of 256 MB of total response headers
            var headerValue        = new string('a', headerSize);
            var headerStringValues = new StringValues(Enumerable.Repeat(headerValue, headerCount).ToArray());

            var requestAborted   = false;
            var mockKestrelTrace = new Mock <IKestrelTrace>();

            var testContext = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object)
            {
                ServerOptions =
                {
                    Limits                  =
                    {
                        MinResponseDataRate = new MinDataRate(bytesPerSecond: 240, gracePeriod: TimeSpan.FromSeconds(2))
                    }
                }
            };

            testContext.InitializeHeartbeat();
            var dateHeaderValueManager = new DateHeaderValueManager();

            dateHeaderValueManager.OnHeartbeat(DateTimeOffset.MinValue);
            testContext.DateHeaderValueManager = dateHeaderValueManager;

            var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0));

            async Task App(HttpContext context)
            {
                context.RequestAborted.Register(() =>
                {
                    requestAborted = true;
                });

                context.Response.Headers[$"X-Custom-Header"] = headerStringValues;
                context.Response.ContentLength = 0;

                await context.Response.BodyWriter.FlushAsync();
            }

            using (var server = new TestServer(App, testContext, listenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    for (var i = 0; i < requestCount - 1; i++)
                    {
                        await connection.Send(
                            "GET / HTTP/1.1",
                            "Host:",
                            "",
                            "");
                    }

                    await connection.Send(
                        "GET / HTTP/1.1",
                        "Host:",
                        "",
                        "");

                    await connection.Receive(
                        "HTTP/1.1 200 OK",
                        $"Date: {dateHeaderValueManager.GetDateHeaderValues().String}");

                    var minResponseSize    = headerSize * headerCount;
                    var minTotalOutputSize = requestCount * minResponseSize;

                    // Make sure consuming a single set of response headers exceeds the 2 second timeout.
                    var targetBytesPerSecond = minResponseSize / 4;

                    // expectedBytes was determined by manual testing. A constant Date header is used, so this shouldn't change unless
                    // the response header writing logic itself changes.
                    await AssertBytesReceivedAtTargetRate(connection.Stream, expectedBytes : 268_439_596, targetBytesPerSecond);

                    connection.ShutdownSend();
                    await connection.WaitForConnectionClose();
                }

                await server.StopAsync();
            }

            mockKestrelTrace.Verify(t => t.ResponseMinimumDataRateNotSatisfied(It.IsAny <string>(), It.IsAny <string>()), Times.Never());
            mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny <string>()), Times.Once());
            Assert.False(requestAborted);
        }
        public async Task ConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate()
        {
            var       logger       = LoggerFactory.CreateLogger($"{ typeof(ResponseTests).FullName}.{ nameof(ConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate)}");
            const int chunkSize    = 1024;
            const int chunks       = 256 * 1024;
            var       responseSize = chunks * chunkSize;
            var       chunkData    = new byte[chunkSize];

            var responseRateTimeoutMessageLogged = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var connectionStopMessageLogged      = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var requestAborted   = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var appFuncCompleted = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);

            var mockKestrelTrace = new Mock <IKestrelTrace>();

            mockKestrelTrace
            .Setup(trace => trace.ResponseMinimumDataRateNotSatisfied(It.IsAny <string>(), It.IsAny <string>()))
            .Callback(() => responseRateTimeoutMessageLogged.SetResult(null));
            mockKestrelTrace
            .Setup(trace => trace.ConnectionStop(It.IsAny <string>()))
            .Callback(() => connectionStopMessageLogged.SetResult(null));

            var testContext = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object)
            {
                ServerOptions =
                {
                    Limits                  =
                    {
                        MinResponseDataRate = new MinDataRate(bytesPerSecond: 1024 * 1024, gracePeriod: TimeSpan.FromSeconds(2))
                    }
                }
            };

            testContext.InitializeHeartbeat();

            var appLogger = LoggerFactory.CreateLogger("App");

            async Task App(HttpContext context)
            {
                appLogger.LogInformation("Request received");
                context.RequestAborted.Register(() => requestAborted.SetResult(null));

                context.Response.ContentLength = responseSize;

                var i = 0;

                try
                {
                    for (; i < chunks; i++)
                    {
                        await context.Response.BodyWriter.WriteAsync(new Memory <byte>(chunkData, 0, chunkData.Length), context.RequestAborted);

                        await Task.Yield();
                    }

                    appFuncCompleted.SetException(new Exception("This shouldn't be reached."));
                }
                catch (OperationCanceledException)
                {
                    appFuncCompleted.SetResult(null);
                    throw;
                }
                catch (Exception ex)
                {
                    appFuncCompleted.SetException(ex);
                }
                finally
                {
                    appLogger.LogInformation("Wrote {total} bytes", chunkSize * i);
                    await requestAborted.Task.DefaultTimeout();
                }
            }

            using (var server = new TestServer(App, testContext))
            {
                using (var connection = server.CreateConnection())
                {
                    logger.LogInformation("Sending request");
                    await connection.Send(
                        "GET / HTTP/1.1",
                        "Host:",
                        "",
                        "");

                    logger.LogInformation("Sent request");

                    var sw = Stopwatch.StartNew();
                    logger.LogInformation("Waiting for connection to abort.");

                    await requestAborted.Task.DefaultTimeout();

                    await responseRateTimeoutMessageLogged.Task.DefaultTimeout();

                    await connectionStopMessageLogged.Task.DefaultTimeout();

                    await appFuncCompleted.Task.DefaultTimeout();
                    await AssertStreamAborted(connection.Stream, chunkSize *chunks);

                    sw.Stop();
                    logger.LogInformation("Connection was aborted after {totalMilliseconds}ms.", sw.ElapsedMilliseconds);
                }
                await server.StopAsync();
            }
        }
        public async Task HttpsConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate()
        {
            const int chunkSize = 1024;
            const int chunks    = 256 * 1024;
            var       chunkData = new byte[chunkSize];

            var certificate = TestResources.GetTestCertificate();

            var responseRateTimeoutMessageLogged = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var connectionStopMessageLogged      = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var aborted          = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var appFuncCompleted = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);

            var mockKestrelTrace = new Mock <IKestrelTrace>();

            mockKestrelTrace
            .Setup(trace => trace.ResponseMinimumDataRateNotSatisfied(It.IsAny <string>(), It.IsAny <string>()))
            .Callback(() => responseRateTimeoutMessageLogged.SetResult(null));
            mockKestrelTrace
            .Setup(trace => trace.ConnectionStop(It.IsAny <string>()))
            .Callback(() => connectionStopMessageLogged.SetResult(null));

            var testContext = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object)
            {
                ServerOptions =
                {
                    Limits                  =
                    {
                        MinResponseDataRate = new MinDataRate(bytesPerSecond: 1024 * 1024, gracePeriod: TimeSpan.FromSeconds(2))
                    }
                }
            };

            testContext.InitializeHeartbeat();

            void ConfigureListenOptions(ListenOptions listenOptions)
            {
                listenOptions.UseHttps(new HttpsConnectionAdapterOptions {
                    ServerCertificate = certificate
                });
            }

            using (var server = new TestServer(async context =>
            {
                context.RequestAborted.Register(() =>
                {
                    aborted.SetResult(null);
                });

                context.Response.ContentLength = chunks * chunkSize;

                try
                {
                    for (var i = 0; i < chunks; i++)
                    {
                        await context.Response.BodyWriter.WriteAsync(new Memory <byte>(chunkData, 0, chunkData.Length), context.RequestAborted);
                    }
                }
                catch (OperationCanceledException)
                {
                    appFuncCompleted.SetResult(null);
                    throw;
                }
                finally
                {
                    await aborted.Task.DefaultTimeout();
                }
            }, testContext, ConfigureListenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    using (var sslStream = new SslStream(connection.Stream, false, (sender, cert, chain, errors) => true, null))
                    {
                        await sslStream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false);

                        var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n");
                        await sslStream.WriteAsync(request, 0, request.Length);

                        await aborted.Task.DefaultTimeout();

                        await responseRateTimeoutMessageLogged.Task.DefaultTimeout();

                        await connectionStopMessageLogged.Task.DefaultTimeout();

                        await appFuncCompleted.Task.DefaultTimeout();

                        await AssertStreamAborted(connection.Stream, chunkSize *chunks);
                    }
                }
                await server.StopAsync();
            }
        }
        public async Task AppCanHandleClientAbortingConnectionMidResponse(ListenOptions listenOptions)
        {
            const int connectionResetEventId = 19;
            const int connectionFinEventId   = 6;
            const int connectionStopEventId  = 2;

            const int responseBodySegmentSize  = 65536;
            const int responseBodySegmentCount = 100;

            var requestAborted  = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var appCompletedTcs = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);

            var scratchBuffer = new byte[responseBodySegmentSize];

            using (var server = new TestServer(async context =>
            {
                context.RequestAborted.Register(() => requestAborted.SetResult(null));

                for (var i = 0; i < responseBodySegmentCount; i++)
                {
                    await context.Response.Body.WriteAsync(scratchBuffer, 0, scratchBuffer.Length);
                    await Task.Delay(10);
                }

                await requestAborted.Task.DefaultTimeout();
                appCompletedTcs.SetResult(null);
            }, new TestServiceContext(LoggerFactory), listenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    await connection.Send(
                        "GET / HTTP/1.1",
                        "Host:",
                        "",
                        "");

                    // Read just part of the response and close the connection.
                    // https://github.com/aspnet/KestrelHttpServer/issues/2554
                    await connection.Stream.ReadAsync(scratchBuffer, 0, scratchBuffer.Length);

                    connection.Reset();
                }

                await requestAborted.Task.DefaultTimeout();

                // After the RequestAborted token is tripped, the connection reset should be logged.
                // On Linux and macOS, the connection close is still sometimes observed as a FIN despite the LingerState.
                var presShutdownTransportLogs = TestSink.Writes.Where(
                    w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" ||
                    w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets");
                var connectionResetLogs = presShutdownTransportLogs.Where(
                    w => w.EventId == connectionResetEventId ||
                    (!TestPlatformHelper.IsWindows && w.EventId == connectionFinEventId));

                Assert.NotEmpty(connectionResetLogs);

                // On macOS, the default 5 shutdown timeout is insufficient for the write loop to complete, so give it extra time.
                await appCompletedTcs.Task.DefaultTimeout();

                await server.StopAsync();
            }

            var coreLogs = TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel");

            Assert.Single(coreLogs.Where(w => w.EventId == connectionStopEventId));

            var transportLogs = TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel" ||
                                                      w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" ||
                                                      w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets");

            Assert.Empty(transportLogs.Where(w => w.LogLevel > LogLevel.Debug));
        }
        public async Task WritingToConnectionAfterUnobservedCloseTriggersRequestAbortedToken(ListenOptions listenOptions)
        {
            const int connectionPausedEventId = 4;
            const int maxRequestBufferSize    = 4096;

            var requestAborted         = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var readCallbackUnwired    = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var clientClosedConnection = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var writeTcs = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);

            TestSink.MessageLogged += context =>
            {
                if (context.LoggerName != "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" &&
                    context.LoggerName != "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets")
                {
                    return;
                }

                if (context.EventId.Id == connectionPausedEventId)
                {
                    readCallbackUnwired.TrySetResult(null);
                }
            };

            var mockKestrelTrace = new Mock <IKestrelTrace>();

            var testContext = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object)
            {
                ServerOptions =
                {
                    Limits                         =
                    {
                        MaxRequestBufferSize       = maxRequestBufferSize,
                        MaxRequestLineSize         = maxRequestBufferSize,
                        MaxRequestHeadersTotalSize = maxRequestBufferSize,
                    }
                }
            };

            var scratchBuffer = new byte[maxRequestBufferSize * 8];

            using (var server = new TestServer(async context =>
            {
                context.RequestAborted.Register(() => requestAborted.SetResult(null));

                await clientClosedConnection.Task;

                try
                {
                    for (var i = 0; i < 1000; i++)
                    {
                        await context.Response.BodyWriter.WriteAsync(new Memory <byte>(scratchBuffer, 0, scratchBuffer.Length), context.RequestAborted);
                        await Task.Delay(10);
                    }
                }
                catch (Exception ex)
                {
                    writeTcs.SetException(ex);
                    throw;
                }
                finally
                {
                    await requestAborted.Task.DefaultTimeout();
                }

                writeTcs.SetException(new Exception("This shouldn't be reached."));
            }, testContext, listenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    await connection.Send(
                        "POST / HTTP/1.1",
                        "Host:",
                        $"Content-Length: {scratchBuffer.Length}",
                        "",
                        "");

                    var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length);

                    // Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed.
                    await readCallbackUnwired.Task.DefaultTimeout();
                }

                clientClosedConnection.SetResult(null);

                await Assert.ThrowsAnyAsync <OperationCanceledException>(() => writeTcs.Task).DefaultTimeout();

                await server.StopAsync();
            }

            mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny <string>()), Times.Once());
            Assert.True(requestAborted.Task.IsCompleted);
        }
        public async Task ThrowsOnWriteWithRequestAbortedTokenAfterRequestIsAborted(ListenOptions listenOptions)
        {
            // This should match _maxBytesPreCompleted in SocketOutput
            var maxBytesPreCompleted = 65536;

            // Ensure string is long enough to disable write-behind buffering
            var largeString = new string('a', maxBytesPreCompleted + 1);

            var writeTcs         = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var requestAbortedWh = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var requestStartWh   = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);

            using (var server = new TestServer(async httpContext =>
            {
                requestStartWh.SetResult(null);

                var response = httpContext.Response;
                var request = httpContext.Request;
                var lifetime = httpContext.Features.Get <IHttpRequestLifetimeFeature>();

                lifetime.RequestAborted.Register(() => requestAbortedWh.SetResult(null));
                await requestAbortedWh.Task.DefaultTimeout();

                try
                {
                    await response.WriteAsync(largeString, cancellationToken: lifetime.RequestAborted);
                }
                catch (Exception ex)
                {
                    writeTcs.SetException(ex);
                    throw;
                }
                finally
                {
                    await requestAbortedWh.Task.DefaultTimeout();
                }

                writeTcs.SetException(new Exception("This shouldn't be reached."));
            }, new TestServiceContext(LoggerFactory), listenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    await connection.Send(
                        "POST / HTTP/1.1",
                        "Host:",
                        "Content-Length: 0",
                        "",
                        "");

                    await requestStartWh.Task.DefaultTimeout();
                }

                // Write failed - can throw TaskCanceledException or OperationCanceledException,
                // depending on how far the canceled write goes.
                await Assert.ThrowsAnyAsync <OperationCanceledException>(async() => await writeTcs.Task).DefaultTimeout();

                // RequestAborted tripped
                await requestAbortedWh.Task.DefaultTimeout();

                await server.StopAsync();
            }
        }
Exemple #9
0
        public async Task ServerCanAbortConnectionAfterUnobservedClose(ListenOptions listenOptions)
        {
            const int connectionPausedEventId  = 4;
            const int connectionFinSentEventId = 7;
            const int maxRequestBufferSize     = 4096;

            var readCallbackUnwired    = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var clientClosedConnection = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var serverClosedConnection = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var appFuncCompleted       = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);

            TestSink.MessageLogged += context =>
            {
                if (context.LoggerName != "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" &&
                    context.LoggerName != "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets")
                {
                    return;
                }

                if (context.EventId.Id == connectionPausedEventId)
                {
                    readCallbackUnwired.TrySetResult(null);
                }
                else if (context.EventId == connectionFinSentEventId)
                {
                    serverClosedConnection.SetResult(null);
                }
            };

            var mockKestrelTrace = new Mock <IKestrelTrace>();
            var testContext      = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object)
            {
                ServerOptions =
                {
                    Limits                         =
                    {
                        MaxRequestBufferSize       = maxRequestBufferSize,
                        MaxRequestLineSize         = maxRequestBufferSize,
                        MaxRequestHeadersTotalSize = maxRequestBufferSize,
                    }
                }
            };

            var scratchBuffer = new byte[maxRequestBufferSize * 8];

            using (var server = new TestServer(async context =>
            {
                await clientClosedConnection.Task;

                context.Abort();

                await serverClosedConnection.Task;

                appFuncCompleted.SetResult(null);
            }, testContext, listenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    await connection.Send(
                        "POST / HTTP/1.1",
                        "Host:",
                        $"Content-Length: {scratchBuffer.Length}",
                        "",
                        "");

                    var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length);

                    // Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed.
                    await readCallbackUnwired.Task.DefaultTimeout();
                }

                clientClosedConnection.SetResult(null);

                await appFuncCompleted.Task.DefaultTimeout();

                await server.StopAsync();
            }

            mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny <string>()), Times.Once());
        }
Exemple #10
0
        public async Task RequestsCanBeAbortedMidRead(ListenOptions listenOptions)
        {
            // This needs a timeout.
            const int applicationAbortedConnectionId = 34;

            var testContext = new TestServiceContext(LoggerFactory);

            var readTcs         = new TaskCompletionSource <object>(TaskCreationOptions.RunContinuationsAsynchronously);
            var registrationTcs = new TaskCompletionSource <int>(TaskCreationOptions.RunContinuationsAsynchronously);
            var requestId       = 0;

            using (var server = new TestServer(async httpContext =>
            {
                requestId++;

                var response = httpContext.Response;
                var request = httpContext.Request;
                var lifetime = httpContext.Features.Get <IHttpRequestLifetimeFeature>();

                lifetime.RequestAborted.Register(() => registrationTcs.TrySetResult(requestId));

                if (requestId == 1)
                {
                    response.Headers["Content-Length"] = new[] { "5" };

                    await response.WriteAsync("World");
                }
                else
                {
                    var readTask = request.Body.CopyToAsync(Stream.Null);

                    lifetime.Abort();

                    try
                    {
                        await readTask;
                    }
                    catch (Exception ex)
                    {
                        readTcs.SetException(ex);
                        throw;
                    }
                    finally
                    {
                        await registrationTcs.Task.DefaultTimeout();
                    }

                    readTcs.SetException(new Exception("This shouldn't be reached."));
                }
            }, testContext, listenOptions))
            {
                using (var connection = server.CreateConnection())
                {
                    // Full request and response
                    await connection.Send(
                        "POST / HTTP/1.1",
                        "Host:",
                        "Content-Length: 5",
                        "",
                        "Hello");

                    await connection.Receive(
                        "HTTP/1.1 200 OK",
                        $"Date: {testContext.DateHeaderValue}",
                        "Content-Length: 5",
                        "",
                        "World");

                    // Never send the body so CopyToAsync always fails.
                    await connection.Send("POST / HTTP/1.1",
                                          "Host:",
                                          "Content-Length: 5",
                                          "",
                                          "");

                    await connection.WaitForConnectionClose();
                }
                await server.StopAsync();
            }

            await Assert.ThrowsAsync <TaskCanceledException>(async() => await readTcs.Task);

            // The cancellation token for only the last request should be triggered.
            var abortedRequestId = await registrationTcs.Task.DefaultTimeout();

            Assert.Equal(2, abortedRequestId);

            Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel" &&
                                                w.EventId == applicationAbortedConnectionId));
        }