Special exception which will get thrown if a packet is received that not looks like a TLS/SSL record. A user can check for this NotSslRecordException and so detect if one peer tries to use secure and the other plain connection.
Inheritance: Exception
        protected internal override void Decode(IChannelHandlerContext context, IByteBuffer input, List <object> output)
        {
            int startOffset = input.ReaderIndex;
            int endOffset   = input.WriterIndex;
            int offset      = startOffset;
            int totalLength = 0;

            List <int> packetLengths;

            // if we calculated the length of the current SSL record before, use that information.
            if (this.packetLength > 0)
            {
                if (endOffset - startOffset < this.packetLength)
                {
                    // input does not contain a single complete SSL record
                    return;
                }
                else
                {
                    packetLengths = new List <int>(4);
                    packetLengths.Add(this.packetLength);
                    offset           += this.packetLength;
                    totalLength       = this.packetLength;
                    this.packetLength = 0;
                }
            }
            else
            {
                packetLengths = new List <int>(4);
            }

            bool nonSslRecord = false;

            while (totalLength < TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
            {
                int readableBytes = endOffset - offset;
                if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH)
                {
                    break;
                }

                int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset);
                if (encryptedPacketLength == -1)
                {
                    nonSslRecord = true;
                    break;
                }

                Contract.Assert(encryptedPacketLength > 0);

                if (encryptedPacketLength > readableBytes)
                {
                    // wait until the whole packet can be read
                    this.packetLength = encryptedPacketLength;
                    break;
                }

                int newTotalLength = totalLength + encryptedPacketLength;
                if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
                {
                    // Don't read too much.
                    break;
                }

                // 1. call unwrap with packet boundaries - call SslStream.ReadAsync only once.
                // 2. once we're through all the whole packets, switch to reading out using fallback sized buffer

                // We have a whole packet.
                // Increment the offset to handle the next packet.
                packetLengths.Add(encryptedPacketLength);
                offset     += encryptedPacketLength;
                totalLength = newTotalLength;
            }

            if (totalLength > 0)
            {
                // The buffer contains one or more full SSL records.
                // Slice out the whole packet so unwrap will only be called with complete packets.
                // Also directly reset the packetLength. This is needed as unwrap(..) may trigger
                // decode(...) again via:
                // 1) unwrap(..) is called
                // 2) wrap(...) is called from within unwrap(...)
                // 3) wrap(...) calls unwrapLater(...)
                // 4) unwrapLater(...) calls decode(...)
                //
                // See https://github.com/netty/netty/issues/1534

                input.SkipBytes(totalLength);
                this.Unwrap(context, input, startOffset, totalLength, packetLengths, output);

                if (!this.firedChannelRead)
                {
                    // Check first if firedChannelRead is not set yet as it may have been set in a
                    // previous decode(...) call.
                    this.firedChannelRead = output.Count > 0;
                }
            }

            if (nonSslRecord)
            {
                // Not an SSL/TLS packet
                var ex = new NotSslRecordException(
                    "not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
                input.SkipBytes(input.ReadableBytes);
                context.FireExceptionCaught(ex);
                this.HandleFailure(ex);
            }
        }
Example #2
0
        protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List <object> output)
        {
            if (!this.suppressRead && !this.handshakeFailed)
            {
                int       writerIndex = input.WriterIndex;
                Exception error       = null;
                try
                {
                    bool continueLoop = true;
                    for (int i = 0; i < MAX_SSL_RECORDS && continueLoop; i++)
                    {
                        int readerIndex   = input.ReaderIndex;
                        int readableBytes = writerIndex - readerIndex;
                        if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH)
                        {
                            // Not enough data to determine the record type and length.
                            return;
                        }

                        int command = input.GetByte(readerIndex);
                        // tls, but not handshake command
                        switch (command)
                        {
                        case TlsUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
                        case TlsUtils.SSL_CONTENT_TYPE_ALERT:
                            int len = TlsUtils.GetEncryptedPacketLength(input, readerIndex);

                            // Not an SSL/TLS packet
                            if (len == TlsUtils.NOT_ENCRYPTED)
                            {
                                this.handshakeFailed = true;
                                var e = new NotSslRecordException(
                                    "not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
                                input.SkipBytes(input.ReadableBytes);

                                TlsUtils.NotifyHandshakeFailure(context, e);
                                throw e;
                            }
                            if (len == TlsUtils.NOT_ENOUGH_DATA ||
                                writerIndex - readerIndex - TlsUtils.SSL_RECORD_HEADER_LENGTH < len)
                            {
                                // Not enough data
                                return;
                            }

                            // increase readerIndex and try again.
                            input.SkipBytes(len);
                            continue;

                        case TlsUtils.SSL_CONTENT_TYPE_HANDSHAKE:
                            int majorVersion = input.GetByte(readerIndex + 1);

                            // SSLv3 or TLS
                            if (majorVersion == 3)
                            {
                                int packetLength = input.GetUnsignedShort(readerIndex + 3) + TlsUtils.SSL_RECORD_HEADER_LENGTH;

                                if (readableBytes < packetLength)
                                {
                                    // client hello incomplete; try again to decode once more data is ready.
                                    return;
                                }

                                // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
                                //
                                // Decode the ssl client hello packet.
                                // We have to skip bytes until SessionID (which sum to 43 bytes).
                                //
                                // struct {
                                //    ProtocolVersion client_version;
                                //    Random random;
                                //    SessionID session_id;
                                //    CipherSuite cipher_suites<2..2^16-2>;
                                //    CompressionMethod compression_methods<1..2^8-1>;
                                //    select (extensions_present) {
                                //        case false:
                                //            struct {};
                                //        case true:
                                //            Extension extensions<0..2^16-1>;
                                //    };
                                // } ClientHello;
                                //

                                int endOffset = readerIndex + packetLength;
                                int offset    = readerIndex + 43;

                                if (endOffset - offset < 6)
                                {
                                    continueLoop = false;
                                    break;
                                }

                                int sessionIdLength = input.GetByte(offset);
                                offset += sessionIdLength + 1;

                                int cipherSuitesLength = input.GetUnsignedShort(offset);
                                offset += cipherSuitesLength + 2;

                                int compressionMethodLength = input.GetByte(offset);
                                offset += compressionMethodLength + 1;

                                int extensionsLength = input.GetUnsignedShort(offset);
                                offset += 2;
                                int extensionsLimit = offset + extensionsLength;

                                if (extensionsLimit > endOffset)
                                {
                                    // Extensions should never exceed the record boundary.
                                    continueLoop = false;
                                    break;
                                }

                                for (;;)
                                {
                                    if (extensionsLimit - offset < 4)
                                    {
                                        continueLoop = false;
                                        break;
                                    }

                                    int extensionType = input.GetUnsignedShort(offset);
                                    offset += 2;

                                    int extensionLength = input.GetUnsignedShort(offset);
                                    offset += 2;

                                    if (extensionsLimit - offset < extensionLength)
                                    {
                                        continueLoop = false;
                                        break;
                                    }

                                    // SNI
                                    // See https://tools.ietf.org/html/rfc6066#page-6
                                    if (extensionType == 0)
                                    {
                                        offset += 2;
                                        if (extensionsLimit - offset < 3)
                                        {
                                            continueLoop = false;
                                            break;
                                        }

                                        int serverNameType = input.GetByte(offset);
                                        offset++;

                                        if (serverNameType == 0)
                                        {
                                            int serverNameLength = input.GetUnsignedShort(offset);
                                            offset += 2;

                                            if (serverNameLength <= 0 || extensionsLimit - offset < serverNameLength)
                                            {
                                                continueLoop = false;
                                                break;
                                            }

                                            string hostname = input.ToString(offset, serverNameLength, Encoding.UTF8);
                                            //try
                                            //{
                                            //    select(ctx, IDN.toASCII(hostname,
                                            //                            IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
                                            //}
                                            //catch (Throwable t)
                                            //{
                                            //    PlatformDependent.throwException(t);
                                            //}

                                            var idn = new IdnMapping()
                                            {
                                                AllowUnassigned = true
                                            };

                                            hostname = idn.GetAscii(hostname);
#if NETSTANDARD1_3
                                            // TODO: netcore does not have culture sensitive tolower()
                                            hostname = hostname.ToLowerInvariant();
#else
                                            hostname = hostname.ToLower(new CultureInfo("en-US"));
#endif
                                            this.Select(context, hostname);
                                            return;
                                        }
                                        else
                                        {
                                            // invalid enum value
                                            continueLoop = false;
                                            break;
                                        }
                                    }

                                    offset += extensionLength;
                                }
                            }

                            break;

                        // Fall-through
                        default:
                            //not tls, ssl or application data, do not try sni
                            continueLoop = false;
                            break;
                        }
                    }
                }
                catch (Exception e)
                {
                    error = e;

                    // unexpected encoding, ignore sni and use default
                    //if (//Logger.DebugEnabled)
                    //{
                    //    //Logger.Warn($"Unexpected client hello packet: {ByteBufferUtil.HexDump(input)}", e);
                    //}
                }

                if (this.serverTlsSniSettings.DefaultServerHostName != null)
                {
                    // Just select the default server TLS setting
                    this.Select(context, this.serverTlsSniSettings.DefaultServerHostName);
                }
                else
                {
                    this.handshakeFailed = true;
                    var e = new DecoderException($"failed to get the server TLS setting {error}");
                    TlsUtils.NotifyHandshakeFailure(context, e);
                    throw e;
                }
            }
        }
Example #3
0
        protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List <object> output)
        {
            if (!_suppressRead && !_handshakeFailed)
            {
                Exception error = null;
                try
                {
                    int readerIndex     = input.ReaderIndex;
                    int readableBytes   = input.ReadableBytes;
                    int handshakeLength = -1;

                    // Check if we have enough data to determine the record type and length.
                    while (readableBytes >= TlsUtils.SSL_RECORD_HEADER_LENGTH)
                    {
                        int contentType = input.GetByte(readerIndex);
                        // tls, but not handshake command
                        switch (contentType)
                        {
                        case TlsUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
                        // fall-through
                        case TlsUtils.SSL_CONTENT_TYPE_ALERT:
                            int len = TlsUtils.GetEncryptedPacketLength(input, readerIndex);

                            // Not an SSL/TLS packet
                            if (len == TlsUtils.NOT_ENCRYPTED)
                            {
                                _handshakeFailed = true;
                                var e = new NotSslRecordException(
                                    "not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
                                _ = input.SkipBytes(input.ReadableBytes);
                                _ = context.FireUserEventTriggered(new SniCompletionEvent(e));
                                TlsUtils.NotifyHandshakeFailure(context, e, true);
                                throw e;
                            }
                            if (len == TlsUtils.NOT_ENOUGH_DATA)
                            {
                                // Not enough data
                                return;
                            }
                            // SNI can't be present in an ALERT or CHANGE_CIPHER_SPEC record, so we'll fall back and
                            // assume no SNI is present. Let's let the actual TLS implementation sort this out.
                            // Just select the default SslContext
                            goto SelectDefault;

                        case TlsUtils.SSL_CONTENT_TYPE_HANDSHAKE:
                            int majorVersion = input.GetByte(readerIndex + 1);

                            // SSLv3 or TLS
                            if (majorVersion == 3)
                            {
                                int packetLength = input.GetUnsignedShort(readerIndex + 3) + TlsUtils.SSL_RECORD_HEADER_LENGTH;

                                if (readableBytes < packetLength)
                                {
                                    // client hello incomplete; try again to decode once more data is ready.
                                    return;
                                }
                                else if (packetLength == TlsUtils.SSL_RECORD_HEADER_LENGTH)
                                {
                                    goto SelectDefault;
                                }


                                int endOffset = readerIndex + packetLength;

                                // Let's check if we already parsed the handshake length or not.
                                if (handshakeLength == -1)
                                {
                                    if (readerIndex + 4 > endOffset)
                                    {
                                        // Need more data to read HandshakeType and handshakeLength (4 bytes)
                                        return;
                                    }

                                    int handshakeType = input.GetByte(readerIndex + TlsUtils.SSL_RECORD_HEADER_LENGTH);

                                    // Check if this is a clientHello(1)
                                    // See https://tools.ietf.org/html/rfc5246#section-7.4
                                    if (handshakeType != 1)
                                    {
                                        goto SelectDefault;
                                    }

                                    // Read the length of the handshake as it may arrive in fragments
                                    // See https://tools.ietf.org/html/rfc5246#section-7.4
                                    handshakeLength = input.GetUnsignedMedium(readerIndex + TlsUtils.SSL_RECORD_HEADER_LENGTH + 1);

                                    // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
                                    readerIndex  += 4;
                                    packetLength -= 4;

                                    if (handshakeLength + 4 + TlsUtils.SSL_RECORD_HEADER_LENGTH <= packetLength)
                                    {
                                        // We have everything we need in one packet.
                                        // Skip the record header
                                        readerIndex += TlsUtils.SSL_RECORD_HEADER_LENGTH;
                                        Select(context, ExtractSniHostname(input, readerIndex, readerIndex + handshakeLength));
                                        return;
                                    }
                                    else
                                    {
                                        if (_handshakeBuffer is null)
                                        {
                                            _handshakeBuffer = context.Allocator.Buffer(handshakeLength);
                                        }
                                        else
                                        {
                                            // Clear the buffer so we can aggregate into it again.
                                            _ = _handshakeBuffer.Clear();
                                        }
                                    }
                                }

                                // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
                                _ = _handshakeBuffer.WriteBytes(input, readerIndex + TlsUtils.SSL_RECORD_HEADER_LENGTH,
                                                                packetLength - TlsUtils.SSL_RECORD_HEADER_LENGTH);
                                readerIndex   += packetLength;
                                readableBytes -= packetLength;
                                if (handshakeLength <= _handshakeBuffer.ReadableBytes)
                                {
                                    Select(context, ExtractSniHostname(_handshakeBuffer, 0, handshakeLength));
                                    return;
                                }
                            }
                            break;

                        default:
                            //not tls, ssl or application data, do not try sni
                            break;
                        }
                    }
                }
                catch (Exception e)
                {
                    error = e;

                    // unexpected encoding, ignore sni and use default
                    if (Logger.WarnEnabled)
                    {
                        Logger.UnexpectedClientHelloPacket(input, e);
                    }
                }

SelectDefault:
                if (_serverTlsSniSettings.DefaultServerHostName is object)
                {
                    // Just select the default server TLS setting
                    Select(context, _serverTlsSniSettings.DefaultServerHostName);
                }
                else
                {
                    ReleaseHandshakeBuffer();

                    _handshakeFailed = true;
                    var e = new DecoderException($"failed to get the server TLS setting {error}");
                    TlsUtils.NotifyHandshakeFailure(context, e, true);
                    throw e;
                }
            }
        }