예제 #1
0
        private void ThrowsWhenBindingLocalhostToAddressInUse(AddressFamily addressFamily, IPAddress address)
        {
            using (var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp))
            {
                socket.Bind(new IPEndPoint(address, 0));
                var port = ((IPEndPoint)socket.LocalEndPoint).Port;

                var hostBuilder = TransportSelector.GetWebHostBuilder()
                                  .ConfigureLogging(_configureLoggingDelegate)
                                  .UseKestrel()
                                  .UseUrls($"http://localhost:{port}")
                                  .Configure(ConfigureEchoAddress);

                using (var host = hostBuilder.Build())
                {
                    var exception = Assert.Throws <IOException>(() => host.Start());
                    Assert.Equal(
                        CoreStrings.FormatEndpointAlreadyInUse($"http://{(addressFamily == AddressFamily.InterNetwork ? "127.0.0.1" : "[::1]")}:{port}"),
                        exception.Message);
                }
            }
        }
예제 #2
0
        private async Task RegisterAddresses_Success(string addressInput, string[] testUrls, int testPort = 0)
        {
            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .UseKestrel()
                              .ConfigureLogging(_configureLoggingDelegate)
                              .UseUrls(addressInput)
                              .Configure(ConfigureEchoAddress);

            using (var host = hostBuilder.Build())
            {
                host.Start();

                foreach (var testUrl in testUrls.Select(testUrl => $"{testUrl}:{(testPort == 0 ? host.GetPort() : testPort)}"))
                {
                    var response = await HttpClientSlim.GetStringAsync(testUrl, validateCertificate : false);

                    // Compare the response with Uri.ToString(), rather than testUrl directly.
                    // Required to handle IPv6 addresses with zone index, like "fe80::3%1"
                    Assert.Equal(new Uri(testUrl).ToString(), response);
                }
            }
        }
예제 #3
0
        private async Task TestRemoteIPAddress(string registerAddress, string requestAddress, string expectAddress)
        {
            var builder = TransportSelector.GetWebHostBuilder()
                          .UseKestrel()
                          .UseUrls($"http://{registerAddress}:0")
                          .ConfigureServices(AddTestLogging)
                          .Configure(app =>
            {
                app.Run(async context =>
                {
                    var connection = context.Connection;
                    await context.Response.WriteAsync(JsonConvert.SerializeObject(new
                    {
                        RemoteIPAddress = connection.RemoteIpAddress?.ToString(),
                        RemotePort      = connection.RemotePort,
                        LocalIPAddress  = connection.LocalIpAddress?.ToString(),
                        LocalPort       = connection.LocalPort
                    }));
                });
            });

            using (var host = builder.Build())
                using (var client = new HttpClient())
                {
                    host.Start();

                    var response = await client.GetAsync($"http://{requestAddress}:{host.GetPort()}/");

                    response.EnsureSuccessStatusCode();

                    var connectionFacts = await response.Content.ReadAsStringAsync();

                    Assert.NotEmpty(connectionFacts);

                    var facts = JsonConvert.DeserializeObject <JObject>(connectionFacts);
                    Assert.Equal(expectAddress, facts["RemoteIPAddress"].Value <string>());
                    Assert.NotEmpty(facts["RemotePort"].Value <string>());
                }
        }
        public async Task CanRebindToMultipleEndPoints()
        {
            var port = GetNextPort();
            var ipv4endPointAddress = $"http://127.0.0.1:{port}/";
            var ipv6endPointAddress = $"http://[::1]:{port}/";

            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .ConfigureServices(AddTestLogging)
                              .UseKestrel(options =>
            {
                options.Listen(IPAddress.Loopback, port);
                options.Listen(IPAddress.IPv6Loopback, port);
            })
                              .Configure(ConfigureEchoAddress);

            using (var host = hostBuilder.Build())
            {
                host.Start();

                Assert.Equal(ipv4endPointAddress, await HttpClientSlim.GetStringAsync(ipv4endPointAddress));
                Assert.Equal(ipv6endPointAddress, await HttpClientSlim.GetStringAsync(ipv6endPointAddress));
            }

            hostBuilder = TransportSelector.GetWebHostBuilder()
                          .UseKestrel(options =>
            {
                options.Listen(IPAddress.Loopback, port);
                options.Listen(IPAddress.IPv6Loopback, port);
            })
                          .Configure(ConfigureEchoAddress);

            using (var host = hostBuilder.Build())
            {
                host.Start();

                Assert.Equal(ipv4endPointAddress, await HttpClientSlim.GetStringAsync(ipv4endPointAddress));
                Assert.Equal(ipv6endPointAddress, await HttpClientSlim.GetStringAsync(ipv6endPointAddress));
            }
        }
예제 #5
0
        public async Task SpawnChildProcess_DoesNotInheritListenHandle()
        {
            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .UseKestrel()
                              .ConfigureServices(AddTestLogging)
                              .UseUrls("http://127.0.0.1:0")
                              .Configure(app =>
            {
                app.Run(context =>
                {
                    return(context.Response.WriteAsync("Hello World"));
                });
            });

            using (var host = hostBuilder.Build())
            {
                await host.StartAsync();

                var processInfo = new ProcessStartInfo
                {
                    FileName       = "cmd.exe",
                    CreateNoWindow = true,
                };
                using (var process = Process.Start(processInfo))
                {
                    var port = host.GetPort();
                    await host.StopAsync();

                    // We should not be able to connect if the handle was correctly closed and not inherited by the child process.
                    using (var client = new TcpClient())
                    {
                        await Assert.ThrowsAnyAsync <SocketException>(() => client.ConnectAsync("127.0.0.1", port));
                    }

                    process.Kill();
                }
            }
        }
예제 #6
0
        public async Task DoesNotOverrideDirectConfigurationWithIServerAddressesFeature_IfPreferHostingUrlsFalse()
        {
            var useUrlsAddress = $"http://127.0.0.1:0";

            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .ConfigureServices(AddTestLogging)
                              .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword");
                });
            })
                              .UseUrls($"http://127.0.0.1:0")
                              .PreferHostingUrls(false)
                              .Configure(ConfigureEchoAddress);

            using (var host = hostBuilder.Build())
            {
                await host.StartAsync();

                var port = host.GetPort();

                // If this isn't working properly, we'll get the HTTP endpoint defined in UseUrls
                // instead of the HTTPS endpoint defined in UseKestrel.
                var serverAddresses = host.ServerFeatures.Get <IServerAddressesFeature>().Addresses;
                Assert.Equal(1, serverAddresses.Count);
                var endPointAddress = $"https://127.0.0.1:{port}";
                Assert.Equal(serverAddresses.First(), endPointAddress);

                Assert.Single(TestApplicationErrorLogger.Messages, log => log.LogLevel == LogLevel.Warning &&
                              string.Equals(CoreStrings.FormatOverridingWithKestrelOptions(useUrlsAddress, "UseKestrel()"),
                                            log.Message, StringComparison.Ordinal));

                Assert.Equal(new Uri(endPointAddress).ToString(), await HttpClientSlim.GetStringAsync(endPointAddress, validateCertificate: false));
                await host.StopAsync();
            }
        }
예제 #7
0
        public async Task OverrideDirectConfigurationWithIServerAddressesFeature_Succeeds()
        {
            var useUrlsAddress = $"http://127.0.0.1:0";
            var hostBuilder    = TransportSelector.GetWebHostBuilder()
                                 .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    listenOptions.UseHttps(TestResources.GetTestCertificate());
                });
            })
                                 .UseUrls(useUrlsAddress)
                                 .PreferHostingUrls(true)
                                 .ConfigureServices(AddTestLogging)
                                 .Configure(ConfigureEchoAddress);

            using (var host = hostBuilder.Build())
            {
                await host.StartAsync();

                var port = host.GetPort();

                // If this isn't working properly, we'll get the HTTPS endpoint defined in UseKestrel
                // instead of the HTTP endpoint defined in UseUrls.
                var serverAddresses = host.ServerFeatures.Get <IServerAddressesFeature>().Addresses;
                Assert.Equal(1, serverAddresses.Count);
                var useUrlsAddressWithPort = $"http://127.0.0.1:{port}";
                Assert.Equal(serverAddresses.First(), useUrlsAddressWithPort);

                Assert.Single(TestApplicationErrorLogger.Messages, log => log.LogLevel == LogLevel.Information &&
                              string.Equals(CoreStrings.FormatOverridingWithPreferHostingUrls(nameof(IServerAddressesFeature.PreferHostingUrls), useUrlsAddress),
                                            log.Message, StringComparison.Ordinal));

                Assert.Equal(new Uri(useUrlsAddressWithPort).ToString(), await HttpClientSlim.GetStringAsync(useUrlsAddressWithPort));

                await host.StopAsync();
            }
        }
예제 #8
0
        public void ThrowsWhenBindingToIPv6AddressInUse()
        {
            TestApplicationErrorLogger.IgnoredExceptions.Add(typeof(IOException));

            using (var socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp))
            {
                socket.Bind(new IPEndPoint(IPAddress.IPv6Loopback, 0));
                socket.Listen(0);
                var port = ((IPEndPoint)socket.LocalEndPoint).Port;

                var hostBuilder = TransportSelector.GetWebHostBuilder()
                                  .ConfigureServices(AddTestLogging)
                                  .UseKestrel()
                                  .UseUrls($"http://[::1]:{port}")
                                  .Configure(ConfigureEchoAddress);

                using (var host = hostBuilder.Build())
                {
                    var exception = Assert.Throws <IOException>(() => host.Start());
                    Assert.Equal(CoreStrings.FormatEndpointAlreadyInUse($"http://[::1]:{port}"), exception.Message);
                }
            }
        }
        public async Task ClientAttemptingToUseUnsupportedProtocolIsLoggedAsDebug()
        {
            var loggerProvider = new HandshakeErrorLoggerProvider();

            LoggerFactory.AddProvider(loggerProvider);
            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword");
                });
            })
                              .ConfigureServices(AddTestLogging)
                              .Configure(app => app.Run(httpContext => Task.CompletedTask));

            using (var host = hostBuilder.Build())
            {
                host.Start();

                using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/")))
                    using (var stream = new NetworkStream(socket, ownsSocket: false))
                        using (var sslStream = new SslStream(stream, true, (sender, certificate, chain, errors) => true))
                        {
                            // SslProtocols.Tls is TLS 1.0 which isn't supported by Kestrel by default.
                            await Assert.ThrowsAsync <IOException>(() =>
                                                                   sslStream.AuthenticateAsClientAsync("127.0.0.1", clientCertificates: null,
                                                                                                       enabledSslProtocols: SslProtocols.Tls,
                                                                                                       checkCertificateRevocation: false));
                        }
            }

            await loggerProvider.FilterLogger.LogTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout);

            Assert.Equal(1, loggerProvider.FilterLogger.LastEventId);
            Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel);
        }
        public async Task LoggingConnectionAdapterCanBeAddedBeforeAndAfterHttpsAdapter()
        {
            var host = TransportSelector.GetWebHostBuilder()
                       .ConfigureLogging(builder =>
            {
                builder.SetMinimumLevel(LogLevel.Trace);
                builder.AddXunit(_output);
            })
                       .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    listenOptions.UseConnectionLogging();
                    listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword");
                    listenOptions.UseConnectionLogging();
                });
            })
                       .Configure(app =>
            {
                app.Run(context =>
                {
                    context.Response.ContentLength = 12;
                    return(context.Response.WriteAsync("Hello World!"));
                });
            })
                       .Build();

            using (host)
            {
                await host.StartAsync();

                var response = await HttpClientSlim.GetStringAsync($"https://localhost:{host.GetPort()}/", validateCertificate : false)
                               .TimeoutAfter(TimeSpan.FromSeconds(10));

                Assert.Equal("Hello World!", response);
            }
        }
예제 #11
0
        public TestServer(RequestDelegate app, TestServiceContext context, Action <KestrelServerOptions> configureKestrel, Action <IServiceCollection> configureServices)
        {
            _app    = app;
            Context = context;

            _host = TransportSelector.GetWebHostBuilder(context.MemoryPoolFactory, context.ServerOptions.Limits.MaxRequestBufferSize)
                    .UseKestrel(options =>
            {
                configureKestrel(options);
                _listenOptions = options.ListenOptions.First();
            })
                    .ConfigureServices(services =>
            {
                services.AddSingleton <IStartup>(this);
                services.AddSingleton(context.LoggerFactory);
                services.AddSingleton <IServer>(sp =>
                {
                    // Manually configure options on the TestServiceContext.
                    // We're doing this so we can use the same instance that was passed in
                    var configureOptions = sp.GetServices <IConfigureOptions <KestrelServerOptions> >();
                    foreach (var c in configureOptions)
                    {
                        c.Configure(context.ServerOptions);
                    }

                    return(new KestrelServer(sp.GetRequiredService <IConnectionListenerFactory>(), context));
                });
                configureServices(services);
            })
                    .UseSetting(WebHostDefaults.ApplicationKey, typeof(TestServer).GetTypeInfo().Assembly.FullName)
                    .UseSetting(WebHostDefaults.ShutdownTimeoutKey, TestConstants.DefaultTimeout.TotalSeconds.ToString())
                    .Build();

            _host.Start();

            Context.Log.LogDebug($"TestServer is listening on port {Port}");
        }
예제 #12
0
        public async Task HandshakeTimesOutAndIsLoggedAsDebug()
        {
            var loggerProvider = new HandshakeErrorLoggerProvider();

            LoggerFactory.AddProvider(loggerProvider);
            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    listenOptions.UseHttps(o =>
                    {
                        o.ServerCertificate = new X509Certificate2(TestResources.TestCertificatePath, "testPassword");
                        o.HandshakeTimeout  = TimeSpan.FromSeconds(1);
                    });
                });
            })
                              .ConfigureServices(AddTestLogging)
                              .Configure(app => app.Run(httpContext => Task.CompletedTask));

            using (var host = hostBuilder.Build())
            {
                host.Start();

                using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/")))
                    using (var stream = new NetworkStream(socket, ownsSocket: false))
                    {
                        // No data should be sent and the connection should be closed in well under 30 seconds.
                        Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1).TimeoutAfter(TestConstants.DefaultTimeout));
                    }
            }

            await loggerProvider.FilterLogger.LogTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout);

            Assert.Equal(2, loggerProvider.FilterLogger.LastEventId);
            Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel);
        }
예제 #13
0
        private async Task RegisterDefaultServerAddresses_Success(IEnumerable <string> addresses, bool mockHttps = false)
        {
            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .ConfigureServices(AddTestLogging)
                              .UseKestrel(options =>
            {
                if (mockHttps)
                {
                    options.DefaultCertificate = TestResources.GetTestCertificate();
                }
            })
                              .Configure(ConfigureEchoAddress);

            using (var host = hostBuilder.Build())
            {
                await host.StartAsync();

                Assert.Equal(5000, host.GetPort());

                if (mockHttps)
                {
                    Assert.Contains(5001, host.GetPorts());
                }

                Assert.Single(TestApplicationErrorLogger.Messages, log => log.LogLevel == LogLevel.Debug &&
                              (string.Equals(CoreStrings.FormatBindingToDefaultAddresses(Constants.DefaultServerAddress, Constants.DefaultServerHttpsAddress), log.Message, StringComparison.Ordinal) ||
                               string.Equals(CoreStrings.FormatBindingToDefaultAddress(Constants.DefaultServerAddress), log.Message, StringComparison.Ordinal)));

                foreach (var address in addresses)
                {
                    Assert.Equal(new Uri(address).ToString(), await HttpClientSlim.GetStringAsync(address, validateCertificate: false));
                }

                await host.StopAsync();
            }
        }
예제 #14
0
        public async Task ClientHandshakeFailureLoggedAsDebug()
        {
            var loggerProvider = new HandshakeErrorLoggerProvider();

            LoggerFactory.AddProvider(loggerProvider);

            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword");
                });
            })
                              .ConfigureServices(AddTestLogging)
                              .Configure(app => { });

            using (var host = hostBuilder.Build())
            {
                host.Start();

                using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/")))
                    using (var stream = new NetworkStream(socket))
                    {
                        // Send null bytes and close socket
                        await stream.WriteAsync(new byte[10], 0, 10);
                    }

                await loggerProvider.FilterLogger.LogTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout);
            }

            Assert.Equal(1, loggerProvider.FilterLogger.LastEventId.Id);
            Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel);
            Assert.True(loggerProvider.ErrorLogger.TotalErrorsLogged == 0,
                        userMessage: string.Join(Environment.NewLine, loggerProvider.ErrorLogger.ErrorMessages));
        }
예제 #15
0
        private async Task RegisterAddresses_Success(string addressInput, string[] testUrls, int testPort = 0)
        {
            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .UseKestrel(serverOptions =>
            {
                serverOptions.ConfigureHttpsDefaults(httpsOptions =>
                {
                    httpsOptions.ServerCertificate = TestResources.GetTestCertificate();
                });
            })
                              .ConfigureServices(AddTestLogging)
                              .UseUrls(addressInput)
                              .Configure(ConfigureEchoAddress);

            using (var host = hostBuilder.Build())
            {
                host.Start();

                foreach (var testUrl in testUrls.Select(testUrl => $"{testUrl}:{(testPort == 0 ? host.GetPort() : testPort)}"))
                {
                    var response = await HttpClientSlim.GetStringAsync(testUrl, validateCertificate : false);

                    // Filter out the scope id for IPv6, that's not sent over the wire. "fe80::3%1"
                    // See https://github.com/aspnet/Common/pull/369
                    var uri = new Uri(testUrl);
                    if (uri.HostNameType == UriHostNameType.IPv6)
                    {
                        var builder = new UriBuilder(uri);
                        var ip      = IPAddress.Parse(builder.Host);
                        builder.Host = new IPAddress(ip.GetAddressBytes()).ToString(); // Without the scope id.
                        uri          = builder.Uri;
                    }
                    Assert.Equal(uri.ToString(), response);
                }
            }
        }
        public async Task TestUnixDomainSocketWithUrl()
        {
            var path = Path.GetTempFileName();
            var url  = $"http://unix:/{path}";

            Delete(path);

            try
            {
                var hostBuilder = TransportSelector.GetWebHostBuilder()
                                  .UseUrls(url)
                                  .UseKestrel()
                                  .ConfigureServices(AddTestLogging)
                                  .Configure(app =>
                {
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Hello World");
                    });
                });

                using (var host = hostBuilder.Build())
                {
                    await host.StartAsync().DefaultTimeout();

                    // https://github.com/dotnet/corefx/issues/5999
                    // .NET Core HttpClient does not support unix sockets, it's difficult to parse raw response data. below is a little hacky way.
                    using (var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
                    {
                        await socket.ConnectAsync(new UnixDomainSocketEndPoint(path)).DefaultTimeout();

                        var httpRequest = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\nConnection: close\r\n\r\n");
                        await socket.SendAsync(httpRequest, SocketFlags.None).DefaultTimeout();

                        var readBuffer = new byte[512];
                        var read       = 0;
                        while (true)
                        {
                            var bytesReceived = await socket.ReceiveAsync(readBuffer.AsMemory(read), SocketFlags.None).DefaultTimeout();

                            read += bytesReceived;
                            if (bytesReceived <= 0)
                            {
                                break;
                            }
                        }

                        var httpResponse    = Encoding.ASCII.GetString(readBuffer, 0, read);
                        int httpStatusStart = httpResponse.IndexOf(' ') + 1;
                        int httpStatusEnd   = httpResponse.IndexOf(' ', httpStatusStart);

                        var httpStatus = int.Parse(httpResponse.Substring(httpStatusStart, httpStatusEnd - httpStatusStart));
                        Assert.Equal(httpStatus, StatusCodes.Status200OK);
                    }
                    await host.StopAsync().DefaultTimeout();
                }
            }
            finally
            {
                Delete(path);
            }
        }
예제 #17
0
        public async Task DoesNotThrowObjectDisposedExceptionOnConnectionAbort()
        {
            var loggerProvider = new HandshakeErrorLoggerProvider();

            LoggerFactory.AddProvider(loggerProvider);
            var hostBuilder = TransportSelector.GetWebHostBuilder()
                              .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword");
                });
            })
                              .ConfigureServices(AddTestLogging)
                              .ConfigureLogging(builder => builder.AddProvider(loggerProvider))
                              .Configure(app => app.Run(async httpContext =>
            {
                var ct = httpContext.RequestAborted;
                while (!ct.IsCancellationRequested)
                {
                    try
                    {
                        await httpContext.Response.WriteAsync($"hello, world", ct);
                        await Task.Delay(1000, ct);
                    }
                    catch (TaskCanceledException)
                    {
                        // Don't regard connection abort as an error
                    }
                }
            }));

            using (var host = hostBuilder.Build())
            {
                host.Start();

                using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/")))
                    using (var stream = new NetworkStream(socket, ownsSocket: false))
                        using (var sslStream = new SslStream(stream, true, (sender, certificate, chain, errors) => true))
                        {
                            await sslStream.AuthenticateAsClientAsync("127.0.0.1", clientCertificates : null,
                                                                      enabledSslProtocols : SslProtocols.Tls11 | SslProtocols.Tls12,
                                                                      checkCertificateRevocation : false);

                            var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n");
                            await sslStream.WriteAsync(request, 0, request.Length);

                            // Temporary workaround for a deadlock when reading from an aborted client SslStream on Mac and Linux.
                            if (TestPlatformHelper.IsWindows)
                            {
                                await sslStream.ReadAsync(new byte[32], 0, 32);
                            }
                            else
                            {
                                await stream.ReadAsync(new byte[32], 0, 32);
                            }
                        }
            }

            Assert.False(loggerProvider.ErrorLogger.ObjectDisposedExceptionLogged);
        }
        private IWebHost StartWebHost(long?maxRequestBufferSize,
                                      byte[] expectedBody,
                                      bool useConnectionAdapter,
                                      TaskCompletionSource <object> startReadingRequestBody,
                                      TaskCompletionSource <object> clientFinishedSendingRequestBody)
        {
            var host = TransportSelector.GetWebHostBuilder()
                       .ConfigureLogging(_configureLoggingDelegate)
                       .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    if (useConnectionAdapter)
                    {
                        listenOptions.ConnectionAdapters.Add(new PassThroughConnectionAdapter());
                    }
                });

                options.Limits.MaxRequestBufferSize = maxRequestBufferSize;

                if (maxRequestBufferSize.HasValue &&
                    maxRequestBufferSize.Value < options.Limits.MaxRequestLineSize)
                {
                    options.Limits.MaxRequestLineSize = (int)maxRequestBufferSize;
                }

                if (maxRequestBufferSize.HasValue &&
                    maxRequestBufferSize.Value < options.Limits.MaxRequestHeadersTotalSize)
                {
                    options.Limits.MaxRequestHeadersTotalSize = (int)maxRequestBufferSize;
                }

                options.Limits.MinRequestBodyDataRate = null;
            })
                       .UseContentRoot(Directory.GetCurrentDirectory())
                       .Configure(app => app.Run(async context =>
            {
                await startReadingRequestBody.Task.TimeoutAfter(TimeSpan.FromSeconds(120));

                var buffer    = new byte[expectedBody.Length];
                var bytesRead = 0;
                while (bytesRead < buffer.Length)
                {
                    bytesRead += await context.Request.Body.ReadAsync(buffer, bytesRead, buffer.Length - bytesRead);
                }

                await clientFinishedSendingRequestBody.Task.TimeoutAfter(TimeSpan.FromSeconds(120));

                // Verify client didn't send extra bytes
                if (await context.Request.Body.ReadAsync(new byte[1], 0, 1) != 0)
                {
                    context.Response.StatusCode = StatusCodes.Status500InternalServerError;
                    await context.Response.WriteAsync("Client sent more bytes than expectedBody.Length");
                    return;
                }

                // Verify bytes received match expectedBody
                for (int i = 0; i < expectedBody.Length; i++)
                {
                    if (buffer[i] != expectedBody[i])
                    {
                        context.Response.StatusCode = StatusCodes.Status500InternalServerError;
                        await context.Response.WriteAsync($"Bytes received do not match expectedBody at position {i}");
                        return;
                    }
                }

                await context.Response.WriteAsync($"bytesRead: {bytesRead.ToString()}");
            }))
                       .Build();

            host.Start();

            return(host);
        }
예제 #19
0
        public async Task TestUnixDomainSocket()
        {
            var path = Path.GetTempFileName();

            Delete(path);

            try
            {
                var serverConnectionCompletedTcs = new TaskCompletionSource <object>(TaskContinuationOptions.RunContinuationsAsynchronously);

                async Task EchoServer(ConnectionContext connection)
                {
                    // For graceful shutdown
                    var notificationFeature = connection.Features.Get <IConnectionLifetimeNotificationFeature>();

                    try
                    {
                        while (true)
                        {
                            var result = await connection.Transport.Input.ReadAsync(notificationFeature.ConnectionClosedRequested);

                            if (result.IsCompleted)
                            {
                                Logger.LogDebug("Application receive loop ending for connection {connectionId}.", connection.ConnectionId);
                                break;
                            }

                            await connection.Transport.Output.WriteAsync(result.Buffer.ToArray());

                            connection.Transport.Input.AdvanceTo(result.Buffer.End);
                        }
                    }
                    catch (OperationCanceledException)
                    {
                        Logger.LogDebug("Graceful shutdown triggered for {connectionId}.", connection.ConnectionId);
                    }
                    finally
                    {
                        serverConnectionCompletedTcs.TrySetResult(null);
                    }
                }

                var hostBuilder = TransportSelector.GetWebHostBuilder()
                                  .UseKestrel(o =>
                {
                    o.ListenUnixSocket(path, builder =>
                    {
                        builder.Run(EchoServer);
                    });
                })
                                  .ConfigureServices(AddTestLogging)
                                  .Configure(c => { });

                using (var host = hostBuilder.Build())
                {
                    await host.StartAsync();

                    using (var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
                    {
                        await socket.ConnectAsync(new UnixDomainSocketEndPoint(path));

                        var data = Encoding.ASCII.GetBytes("Hello World");
                        await socket.SendAsync(data, SocketFlags.None);

                        var buffer = new byte[data.Length];
                        var read   = 0;
                        while (read < data.Length)
                        {
                            read += await socket.ReceiveAsync(buffer.AsMemory(read, buffer.Length - read), SocketFlags.None);
                        }

                        Assert.Equal(data, buffer);
                    }

                    // Wait for the server to complete the loop because of the FIN
                    await serverConnectionCompletedTcs.Task.DefaultTimeout();

                    await host.StopAsync();
                }
            }
            finally
            {
                Delete(path);
            }
        }
예제 #20
0
        public async Task LargeUpload(long contentLength, bool checkBytes)
        {
            const int bufferLength = 1024 * 1024;

            Assert.True(contentLength % bufferLength == 0, $"{nameof(contentLength)} sent must be evenly divisible by {bufferLength}.");
            Assert.True(bufferLength % 256 == 0, $"{nameof(bufferLength)} must be evenly divisible by 256");

            var builder = TransportSelector.GetWebHostBuilder()
                          .ConfigureServices(AddTestLogging)
                          .UseKestrel(options =>
            {
                options.Limits.MaxRequestBodySize     = contentLength;
                options.Limits.MinRequestBodyDataRate = null;
            })
                          .UseUrls("http://127.0.0.1:0/")
                          .Configure(app =>
            {
                app.Run(async context =>
                {
                    // Read the full request body
                    long total        = 0;
                    var receivedBytes = new byte[bufferLength];
                    var received      = 0;
                    while ((received = await context.Request.Body.ReadAsync(receivedBytes, 0, receivedBytes.Length)) > 0)
                    {
                        if (checkBytes)
                        {
                            for (var i = 0; i < received; i++)
                            {
                                // Do not use Assert.Equal here, it is to slow for this hot path
                                Assert.True((byte)((total + i) % 256) == receivedBytes[i], "Data received is incorrect");
                            }
                        }

                        total += received;
                    }

                    await context.Response.WriteAsync($"bytesRead: {total.ToString()}");
                });
            });

            using (var host = builder.Build())
            {
                host.Start();

                using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
                {
                    socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort()));
                    socket.Send(Encoding.ASCII.GetBytes("POST / HTTP/1.1\r\nHost: \r\n"));
                    socket.Send(Encoding.ASCII.GetBytes($"Content-Length: {contentLength}\r\n\r\n"));

                    var contentBytes = new byte[bufferLength];

                    if (checkBytes)
                    {
                        for (var i = 0; i < contentBytes.Length; i++)
                        {
                            contentBytes[i] = (byte)i;
                        }
                    }

                    for (var i = 0; i < contentLength / contentBytes.Length; i++)
                    {
                        socket.Send(contentBytes);
                    }

                    using (var stream = new NetworkStream(socket))
                    {
                        await AssertStreamContains(stream, $"bytesRead: {contentLength}");
                    }
                }
            }
        }
예제 #21
0
        private async Task <IWebHost> StartWebHost(long?maxRequestBufferSize,
                                                   byte[] expectedBody,
                                                   bool useConnectionAdapter,
                                                   TaskCompletionSource startReadingRequestBody,
                                                   TaskCompletionSource clientFinishedSendingRequestBody,
                                                   Func <MemoryPool <byte> > memoryPoolFactory = null)
        {
            var host = TransportSelector.GetWebHostBuilder(memoryPoolFactory, maxRequestBufferSize)
                       .ConfigureServices(AddTestLogging)
                       .UseKestrel(options =>
            {
                options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions =>
                {
                    if (useConnectionAdapter)
                    {
                        listenOptions.UsePassThrough();
                    }
                });

                options.Limits.MaxRequestBufferSize = maxRequestBufferSize;

                if (maxRequestBufferSize.HasValue &&
                    maxRequestBufferSize.Value < options.Limits.MaxRequestLineSize)
                {
                    options.Limits.MaxRequestLineSize = (int)maxRequestBufferSize;
                }

                if (maxRequestBufferSize.HasValue &&
                    maxRequestBufferSize.Value < options.Limits.MaxRequestHeadersTotalSize)
                {
                    options.Limits.MaxRequestHeadersTotalSize = (int)maxRequestBufferSize;
                }

                options.Limits.MinRequestBodyDataRate = null;

                options.Limits.MaxRequestBodySize = _dataLength;
            })
                       .UseContentRoot(Directory.GetCurrentDirectory())
                       .Configure(app => app.Run(async context =>
            {
                await startReadingRequestBody.Task.TimeoutAfter(TimeSpan.FromSeconds(120));

                var buffer    = new byte[expectedBody.Length];
                var bytesRead = 0;
                while (bytesRead < buffer.Length)
                {
                    bytesRead += await context.Request.Body.ReadAsync(buffer, bytesRead, buffer.Length - bytesRead);
                }

                await clientFinishedSendingRequestBody.Task.TimeoutAfter(TimeSpan.FromSeconds(120));

                // Verify client didn't send extra bytes
                if (await context.Request.Body.ReadAsync(new byte[1], 0, 1) != 0)
                {
                    context.Response.StatusCode = StatusCodes.Status500InternalServerError;
                    await context.Response.WriteAsync("Client sent more bytes than expectedBody.Length");
                    return;
                }

                await context.Response.WriteAsync($"bytesRead: {bytesRead.ToString()}");
            }))
                       .Build();

            await host.StartAsync();

            return(host);
        }
예제 #22
0
        private void ThrowsWhenBindingLocalhostToAddressInUse(AddressFamily addressFamily)
        {
            TestApplicationErrorLogger.IgnoredExceptions.Add(typeof(IOException));

            var addressInUseCount = 0;
            var wrongMessageCount = 0;

            var address            = addressFamily == AddressFamily.InterNetwork ? IPAddress.Loopback : IPAddress.IPv6Loopback;
            var otherAddressFamily = addressFamily == AddressFamily.InterNetwork ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork;

            while (addressInUseCount < 10 && wrongMessageCount < 10)
            {
                int port;

                using (var socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp))
                {
                    // Bind first to IPv6Any to ensure both the IPv4 and IPv6 ports are available.
                    socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0));
                    socket.Listen(0);
                    port = ((IPEndPoint)socket.LocalEndPoint).Port;
                }

                using (var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp))
                {
                    try
                    {
                        socket.Bind(new IPEndPoint(address, port));
                        socket.Listen(0);
                    }
                    catch (SocketException)
                    {
                        addressInUseCount++;
                        continue;
                    }

                    var hostBuilder = TransportSelector.GetWebHostBuilder()
                                      .ConfigureServices(AddTestLogging)
                                      .UseKestrel()
                                      .UseUrls($"http://localhost:{port}")
                                      .Configure(ConfigureEchoAddress);

                    using (var host = hostBuilder.Build())
                    {
                        var exception = Assert.Throws <IOException>(() => host.Start());

                        var thisAddressString  = $"http://{(addressFamily == AddressFamily.InterNetwork ? "127.0.0.1" : "[::1]")}:{port}";
                        var otherAddressString = $"http://{(addressFamily == AddressFamily.InterNetworkV6? "127.0.0.1" : "[::1]")}:{port}";

                        if (exception.Message == CoreStrings.FormatEndpointAlreadyInUse(otherAddressString))
                        {
                            // Don't fail immediately, because it's possible that something else really did bind to the
                            // same port for the other address family between the IPv6Any bind above and now.
                            wrongMessageCount++;
                            continue;
                        }

                        Assert.Equal(CoreStrings.FormatEndpointAlreadyInUse(thisAddressString), exception.Message);
                        break;
                    }
                }
            }

            if (addressInUseCount >= 10)
            {
                Assert.True(false, $"The corresponding {otherAddressFamily} address was already in use 10 times.");
            }

            if (wrongMessageCount >= 10)
            {
                Assert.True(false, $"An error for a conflict with {otherAddressFamily} was thrown 10 times.");
            }
        }
예제 #23
0
        public async Task TestUnixDomainSocket()
        {
            var path = Path.GetTempFileName();

            Delete(path);

            try
            {
                async Task EchoServer(ConnectionContext connection)
                {
                    // For graceful shutdown
                    var notificationFeature = connection.Features.Get <IConnectionLifetimeNotificationFeature>();

                    try
                    {
                        while (true)
                        {
                            var result = await connection.Transport.Input.ReadAsync(notificationFeature.ConnectionClosedRequested);

                            if (result.IsCompleted)
                            {
                                break;
                            }

                            await connection.Transport.Output.WriteAsync(result.Buffer.ToArray());

                            connection.Transport.Input.AdvanceTo(result.Buffer.End);
                        }
                    }
                    catch (OperationCanceledException)
                    {
                    }
                }

                var hostBuilder = TransportSelector.GetWebHostBuilder()
                                  .UseKestrel(o =>
                {
                    o.ListenUnixSocket(path, builder =>
                    {
                        builder.Run(EchoServer);
                    });
                })
                                  .ConfigureServices(AddTestLogging)
                                  .Configure(c => { });

                using (var host = hostBuilder.Build())
                {
                    await host.StartAsync();

                    using (var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
                    {
                        await socket.ConnectAsync(new UnixDomainSocketEndPoint(path));

                        var data = Encoding.ASCII.GetBytes("Hello World");
                        await socket.SendAsync(data, SocketFlags.None);

                        var buffer = new byte[data.Length];
                        var read   = 0;
                        while (read < data.Length)
                        {
                            read += await socket.ReceiveAsync(buffer.AsMemory(read, buffer.Length - read), SocketFlags.None);
                        }

                        Assert.Equal(data, buffer);
                    }

                    await host.StopAsync();
                }
            }
            finally
            {
                Delete(path);
            }
        }