public ValueTask <ArraySegment <byte> > ReadPayloadAsync(ArraySegmentHolder <byte> cache, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) => ProtocolUtility.ReadPayloadAsync(m_bufferedByteReader !, m_byteHandler !, m_getNextSequenceNumber, cache, protocolErrorBehavior, ioBehavior);
public CompressedByteHandler(CompressedPayloadHandler compressedPayloadHandler, ProtocolErrorBehavior protocolErrorBehavior) { m_compressedPayloadHandler = compressedPayloadHandler; m_protocolErrorBehavior = protocolErrorBehavior; }
private static ValueTask <ArraySegment <byte> > DoReadPayloadAsync(BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func <int> getNextSequenceNumber, ArraySegmentHolder <byte> previousPayloads, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) { var readPacketTask = ReadPacketAsync(bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior); while (readPacketTask.IsCompleted) { if (HasReadPayload(previousPayloads, readPacketTask.Result, protocolErrorBehavior, out var result)) { return(result); } readPacketTask = ReadPacketAsync(bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior); } return(AddContinuation(readPacketTask, bufferedByteReader, byteHandler, getNextSequenceNumber, previousPayloads, protocolErrorBehavior, ioBehavior)); // NOTE: use a local function (with no captures) to defer creation of lambda objects ValueTask <ArraySegment <byte> > AddContinuation(ValueTask <Packet> readPacketTask_, BufferedByteReader bufferedByteReader_, IByteHandler byteHandler_, Func <int> getNextSequenceNumber_, ArraySegmentHolder <byte> previousPayloads_, ProtocolErrorBehavior protocolErrorBehavior_, IOBehavior ioBehavior_) { return(readPacketTask_.ContinueWith(packet => HasReadPayload(previousPayloads_, packet, protocolErrorBehavior_, out var result_) ? result_ : DoReadPayloadAsync(bufferedByteReader_, byteHandler_, getNextSequenceNumber_, previousPayloads_, protocolErrorBehavior_, ioBehavior_))); } }
private static bool HasReadPayload(ArraySegmentHolder <byte> previousPayloads, Packet packet, ProtocolErrorBehavior protocolErrorBehavior, out ValueTask <ArraySegment <byte> > result) { if (previousPayloads.Count == 0 && packet.Contents.Count < MaxPacketSize) { result = new ValueTask <ArraySegment <byte> >(packet.Contents); return(true); } var previousPayloadsArray = previousPayloads.Array; if (previousPayloadsArray is null) { previousPayloadsArray = new byte[ProtocolUtility.MaxPacketSize + 1]; } else if (previousPayloads.Offset + previousPayloads.Count + packet.Contents.Count > previousPayloadsArray.Length) { Array.Resize(ref previousPayloadsArray, previousPayloadsArray.Length * 2); } Buffer.BlockCopy(packet.Contents.Array, packet.Contents.Offset, previousPayloadsArray, previousPayloads.Offset + previousPayloads.Count, packet.Contents.Count); previousPayloads.ArraySegment = new ArraySegment <byte>(previousPayloadsArray, previousPayloads.Offset, previousPayloads.Count + packet.Contents.Count); if (packet.Contents.Count < ProtocolUtility.MaxPacketSize) { result = new ValueTask <ArraySegment <byte> >(previousPayloads.ArraySegment); return(true); } result = default(ValueTask <ArraySegment <byte> >); return(false); }
private static ValueTask <Packet> CreatePacketFromPayload(ArraySegment <byte> payloadBytes, int payloadLength, ProtocolErrorBehavior protocolErrorBehavior) => payloadBytes.Count >= payloadLength ? new ValueTask <Packet>(new Packet(payloadBytes)) : protocolErrorBehavior == ProtocolErrorBehavior.Throw ? ValueTaskExtensions.FromException <Packet>(new EndOfStreamException("Expected to read {0} payload bytes but only received {1}.".FormatInvariant(payloadLength, payloadBytes.Count))) : default(ValueTask <Packet>);
public static ValueTask <ArraySegment <byte> > ReadPayloadAsync(BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func <int> getNextSequenceNumber, ArraySegmentHolder <byte> cache, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) { cache.Clear(); return(DoReadPayloadAsync(bufferedByteReader, byteHandler, getNextSequenceNumber, cache, protocolErrorBehavior, ioBehavior)); }
private static ValueTask <Packet> ReadPacketAsync(BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func <int> getNextSequenceNumber, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) { var headerBytesTask = bufferedByteReader.ReadBytesAsync(byteHandler, 4, ioBehavior); if (headerBytesTask.IsCompleted) { return(ReadPacketAfterHeader(headerBytesTask.Result, bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior)); } return(AddContinuation(headerBytesTask, bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior)); // NOTE: use a local function (with no captures) to defer creation of lambda objects ValueTask <Packet> AddContinuation(ValueTask <ArraySegment <byte> > headerBytes_, BufferedByteReader bufferedByteReader_, IByteHandler byteHandler_, Func <int> getNextSequenceNumber_, ProtocolErrorBehavior protocolErrorBehavior_, IOBehavior ioBehavior_) => headerBytes_.ContinueWith(x => ReadPacketAfterHeader(x, bufferedByteReader_, byteHandler_, getNextSequenceNumber_, protocolErrorBehavior_, ioBehavior_)); }
private static ValueTask <Packet> ReadPacketAfterHeader(ArraySegment <byte> headerBytes, BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func <int> getNextSequenceNumber, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) { if (headerBytes.Count < 4) { return(protocolErrorBehavior == ProtocolErrorBehavior.Throw ? ValueTaskExtensions.FromException <Packet>(new EndOfStreamException("Expected to read 4 header bytes but only received {0}.".FormatInvariant(headerBytes.Count))) : default(ValueTask <Packet>)); } var payloadLength = (int)SerializationUtility.ReadUInt32(headerBytes.Array, headerBytes.Offset, 3); int packetSequenceNumber = headerBytes.Array[headerBytes.Offset + 3]; var expectedSequenceNumber = getNextSequenceNumber() % 256; if (expectedSequenceNumber != -1 && packetSequenceNumber != expectedSequenceNumber) { if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore) { return(default(ValueTask <Packet>)); } var exception = MySqlProtocolException.CreateForPacketOutOfOrder(expectedSequenceNumber, packetSequenceNumber); return(ValueTaskExtensions.FromException <Packet>(exception)); } var payloadBytesTask = bufferedByteReader.ReadBytesAsync(byteHandler, payloadLength, ioBehavior); if (payloadBytesTask.IsCompleted) { return(CreatePacketFromPayload(payloadBytesTask.Result, payloadLength, protocolErrorBehavior)); } return(AddContinuation(payloadBytesTask, payloadLength, protocolErrorBehavior)); // NOTE: use a local function (with no captures) to defer creation of lambda objects ValueTask <Packet> AddContinuation(ValueTask <ArraySegment <byte> > payloadBytesTask_, int payloadLength_, ProtocolErrorBehavior protocolErrorBehavior_) => payloadBytesTask_.ContinueWith(x => CreatePacketFromPayload(x, payloadLength_, protocolErrorBehavior_)); }
private static ValueTask <Packet> ReadPacketAsync(BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func <int> getNextSequenceNumber, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) { var headerBytesTask = bufferedByteReader.ReadBytesAsync(byteHandler, 4, ioBehavior); if (headerBytesTask.IsCompleted) { return(ReadPacketAfterHeader(headerBytesTask.Result, bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior)); } return(AddContinuation(headerBytesTask, bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior));
static async ValueTask <Packet> AddContinuation(ValueTask <ArraySegment <byte> > headerBytes_, BufferedByteReader bufferedByteReader_, IByteHandler byteHandler_, Func <int> getNextSequenceNumber_, ProtocolErrorBehavior protocolErrorBehavior_, IOBehavior ioBehavior_) => await ReadPacketAfterHeader(await headerBytes_.ConfigureAwait(false), bufferedByteReader_, byteHandler_, getNextSequenceNumber_, protocolErrorBehavior_, ioBehavior_).ConfigureAwait(false);
private ValueTask <ArraySegment <byte> > ReadBytesAsync(int count, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) { // satisfy the read from cache if possible if (m_remainingData.Count > 0) { int bytesToRead = Math.Min(m_remainingData.Count, count); var result = new ArraySegment <byte>(m_remainingData.Array, m_remainingData.Offset, bytesToRead); m_remainingData = m_remainingData.Slice(bytesToRead); return(new ValueTask <ArraySegment <byte> >(result)); } // read the compressed header (seven bytes) return(m_bufferedByteReader.ReadBytesAsync(m_byteHandler, 7, ioBehavior) .ContinueWith(headerReadBytes => { if (headerReadBytes.Count < 7) { return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ? default(ValueTask <ArraySegment <byte> >) : ValueTaskExtensions.FromException <ArraySegment <byte> >(new EndOfStreamException("Wanted to read 7 bytes but only read {0} when reading compressed packet header".FormatInvariant(headerReadBytes.Count))); } var payloadLength = (int)SerializationUtility.ReadUInt32(headerReadBytes.Array, headerReadBytes.Offset, 3); int packetSequenceNumber = headerReadBytes.Array[headerReadBytes.Offset + 3]; var uncompressedLength = (int)SerializationUtility.ReadUInt32(headerReadBytes.Array, headerReadBytes.Offset + 4, 3); // verify the compressed packet sequence number var expectedSequenceNumber = GetNextCompressedSequenceNumber() % 256; if (packetSequenceNumber != expectedSequenceNumber) { if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore) { return default(ValueTask <ArraySegment <byte> >); } var exception = new InvalidOperationException("Packet received out-of-order. Expected {0}; got {1}.".FormatInvariant(expectedSequenceNumber, packetSequenceNumber)); return ValueTaskExtensions.FromException <ArraySegment <byte> >(exception); } // MySQL protocol resets the uncompressed sequence number back to the sequence number of this compressed packet. // This isn't in the documentation, but the code explicitly notes that uncompressed packets are modified by compression: // - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L276 // - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L225-L227 if (!m_isContinuationPacket) { m_uncompressedSequenceNumber = packetSequenceNumber; } // except this doesn't happen when uncompressed packets need to be broken up across multiple compressed packets m_isContinuationPacket = payloadLength == ProtocolUtility.MaxPacketSize || uncompressedLength == ProtocolUtility.MaxPacketSize; return m_bufferedByteReader.ReadBytesAsync(m_byteHandler, payloadLength, ioBehavior) .ContinueWith(payloadReadBytes => { if (payloadReadBytes.Count < payloadLength) { return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ? default(ValueTask <ArraySegment <byte> >) : ValueTaskExtensions.FromException <ArraySegment <byte> >(new EndOfStreamException("Wanted to read {0} bytes but only read {1} when reading compressed payload".FormatInvariant(payloadLength, payloadReadBytes.Count))); } if (uncompressedLength == 0) { // data is uncompressed m_remainingData = payloadReadBytes; } else { // check CMF (Compression Method and Flags) and FLG (Flags) bytes for expected values var cmf = payloadReadBytes.Array[payloadReadBytes.Offset]; var flg = payloadReadBytes.Array[payloadReadBytes.Offset + 1]; if (cmf != 0x78 || ((flg & 0x40) == 0x40) || ((cmf * 256 + flg) % 31 != 0)) { // CMF = 0x78: 32K Window Size + deflate compression // FLG & 0x40: has preset dictionary (not supported) // CMF*256+FLG is a multiple of 31: header checksum return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ? default(ValueTask <ArraySegment <byte> >) : ValueTaskExtensions.FromException <ArraySegment <byte> >(new NotSupportedException("Unsupported zlib header: {0:X2}{1:X2}".FormatInvariant(cmf, flg))); } // zlib format (https://www.ietf.org/rfc/rfc1950.txt) is: [two header bytes] [deflate-compressed data] [four-byte checksum] // .NET implements the middle part with DeflateStream; need to handle header and checksum explicitly const int headerSize = 2; const int checksumSize = 4; var uncompressedData = new byte[uncompressedLength]; using (var compressedStream = new MemoryStream(payloadReadBytes.Array, payloadReadBytes.Offset + headerSize, payloadReadBytes.Count - headerSize - checksumSize)) using (var decompressingStream = new DeflateStream(compressedStream, CompressionMode.Decompress)) { var bytesRead = decompressingStream.Read(uncompressedData, 0, uncompressedLength); m_remainingData = new ArraySegment <byte>(uncompressedData, 0, bytesRead); var checksum = ComputeAdler32Checksum(uncompressedData, 0, bytesRead); int adlerStartOffset = payloadReadBytes.Offset + payloadReadBytes.Count - 4; if (payloadReadBytes.Array[adlerStartOffset + 0] != ((checksum >> 24) & 0xFF) || payloadReadBytes.Array[adlerStartOffset + 1] != ((checksum >> 16) & 0xFF) || payloadReadBytes.Array[adlerStartOffset + 2] != ((checksum >> 8) & 0xFF) || payloadReadBytes.Array[adlerStartOffset + 3] != (checksum & 0xFF)) { return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ? default(ValueTask <ArraySegment <byte> >) : ValueTaskExtensions.FromException <ArraySegment <byte> >(new NotSupportedException("Invalid Adler-32 checksum of uncompressed data.")); } } } var result = m_remainingData.Slice(0, count); m_remainingData = m_remainingData.Slice(count); return new ValueTask <ArraySegment <byte> >(result); }); })); }
public ValueTask <ArraySegment <byte> > ReadPayloadAsync(ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) => ProtocolUtility.ReadPayloadAsync(m_bufferedByteReader, new CompressedByteHandler(this, protocolErrorBehavior), GetNextUncompressedSequenceNumber, default(ArraySegment <byte>), protocolErrorBehavior, ioBehavior);
public ValueTask <ArraySegment <byte> > ReadPayloadAsync(ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) => ProtocolUtility.ReadPayloadAsync(m_bufferedByteReader, m_byteHandler, () => GetNextSequenceNumber(), default(ArraySegment <byte>), protocolErrorBehavior, ioBehavior);
private static ValueTask <Packet> CreatePacketFromPayload(ArraySegment <byte> payloadBytes, int payloadLength, int packetSequenceNumber, ProtocolErrorBehavior protocolErrorBehavior) => payloadBytes.Count >= payloadLength ? new ValueTask <Packet>(new Packet(packetSequenceNumber, payloadBytes)) : protocolErrorBehavior == ProtocolErrorBehavior.Throw ? ValueTaskExtensions.FromException <Packet>(new EndOfStreamException()) : default(ValueTask <Packet>);
private async Task <PayloadData> ReceivePacketAsync(ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior, CancellationToken cancellationToken) { if (m_end - m_offset < 4) { if (m_end - m_offset > 0) { Buffer.BlockCopy(m_buffer, m_offset, m_buffer, 0, m_end - m_offset); } m_end -= m_offset; m_offset = 0; } // read packet header int offset = m_end; int count = m_buffer.Length - m_end; while (m_end - m_offset < 4) { int bytesRead; if (ioBehavior == IOBehavior.Asynchronous) { m_socketAwaitable.EventArgs.SetBuffer(offset, count); await m_socket.ReceiveAsync(m_socketAwaitable); bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; } else { bytesRead = m_socket.Receive(m_buffer, offset, count, SocketFlags.None); } if (bytesRead <= 0) { if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore) { return(null); } throw new EndOfStreamException(); } offset += bytesRead; m_end += bytesRead; count -= bytesRead; } // decode packet header int payloadLength = (int)SerializationUtility.ReadUInt32(m_buffer, m_offset, 3); if (m_buffer[m_offset + 3] != (byte)(m_sequenceId & 0xFF)) { if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore) { return(null); } throw new InvalidOperationException("Packet received out-of-order. Expected {0}; got {1}.".FormatInvariant(m_sequenceId & 0xFF, m_buffer[3])); } m_sequenceId++; m_offset += 4; if (m_end - m_offset >= payloadLength) { offset = m_offset; m_offset += payloadLength; return(new PayloadData(new ArraySegment <byte>(m_buffer, offset, payloadLength))); } // allocate a larger buffer if necessary var readData = m_buffer; if (payloadLength > m_buffer.Length) { readData = new byte[payloadLength]; if (ioBehavior == IOBehavior.Asynchronous) { m_socketAwaitable.EventArgs.SetBuffer(readData, 0, 0); } } Buffer.BlockCopy(m_buffer, m_offset, readData, 0, m_end - m_offset); m_end -= m_offset; m_offset = 0; // read payload offset = m_end; count = readData.Length - m_end; while (m_end < payloadLength) { int bytesRead; if (ioBehavior == IOBehavior.Asynchronous) { m_socketAwaitable.EventArgs.SetBuffer(offset, count); await m_socket.ReceiveAsync(m_socketAwaitable); bytesRead = m_socketAwaitable.EventArgs.BytesTransferred; } else { bytesRead = m_socket.Receive(readData, offset, count, SocketFlags.None); } if (bytesRead <= 0) { throw new EndOfStreamException(); } offset += bytesRead; m_end += bytesRead; count -= bytesRead; } // switch back to original buffer if a larger one was allocated if (payloadLength > m_buffer.Length) { if (ioBehavior == IOBehavior.Asynchronous) { m_socketAwaitable.EventArgs.SetBuffer(m_buffer, 0, 0); } m_end = 0; } if (payloadLength <= m_buffer.Length) { m_offset = payloadLength; } return(new PayloadData(new ArraySegment <byte>(readData, 0, payloadLength))); }