Пример #1
0
        public async Task RequestAbortedTokenFiredOnClientFIN()
        {
            var appStarted     = new SemaphoreSlim(0);
            var requestAborted = new SemaphoreSlim(0);
            var builder        = TransportSelector.GetHostBuilder()
                                 .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseKestrel()
                .UseUrls("http://127.0.0.1:0")
                .Configure(app => app.Run(async context =>
                {
                    appStarted.Release();

                    var token = context.RequestAborted;
                    token.Register(() => requestAborted.Release(2));
                    await requestAborted.WaitAsync().DefaultTimeout();
                }));
            })
                                 .ConfigureServices(AddTestLogging);

            using (var host = builder.Build())
            {
                await host.StartAsync();

                using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
                {
                    socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort()));
                    socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"));
                    await appStarted.WaitAsync();

                    socket.Shutdown(SocketShutdown.Send);
                    await requestAborted.WaitAsync().DefaultTimeout();
                }

                await host.StopAsync();
            }
        }
Пример #2
0
    private async Task RegisterIPEndPoint_Success(IPEndPoint endPoint, string testUrl, int testPort = 0)
    {
        var hostBuilder = TransportSelector.GetHostBuilder()
                          .ConfigureWebHost(webHostBuilder =>
        {
            webHostBuilder
            .UseKestrel(options =>
            {
                options.Listen(endPoint, listenOptions =>
                {
                    if (testUrl.StartsWith("https", StringComparison.Ordinal))
                    {
                        listenOptions.UseHttps(TestResources.GetTestCertificate());
                    }
                });
            })
            .Configure(ConfigureEchoAddress);
        })
                          .ConfigureServices(AddTestLogging);

        using (var host = hostBuilder.Build())
        {
            await host.StartAsync();

            var testUrlWithPort = $"{testUrl}:{(testPort == 0 ? host.GetPort() : testPort)}";

            var options = ((IOptions <KestrelServerOptions>)host.Services.GetService(typeof(IOptions <KestrelServerOptions>))).Value;
            Assert.Single(options.ListenOptions);

            var response = await HttpClientSlim.GetStringAsync(testUrlWithPort, validateCertificate : false);

            // Compare the response with Uri.ToString(), rather than testUrl directly.
            // Required to handle IPv6 addresses with zone index, like "fe80::3%1"
            Assert.Equal(new Uri(testUrlWithPort).ToString(), response);

            await host.StopAsync();
        }
    }
Пример #3
0
    public async Task RegisterHttpAddress_UpgradedToHttpsByConfigureEndpointDefaults()
    {
        var hostBuilder = TransportSelector.GetHostBuilder()
                          .ConfigureWebHost(webHostBuilder =>
        {
            webHostBuilder
            .UseKestrel(serverOptions =>
            {
                serverOptions.ConfigureEndpointDefaults(listenOptions =>
                {
                    listenOptions.UseHttps(TestResources.GetTestCertificate());
                });
            })
            .UseUrls("http://127.0.0.1:0")
            .Configure(app =>
            {
                var serverAddresses = app.ServerFeatures.Get <IServerAddressesFeature>();
                app.Run(context =>
                {
                    Assert.Single(serverAddresses.Addresses);
                    return(context.Response.WriteAsync(serverAddresses.Addresses.First()));
                });
            });
        })
                          .ConfigureServices(AddTestLogging);

        using (var host = hostBuilder.Build())
        {
            host.Start();

            var expectedUrl = $"https://127.0.0.1:{host.GetPort()}";
            var response    = await HttpClientSlim.GetStringAsync(expectedUrl, validateCertificate : false);

            Assert.Equal(expectedUrl, response);

            await host.StopAsync();
        }
    }
Пример #4
0
        public async Task RunIndividualTestCase(H2SpecTestCase testCase)
        {
            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .UseKestrel(options =>
            {
                options.Listen(IPAddress.Loopback, 0, listenOptions =>
                {
                    listenOptions.Protocols = HttpProtocols.Http2;
                    if (testCase.Https)
                    {
                        listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword");
                    }
                });
            })
                              .ConfigureServices(AddTestLogging)
                              .Configure(ConfigureHelloWorld);

            using (var host = hostBuilder.Build())
            {
                await host.StartAsync();

                H2SpecCommands.RunTest(testCase.Id, host.GetPort(), testCase.Https, Logger);
            }
        }
Пример #5
0
        public async Task SocketTransportExposesSocketsFeature()
        {
            var builder = TransportSelector.GetHostBuilder()
                          .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseKestrel()
                .UseUrls("http://127.0.0.1:0")
                .Configure(app =>
                {
                    app.Run(context =>
                    {
                        var socket = context.Features.Get <IConnectionSocketFeature>().Socket;
                        Assert.NotNull(socket);
                        Assert.Equal(ProtocolType.Tcp, socket.ProtocolType);
                        var ip = (IPEndPoint)socket.RemoteEndPoint;
                        Assert.Equal(ip.Address, context.Connection.RemoteIpAddress);
                        Assert.Equal(ip.Port, context.Connection.RemotePort);

                        return(Task.CompletedTask);
                    });
                });
            })
                          .ConfigureServices(AddTestLogging);

            using var host   = builder.Build();
            using var client = new HttpClient();

            await host.StartAsync();

            var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/");

            response.EnsureSuccessStatusCode();

            await host.StopAsync();
        }
Пример #6
0
        public async Task TestUnixDomainSocket()
        {
            var path = Path.GetTempFileName();

            Delete(path);

            try
            {
                var serverConnectionCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

                async Task EchoServer(ConnectionContext connection)
                {
                    // For graceful shutdown
                    var notificationFeature = connection.Features.Get <IConnectionLifetimeNotificationFeature>();

                    try
                    {
                        while (true)
                        {
                            var result = await connection.Transport.Input.ReadAsync(notificationFeature.ConnectionClosedRequested);

                            if (result.IsCompleted)
                            {
                                Logger.LogDebug("Application receive loop ending for connection {connectionId}.", connection.ConnectionId);
                                break;
                            }

                            await connection.Transport.Output.WriteAsync(result.Buffer.ToArray());

                            connection.Transport.Input.AdvanceTo(result.Buffer.End);
                        }
                    }
                    catch (OperationCanceledException)
                    {
                        Logger.LogDebug("Graceful shutdown triggered for {connectionId}.", connection.ConnectionId);
                    }
                    finally
                    {
                        serverConnectionCompletedTcs.TrySetResult();
                    }
                }

                var hostBuilder = TransportSelector.GetHostBuilder()
                                  .ConfigureWebHost(webHostBuilder =>
                {
                    webHostBuilder
                    .UseKestrel(o =>
                    {
                        o.ListenUnixSocket(path, builder =>
                        {
                            builder.Run(EchoServer);
                        });
                    })
                    .Configure(c => { });
                })
                                  .ConfigureServices(AddTestLogging);

                using (var host = hostBuilder.Build())
                {
                    await host.StartAsync().DefaultTimeout();

                    using (var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
                    {
                        await socket.ConnectAsync(new UnixDomainSocketEndPoint(path)).DefaultTimeout();

                        var data = Encoding.ASCII.GetBytes("Hello World");
                        await socket.SendAsync(data, SocketFlags.None).DefaultTimeout();

                        var buffer = new byte[data.Length];
                        var read   = 0;
                        while (read < data.Length)
                        {
                            var bytesReceived = await socket.ReceiveAsync(buffer.AsMemory(read, buffer.Length - read), SocketFlags.None).DefaultTimeout();

                            read += bytesReceived;
                            if (bytesReceived <= 0)
                            {
                                break;
                            }
                        }

                        Assert.Equal(data, buffer);
                    }

                    // Wait for the server to complete the loop because of the FIN
                    await serverConnectionCompletedTcs.Task.DefaultTimeout();

                    await host.StopAsync().DefaultTimeout();
                }
            }
            finally
            {
                Delete(path);
            }
        }
Пример #7
0
        public async Task TestUnixDomainSocketWithUrl()
        {
            var path = Path.GetTempFileName();
            var url  = $"http://unix:/{path}";

            Delete(path);

            try
            {
                var hostBuilder = TransportSelector.GetHostBuilder()
                                  .ConfigureWebHost(webHostBuilder =>
                {
                    webHostBuilder
                    .UseUrls(url)
                    .UseKestrel()
                    .Configure(app =>
                    {
                        app.Run(async context =>
                        {
                            await context.Response.WriteAsync("Hello World");
                        });
                    });
                })
                                  .ConfigureServices(AddTestLogging);

                using (var host = hostBuilder.Build())
                {
                    await host.StartAsync().DefaultTimeout();

                    // https://github.com/dotnet/corefx/issues/5999
                    // .NET Core HttpClient does not support unix sockets, it's difficult to parse raw response data. below is a little hacky way.
                    using (var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
                    {
                        await socket.ConnectAsync(new UnixDomainSocketEndPoint(path)).DefaultTimeout();

                        var httpRequest = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\nConnection: close\r\n\r\n");
                        await socket.SendAsync(httpRequest, SocketFlags.None).DefaultTimeout();

                        var readBuffer = new byte[512];
                        var read       = 0;
                        while (true)
                        {
                            var bytesReceived = await socket.ReceiveAsync(readBuffer.AsMemory(read), SocketFlags.None).DefaultTimeout();

                            read += bytesReceived;
                            if (bytesReceived <= 0)
                            {
                                break;
                            }
                        }

                        var httpResponse    = Encoding.ASCII.GetString(readBuffer, 0, read);
                        int httpStatusStart = httpResponse.IndexOf(' ') + 1;
                        Assert.False(httpStatusStart == 0, $"Space not found in '{httpResponse}'.");
                        int httpStatusEnd = httpResponse.IndexOf(' ', httpStatusStart);
                        Assert.False(httpStatusEnd == -1, $"Second space not found in '{httpResponse}'.");

                        var httpStatus = int.Parse(httpResponse.Substring(httpStatusStart, httpStatusEnd - httpStatusStart), CultureInfo.InvariantCulture);
                        Assert.Equal(httpStatus, StatusCodes.Status200OK);
                    }
                    await host.StopAsync().DefaultTimeout();
                }
            }
            finally
            {
                Delete(path);
            }
        }
Пример #8
0
        public async Task LargeUpload(long contentLength, bool checkBytes)
        {
            const int bufferLength = 1024 * 1024;

            Assert.True(contentLength % bufferLength == 0, $"{nameof(contentLength)} sent must be evenly divisible by {bufferLength}.");
            Assert.True(bufferLength % 256 == 0, $"{nameof(bufferLength)} must be evenly divisible by 256");

            var builder = TransportSelector.GetHostBuilder()
                          .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseKestrel(options =>
                {
                    options.Limits.MaxRequestBodySize     = contentLength;
                    options.Limits.MinRequestBodyDataRate = null;
                })
                .UseUrls("http://127.0.0.1:0/")
                .Configure(app =>
                {
                    app.Run(async context =>
                    {
                        // Read the full request body
                        long total        = 0;
                        var receivedBytes = new byte[bufferLength];
                        var received      = 0;
                        while ((received = await context.Request.Body.ReadAsync(receivedBytes, 0, receivedBytes.Length)) > 0)
                        {
                            if (checkBytes)
                            {
                                for (var i = 0; i < received; i++)
                                {
                                    // Do not use Assert.Equal here, it is to slow for this hot path
                                    Assert.True((byte)((total + i) % 256) == receivedBytes[i], "Data received is incorrect");
                                }
                            }

                            total += received;
                        }

                        await context.Response.WriteAsync($"bytesRead: {total}");
                    });
                });
            })
                          .ConfigureServices(AddTestLogging);

            using (var host = builder.Build())
            {
                await host.StartAsync();

                using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
                {
                    socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort()));
                    socket.Send(Encoding.ASCII.GetBytes("POST / HTTP/1.1\r\nHost: \r\n"));
                    socket.Send(Encoding.ASCII.GetBytes($"Content-Length: {contentLength}\r\n\r\n"));

                    var contentBytes = new byte[bufferLength];

                    if (checkBytes)
                    {
                        for (var i = 0; i < contentBytes.Length; i++)
                        {
                            contentBytes[i] = (byte)i;
                        }
                    }

                    for (var i = 0; i < contentLength / contentBytes.Length; i++)
                    {
                        socket.Send(contentBytes);
                    }

                    using (var stream = new NetworkStream(socket))
                    {
                        await AssertStreamContains(stream, $"bytesRead: {contentLength}");
                    }
                }

                await host.StopAsync();
            }
        }
Пример #9
0
    private async Task <IHost> StartHost(long?maxRequestBufferSize,
                                         byte[] expectedBody,
                                         bool useConnectionAdapter,
                                         TaskCompletionSource startReadingRequestBody,
                                         TaskCompletionSource clientFinishedSendingRequestBody,
                                         Func <MemoryPool <byte> > memoryPoolFactory = null)
    {
        var host = TransportSelector.GetHostBuilder(memoryPoolFactory, maxRequestBufferSize)
                   .ConfigureWebHost(webHostBuilder =>
        {
            webHostBuilder
            .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    if (useConnectionAdapter)
                    {
                        listenOptions.UsePassThrough();
                    }
                });

                options.Limits.MaxRequestBufferSize = maxRequestBufferSize;

                if (maxRequestBufferSize.HasValue &&
                    maxRequestBufferSize.Value < options.Limits.MaxRequestLineSize)
                {
                    options.Limits.MaxRequestLineSize = (int)maxRequestBufferSize;
                }

                if (maxRequestBufferSize.HasValue &&
                    maxRequestBufferSize.Value < options.Limits.MaxRequestHeadersTotalSize)
                {
                    options.Limits.MaxRequestHeadersTotalSize = (int)maxRequestBufferSize;
                }

                options.Limits.MinRequestBodyDataRate = null;

                options.Limits.MaxRequestBodySize = _dataLength;
            })
            .UseContentRoot(Directory.GetCurrentDirectory())
            .Configure(app => app.Run(async context =>
            {
                await startReadingRequestBody.Task.TimeoutAfter(TimeSpan.FromSeconds(120));

                var buffer    = new byte[expectedBody.Length];
                var bytesRead = 0;
                while (bytesRead < buffer.Length)
                {
                    bytesRead += await context.Request.Body.ReadAsync(buffer, bytesRead, buffer.Length - bytesRead);
                }

                await clientFinishedSendingRequestBody.Task.TimeoutAfter(TimeSpan.FromSeconds(120));

                // Verify client didn't send extra bytes
                if (await context.Request.Body.ReadAsync(new byte[1], 0, 1) != 0)
                {
                    context.Response.StatusCode = StatusCodes.Status500InternalServerError;
                    await context.Response.WriteAsync("Client sent more bytes than expectedBody.Length");
                    return;
                }

                await context.Response.WriteAsync($"bytesRead: {bytesRead}");
            }));
        })
                   .ConfigureServices(AddTestLogging)
                   .Build();

        await host.StartAsync();

        return(host);
    }
Пример #10
0
    private void ThrowsWhenBindingLocalhostToAddressInUse(AddressFamily addressFamily)
    {
        ThrowOnCriticalErrors = false;

        var addressInUseCount = 0;
        var wrongMessageCount = 0;

        var address            = addressFamily == AddressFamily.InterNetwork ? IPAddress.Loopback : IPAddress.IPv6Loopback;
        var otherAddressFamily = addressFamily == AddressFamily.InterNetwork ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork;

        while (addressInUseCount < 10 && wrongMessageCount < 10)
        {
            int port;

            using (var socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp))
            {
                // Bind first to IPv6Any to ensure both the IPv4 and IPv6 ports are available.
                socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0));
                socket.Listen(0);
                port = ((IPEndPoint)socket.LocalEndPoint).Port;
            }

            using (var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp))
            {
                try
                {
                    socket.Bind(new IPEndPoint(address, port));
                    socket.Listen(0);
                }
                catch (SocketException)
                {
                    addressInUseCount++;
                    continue;
                }

                var hostBuilder = TransportSelector.GetHostBuilder()
                                  .ConfigureWebHost(webHostBuilder =>
                {
                    webHostBuilder
                    .UseKestrel()
                    .UseUrls($"http://localhost:{port}")
                    .Configure(ConfigureEchoAddress);
                })
                                  .ConfigureServices(AddTestLogging);

                using (var host = hostBuilder.Build())
                {
                    var exception = Assert.Throws <IOException>(() => host.Start());

                    var thisAddressString  = $"http://{(addressFamily == AddressFamily.InterNetwork ? "127.0.0.1" : "[::1]")}:{port}";
                    var otherAddressString = $"http://{(addressFamily == AddressFamily.InterNetworkV6 ? "127.0.0.1" : "[::1]")}:{port}";

                    if (exception.Message == CoreStrings.FormatEndpointAlreadyInUse(otherAddressString))
                    {
                        // Don't fail immediately, because it's possible that something else really did bind to the
                        // same port for the other address family between the IPv6Any bind above and now.
                        wrongMessageCount++;
                        continue;
                    }

                    Assert.Equal(CoreStrings.FormatEndpointAlreadyInUse(thisAddressString), exception.Message);
                    Assert.Equal(0, LogMessages.Count(log => log.LogLevel == LogLevel.Critical &&
                                                      log.Exception is null &&
                                                      log.Message.EndsWith(CoreStrings.FormatEndpointAlreadyInUse(thisAddressString), StringComparison.Ordinal)));
                    break;
                }
            }
        }

        if (addressInUseCount >= 10)
        {
            Assert.True(false, $"The corresponding {otherAddressFamily} address was already in use 10 times.");
        }

        if (wrongMessageCount >= 10)
        {
            Assert.True(false, $"An error for a conflict with {otherAddressFamily} was thrown 10 times.");
        }
    }