public async Task SendShortData_Success() { var orriginalData = Encoding.UTF8.GetBytes("Hello World"); await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); var serverBuffer = new byte[orriginalData.Length]; var result = await webSocket.ReceiveAsync(new ArraySegment <byte>(serverBuffer), CancellationToken.None); Assert.True(result.EndOfMessage); Assert.Equal(orriginalData.Length, result.Count); Assert.Equal(WebSocketMessageType.Binary, result.MessageType); Assert.Equal(orriginalData, serverBuffer); })) { using (var client = new ClientWebSocket()) { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); await client.SendAsync(new ArraySegment <byte>(orriginalData), WebSocketMessageType.Binary, true, CancellationToken.None); } } }
public async Task NegotiateSubProtocol_Success() { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); Assert.Equal("alpha, bravo, charlie", context.Request.Headers["Sec-WebSocket-Protocol"]); var webSocket = await context.WebSockets.AcceptWebSocketAsync("Bravo"); })) { using (var client = new ClientWebSocket()) { client.Options.AddSubProtocol("alpha"); client.Options.AddSubProtocol("bravo"); client.Options.AddSubProtocol("charlie"); await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); // The Windows version of ClientWebSocket uses the casing from the header (Bravo) // However, the Managed version seems match the header against the list generated by // the AddSubProtocol calls (case-insensitively) and then use the version from // that list as the value for SubProtocol. This is fine, but means we need to ignore case here. // We could update our AddSubProtocols above to the same case but I think it's better to // ensure this behavior is codified by this test. Assert.Equal("Bravo", client.SubProtocol, ignoreCase: true); } } }
public async Task MultipleValueHeadersNotOverridden() { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); Assert.Equal("Upgrade, keep-alive", context.Request.Headers.Connection.ToString()); Assert.Equal("websocket, example", context.Request.Headers.Upgrade.ToString()); })) { using (var client = new HttpClient()) { var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); uri.Scheme = "http"; // Craft a valid WebSocket Upgrade request using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) { request.Headers.Connection.Clear(); request.Headers.Connection.Add("Upgrade"); request.Headers.Connection.Add("keep-alive"); request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket")); request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("example")); request.Headers.Add(HeaderNames.SecWebSocketVersion, "13"); // SecWebSocketKey required to be 16 bytes request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None)); var response = await client.SendAsync(request); Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode); } } } }
public async Task OriginIsNotValidatedForNonWebSocketRequests() { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context => { Assert.False(context.WebSockets.IsWebSocketRequest); return(Task.CompletedTask); }, o => o.AllowedOrigins.Add("http://example.com"))) { using (var client = new HttpClient()) { var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); uri.Scheme = "http"; using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) { request.Headers.Add("Origin", "http://notexample.com"); var response = await client.SendAsync(request); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } } } }
public async Task CloseFromCloseSent_Success() { string closeDescription = "Test Closed"; await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); var serverBuffer = new byte[1024]; var result = await webSocket.ReceiveAsync(new ArraySegment <byte>(serverBuffer), CancellationToken.None); Assert.True(result.EndOfMessage); Assert.Equal(0, result.Count); Assert.Equal(WebSocketMessageType.Close, result.MessageType); Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); Assert.Equal(closeDescription, result.CloseStatusDescription); await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); })) { using (var client = new ClientWebSocket()) { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); await client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, closeDescription, CancellationToken.None); Assert.Equal(WebSocketState.CloseSent, client.State); await client.CloseAsync(WebSocketCloseStatus.NormalClosure, closeDescription, CancellationToken.None); Assert.Equal(WebSocketState.Closed, client.State); } } }
public async Task SendLongData_Success() { var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var orriginalData = Encoding.UTF8.GetBytes(new string('a', 0x1FFFF)); await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); var serverBuffer = new byte[orriginalData.Length]; var result = await webSocket.ReceiveAsync(new ArraySegment <byte>(serverBuffer), CancellationToken.None); Assert.True(result.EndOfMessage); Assert.Equal(WebSocketMessageType.Binary, result.MessageType); Assert.Equal(orriginalData, serverBuffer); tcs.SetResult(); })) { using (var client = new ClientWebSocket()) { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); await client.SendAsync(new ArraySegment <byte>(orriginalData), WebSocketMessageType.Binary, true, CancellationToken.None); } // Wait to close the server otherwise the app could throw if it takes longer than the shutdown timeout await tcs.Task; } }
public async Task ReceiveLongData() { var orriginalData = Encoding.UTF8.GetBytes(new string('a', 0x1FFFF)); await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); await webSocket.SendAsync(new ArraySegment <byte>(orriginalData), WebSocketMessageType.Binary, true, CancellationToken.None); })) { using (var client = new ClientWebSocket()) { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); var clientBuffer = new byte[orriginalData.Length]; WebSocketReceiveResult result; int receivedCount = 0; do { result = await client.ReceiveAsync(new ArraySegment <byte>(clientBuffer, receivedCount, clientBuffer.Length - receivedCount), CancellationToken.None); receivedCount += result.Count; Assert.Equal(WebSocketMessageType.Binary, result.MessageType); }while (!result.EndOfMessage); Assert.Equal(orriginalData.Length, receivedCount); Assert.Equal(WebSocketMessageType.Binary, result.MessageType); Assert.Equal(orriginalData, clientBuffer); } } }
public async Task CompressionNegotiationCanChooseExtension(string clientHeader, string expectedResponse) { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() { DangerousEnableCompression = true, ServerMaxWindowBits = 13 }); })) { using (var client = new HttpClient()) { var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); uri.Scheme = "http"; // Craft a valid WebSocket Upgrade request using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) { SetGenericWebSocketRequest(request); request.Headers.Add(HeaderNames.SecWebSocketExtensions, clientHeader); var response = await client.SendAsync(request); Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode); Assert.Equal(expectedResponse, response.Headers.GetValues(HeaderNames.SecWebSocketExtensions).Aggregate((l, r) => $"{l}; {r}")); } } } }
public async Task CompressionNegotiationIgnoredIfNotEnabledOnServer() { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); })) { using (var client = new HttpClient()) { var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); uri.Scheme = "http"; // Craft a valid WebSocket Upgrade request using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) { SetGenericWebSocketRequest(request); request.Headers.Add(HeaderNames.SecWebSocketExtensions, "permessage-deflate"); var response = await client.SendAsync(request); Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode); Assert.False(response.Headers.Contains(HeaderNames.SecWebSocketExtensions)); } } } }
public async Task CanSendAndReceiveCompressedData() { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); using var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() { DangerousEnableCompression = true, ServerMaxWindowBits = 13 }); var serverBuffer = new byte[1024]; while (true) { var result = await webSocket.ReceiveAsync(serverBuffer, CancellationToken.None); if (result.MessageType == WebSocketMessageType.Close) { break; } await webSocket.SendAsync(serverBuffer.AsMemory(0, result.Count), result.MessageType, result.EndOfMessage, default); } await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, default); })) { using (var client = new ClientWebSocket()) { client.Options.DangerousDeflateOptions = new WebSocketDeflateOptions() { ServerMaxWindowBits = 12, ClientMaxWindowBits = 11, }; await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); var sendCount = 8193; var clientBuf = new byte[sendCount]; var receiveBuf = new byte[sendCount]; Random.Shared.NextBytes(clientBuf); await client.SendAsync(clientBuf.AsMemory(0, sendCount), WebSocketMessageType.Binary, true, default); var totalRecv = 0; while (totalRecv < sendCount) { var result = await client.ReceiveAsync(receiveBuf.AsMemory(totalRecv), default); totalRecv += result.Count; if (result.EndOfMessage) { Assert.Equal(sendCount, totalRecv); for (var i = 0; i < sendCount; ++i) { Assert.True(clientBuf[i] == receiveBuf[i], $"offset {i} not equal: {clientBuf[i]} == {receiveBuf[i]}"); } } } await client.CloseAsync(WebSocketCloseStatus.NormalClosure, null, default); } } }
public async Task SendFragmentedData_Success() { var orriginalData = Encoding.UTF8.GetBytes("Hello World"); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); var serverBuffer = new byte[orriginalData.Length]; var result = await webSocket.ReceiveAsync(new ArraySegment <byte>(serverBuffer), CancellationToken.None); Assert.False(result.EndOfMessage); Assert.Equal(2, result.Count); int totalReceived = result.Count; Assert.Equal(WebSocketMessageType.Binary, result.MessageType); tcs.SetResult(); result = await webSocket.ReceiveAsync( new ArraySegment <byte>(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None); Assert.False(result.EndOfMessage); Assert.Equal(2, result.Count); totalReceived += result.Count; Assert.Equal(WebSocketMessageType.Binary, result.MessageType); tcs.SetResult(); result = await webSocket.ReceiveAsync( new ArraySegment <byte>(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None); Assert.True(result.EndOfMessage); Assert.Equal(7, result.Count); totalReceived += result.Count; Assert.Equal(WebSocketMessageType.Binary, result.MessageType); Assert.Equal(orriginalData, serverBuffer); })) { using (var client = new ClientWebSocket()) { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); await client.SendAsync(new ArraySegment <byte>(orriginalData, 0, 2), WebSocketMessageType.Binary, false, CancellationToken.None); await tcs.Task; tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); await client.SendAsync(new ArraySegment <byte>(orriginalData, 2, 2), WebSocketMessageType.Binary, false, CancellationToken.None); await tcs.Task; tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); await client.SendAsync(new ArraySegment <byte>(orriginalData, 4, 7), WebSocketMessageType.Binary, true, CancellationToken.None); } } }
public async Task Connect_Success() { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); })) { using (var client = new ClientWebSocket()) { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); } } }
public async Task CommonHeadersAreSetToInternedStrings() { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); var webSocket = await context.WebSockets.AcceptWebSocketAsync(); // Use ReferenceEquals and test against the constants Assert.Same(HeaderNames.Upgrade, context.Request.Headers.Connection.ToString()); Assert.Same(Constants.Headers.UpgradeWebSocket, context.Request.Headers.Upgrade.ToString()); Assert.Same(Constants.Headers.SupportedVersion, context.Request.Headers.SecWebSocketVersion.ToString()); })) { using (var client = new ClientWebSocket()) { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); } } }
public async Task OriginIsValidatedForWebSocketRequests(HttpStatusCode expectedCode, params string[] origins) { await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, context => { Assert.True(context.WebSockets.IsWebSocketRequest); return(Task.CompletedTask); }, o => { if (origins != null) { foreach (var origin in origins) { o.AllowedOrigins.Add(origin); } } })) { using (var client = new HttpClient()) { var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); uri.Scheme = "http"; // Craft a valid WebSocket Upgrade request using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) { request.Headers.Connection.Clear(); request.Headers.Connection.Add("Upgrade"); request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket")); request.Headers.Add(HeaderNames.SecWebSocketVersion, "13"); // SecWebSocketKey required to be 16 bytes request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None)); request.Headers.Add(HeaderNames.Origin, "http://example.com"); var response = await client.SendAsync(request); Assert.Equal(expectedCode, response.StatusCode); } } } }