Esempio n. 1
0
        public void ProcessHandshake(DTLSRecord record)
        {
            SocketAddress address = record.RemoteEndPoint.Serialize();
            Session session = Sessions.GetSession(address);
            byte[] data;
            if ((session != null) && session.IsEncypted(record))
            {
                int count = 0;
                while ((session.Cipher == null) && (count < 50))
                {
                    System.Threading.Thread.Sleep(10);
                    count++;
                }

                if (session.Cipher == null)
                {                    
                    throw new Exception();
                }

                if (session.Cipher != null)
                {
                    long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                    data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                }
                else
                    data = record.Fragment;
            }
            else
                data = record.Fragment;
            using (MemoryStream stream = new MemoryStream(data))
            {
                HandshakeRecord handshakeRecord = HandshakeRecord.Deserialise(stream);
                if (handshakeRecord != null)
                {
#if DEBUG
                    Console.WriteLine(handshakeRecord.MessageType.ToString());
#endif
                    switch (handshakeRecord.MessageType)
                    {
                        case THandshakeType.HelloRequest:
                            //HelloReq
                            break;
                        case THandshakeType.ClientHello:
                            ClientHello clientHello = ClientHello.Deserialise(stream);
                            if (clientHello != null)
                            {

                                byte[] cookie = clientHello.CalculateCookie(record.RemoteEndPoint, _HelloSecret);

                                if (clientHello.Cookie == null)
                                {
                                    Version version = clientHello.ClientVersion;
                                    if (ServerVersion < version)
                                        version = ServerVersion;
                                    if (session == null)
                                    {
                                        session = new Session();
                                        session.SessionID = Guid.NewGuid();
                                        session.RemoteEndPoint = record.RemoteEndPoint;
                                        session.Version = version;
                                        Sessions.AddSession(address, session);
                                    }
                                    else
                                    {
                                        session.Reset();
                                        session.Version = version;
                                    }
                                    session.ClientEpoch = record.Epoch;
                                    session.ClientSequenceNumber = record.SequenceNumber;
                                    //session.Handshake.UpdateHandshakeHash(data);
                                    HelloVerifyRequest helloVerifyRequest = new HelloVerifyRequest();
                                    helloVerifyRequest.Cookie = cookie;
                                    helloVerifyRequest.ServerVersion = ServerVersion;
                                    SendResponse(session, (IHandshakeMessage)helloVerifyRequest, 0);

                                }
                                else
                                {

                                    if (session != null && session.Cipher != null && !session.IsEncypted(record))
                                    {
                                        session.Reset();
                                    }

                                    if (TLSUtils.ByteArrayCompare(clientHello.Cookie, cookie))
                                    {
                                        Version version = clientHello.ClientVersion;
                                        if (ServerVersion < version)
                                            version = ServerVersion;
                                        if (clientHello.SessionID == null)
                                        {
                                            if (session == null)
                                            {
                                                session = new Session();
                                                session.NextSequenceNumber();
                                                session.SessionID = Guid.NewGuid();
                                                session.RemoteEndPoint = record.RemoteEndPoint;
                                                Sessions.AddSession(address, session);
                                            }
                                        }
                                        else
                                        {
                                            Guid sessionID = Guid.Empty;
                                            if (clientHello.SessionID.Length >= 16)
                                            {
                                                byte[] receivedSessionID = new byte[16];
                                                Buffer.BlockCopy(clientHello.SessionID, 0, receivedSessionID, 0, 16);
                                                sessionID = new Guid(receivedSessionID);
                                            }
                                            if (sessionID != Guid.Empty)
                                                session = Sessions.GetSession(sessionID);
                                            if (session == null)
                                            {
                                                //need to Find Session
                                                session = new Session();
                                                session.SessionID = Guid.NewGuid();
                                                session.NextSequenceNumber();
                                                session.RemoteEndPoint = record.RemoteEndPoint;
                                                Sessions.AddSession(address, session);
                                                //session.Version = clientHello.ClientVersion;
                                            }
                                        }
                                        session.Version = version;
                                        session.Handshake.InitaliseHandshakeHash(version < DTLSRecord.Version1_2);
                                        session.Handshake.UpdateHandshakeHash(data);
                                        TCipherSuite cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL;
                                        foreach (TCipherSuite item in clientHello.CipherSuites)
                                        {
                                            if (_SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, PrivateKey, _PSKIdentities, _ValidatePSK != null))
                                            {
                                                cipherSuite = item;
                                                break;
                                            }
                                        }

                                        TKeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite);

                                        ServerHello serverHello = new ServerHello();
                                        byte[] clientSessionID = new byte[32];
                                        byte[] temp = session.SessionID.ToByteArray();
                                        Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16);
                                        Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16);

                                        serverHello.SessionID = clientSessionID;// session.SessionID.ToByteArray();
                                        serverHello.Random = new RandomData();
                                        serverHello.Random.Generate();
                                        serverHello.CipherSuite = (ushort)cipherSuite;
                                        serverHello.ServerVersion = session.Version;

                                        THashAlgorithm hash = THashAlgorithm.SHA256;
                                        TEllipticCurve curve = TEllipticCurve.secp521r1;
                                        if (clientHello.Extensions != null)
                                        {
                                            foreach (Extension extension in clientHello.Extensions)
                                            {
                                                if (extension.SpecifcExtension is ClientCertificateTypeExtension)
                                                {
                                                    ClientCertificateTypeExtension clientCertificateType = extension.SpecifcExtension as ClientCertificateTypeExtension;
                                                    //TCertificateType certificateType = TCertificateType.Unknown;
                                                    //foreach (TCertificateType item in clientCertificateType.CertificateTypes)
                                                    //{

                                                    //}
                                                    //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType));
                                                }
                                                else if (extension.SpecifcExtension is EllipticCurvesExtension)
                                                {
                                                    EllipticCurvesExtension ellipticCurves = extension.SpecifcExtension as EllipticCurvesExtension;
                                                    foreach (TEllipticCurve item in ellipticCurves.SupportedCurves)
                                                    {
                                                        if (EllipticCurveFactory.SupportedCurve(item))
                                                        {
                                                            curve = item;
                                                            break;
                                                        }
                                                    }
                                                }
                                                else if (extension.SpecifcExtension is ServerCertificateTypeExtension)
                                                {
                                                    //serverHello.AddExtension();
                                                }
                                                else if (extension.SpecifcExtension is SignatureAlgorithmsExtension)
                                                {
                                                    SignatureAlgorithmsExtension signatureAlgorithms = extension.SpecifcExtension as SignatureAlgorithmsExtension;
                                                    foreach (SignatureHashAlgorithm item in signatureAlgorithms.SupportedAlgorithms)
                                                    {
                                                        if (item.Signature == TSignatureAlgorithm.ECDSA)
                                                        {
                                                            hash = item.Hash;
                                                            break;
                                                        }
                                                    }
                                                }
                                            }

                                        }

                                        session.Handshake.CipherSuite = cipherSuite;
                                        session.Handshake.ClientRandom = clientHello.Random;
                                        session.Handshake.ServerRandom = serverHello.Random;


                                        if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                        {
                                            EllipticCurvePointFormatsExtension pointFormatsExtension = new EllipticCurvePointFormatsExtension();
                                            pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed);
                                            serverHello.AddExtension(pointFormatsExtension);
                                        }
                                        session.Handshake.MessageSequence = 1;
                                        SendResponse(session, serverHello, session.Handshake.MessageSequence);
                                        session.Handshake.MessageSequence++;

                                        if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                        {
                                            if (Certificate != null)
                                            {
                                                SendResponse(session, Certificate, session.Handshake.MessageSequence);
                                                session.Handshake.MessageSequence++;
                                            }
                                            ECDHEKeyExchange keyExchange = new ECDHEKeyExchange();
                                            keyExchange.Curve = curve;
                                            keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                            keyExchange.ClientRandom = clientHello.Random;
                                            keyExchange.ServerRandom = serverHello.Random;
                                            keyExchange.GenerateEphemeralKey();
                                            session.Handshake.KeyExchange = keyExchange;
                                            if (session.Version == DTLSRecord.DefaultVersion)
                                                hash = THashAlgorithm.SHA1;
                                            ECDHEServerKeyExchange serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, PrivateKey);
                                            SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                            session.Handshake.MessageSequence++;
                                            if (_RequireClientCertificate)
                                            {
                                                CertificateRequest certificateRequest = new CertificateRequest();
                                                certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign);
                                                certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm() { Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA });
                                                SendResponse(session, certificateRequest, session.Handshake.MessageSequence);
                                                session.Handshake.MessageSequence++;
                                            }
                                        }
                                        else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                                        {
                                            ECDHEKeyExchange keyExchange = new ECDHEKeyExchange();
                                            keyExchange.Curve = curve;
                                            keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                            keyExchange.ClientRandom = clientHello.Random;
                                            keyExchange.ServerRandom = serverHello.Random;
                                            keyExchange.GenerateEphemeralKey();
                                            session.Handshake.KeyExchange = keyExchange;
                                            ECDHEPSKServerKeyExchange serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange);
                                            SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                            session.Handshake.MessageSequence++;

                                        }
                                        else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                                        {
                                            PSKKeyExchange keyExchange = new PSKKeyExchange();
                                            keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                            keyExchange.ClientRandom = clientHello.Random;
                                            keyExchange.ServerRandom = serverHello.Random;
                                            session.Handshake.KeyExchange = keyExchange;
                                            //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message
                                            //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange();
                                            //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                            //session.Handshake.MessageSequence++;
                                        }
                                        SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence);
                                        session.Handshake.MessageSequence++;
                                    }
                                }
                            }
                            break;
                        case THandshakeType.ServerHello:
                            break;
                        case THandshakeType.HelloVerifyRequest:
                            break;
                        case THandshakeType.Certificate:
                            Certificate clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509);
                            if (clientCertificate.CertChain.Count > 0)
                            {
                                session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER);                                
                            }
                            session.Handshake.UpdateHandshakeHash(data);
                            break;
                        case THandshakeType.ServerKeyExchange:
                            break;
                        case THandshakeType.CertificateRequest:
                            break;
                        case THandshakeType.ServerHelloDone:
                            break;
                        case THandshakeType.CertificateVerify:
                            CertificateVerify certificateVerify = CertificateVerify.Deserialise(stream, session.Version);
                            session.Handshake.UpdateHandshakeHash(data);
                            break;
                        case THandshakeType.ClientKeyExchange:
                            if ((session == null) || (session.Handshake.KeyExchange == null))
                            {

                            }
                            else
                            {
                                session.Handshake.UpdateHandshakeHash(data);
                                byte[] preMasterSecret = null;
                                if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                {
                                    ECDHEClientKeyExchange clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream);
                                    if (clientKeyExchange != null)
                                    {
                                        ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                        preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                    }
                                }
                                else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                                {
                                    ECDHEPSKClientKeyExchange clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream);
                                    if (clientKeyExchange != null)
                                    {
                                        session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                                        byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                                        if (psk == null)
                                        {
                                            psk = _ValidatePSK(clientKeyExchange.PSKIdentity);
                                            if (psk != null)
                                            {
                                                _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                            }
                                        }

                                        if (psk != null)
                                        {
                                            ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                            byte[] otherSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                            preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);

                                        }
                                    }
                                }
                                else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                                {
                                    PSKClientKeyExchange clientKeyExchange = PSKClientKeyExchange.Deserialise(stream);
                                    if (clientKeyExchange != null)
                                    {
                                        session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                                        byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                                        if (psk == null)
                                        {
                                            psk = _ValidatePSK(clientKeyExchange.PSKIdentity);
                                            if (psk != null)
                                            {
                                                _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                            }
                                        }

                                        if (psk != null)
                                        {
                                            ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                            byte[] otherSecret = new byte[psk.Length];
                                            preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                                        }
                                    }
                                }

                                if (preMasterSecret != null)
                                {
                                    //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange);
                                    //TLSUtils.AssignCipher(session);

                                    session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake);
                                }
                            }
                            break;
                        case THandshakeType.Finished:
                            Finished finished = Finished.Deserialise(stream);
                            if (session != null)
                            {
                                byte[] handshakeHash = session.Handshake.GetHash();
                                byte[] calculatedVerifyData = TLSUtils.GetVerifyData(session.Version,session.Handshake,false, true, handshakeHash);
#if DEBUG
                                Console.Write("Handshake Hash:");
                                TLSUtils.WriteToConsole(handshakeHash);
                                Console.Write("Sent Verify:");
                                TLSUtils.WriteToConsole(finished.VerifyData);
                                Console.Write("Calc Verify:");
                                TLSUtils.WriteToConsole(calculatedVerifyData);
#endif
                                if (TLSUtils.ByteArrayCompare(finished.VerifyData, calculatedVerifyData))
                                {

                                    SendChangeCipherSpec(session);
                                    session.Handshake.UpdateHandshakeHash(data);
                                    handshakeHash = session.Handshake.GetHash();
                                    Finished serverFinished = new Finished();
                                    serverFinished.VerifyData = TLSUtils.GetVerifyData(session.Version,session.Handshake,false, false, handshakeHash);
                                    SendResponse(session, serverFinished, session.Handshake.MessageSequence);
                                    session.Handshake.MessageSequence++;
                                }
                                else
                                {
                                    throw new Exception();
                                }
                            }
                            break;
                        default:
                            break;
                    }
                }
            }
        }
Esempio n. 2
0
 public static CertificateRequest Deserialise(Stream stream, Version version)
 {
     CertificateRequest result = new CertificateRequest();
     int certificateTypeCount = stream.ReadByte();
     if (certificateTypeCount > 0)
     {
         for (int index = 0; index < certificateTypeCount; index++)
         {
             result.CertificateTypes.Add((TClientCertificateType)stream.ReadByte());
         }
     }
     if (version >= DTLSRecord.Version1_2)
     {
         ushort length = NetworkByteOrderConverter.ToUInt16(stream);
         ushort supportedAlgorithmsLength = (ushort)(length / 2);
         if (supportedAlgorithmsLength > 0)
         {
             for (uint index = 0; index < supportedAlgorithmsLength; index++)
             {
                 THashAlgorithm hash = (THashAlgorithm)stream.ReadByte();
                 TSignatureAlgorithm signature = (TSignatureAlgorithm)stream.ReadByte();
                 result.SupportedAlgorithms.Add(new SignatureHashAlgorithm() { Hash = hash, Signature = signature });
             }
         }
     }
     ushort certificateAuthoritiesLength = NetworkByteOrderConverter.ToUInt16(stream);
     if (certificateAuthoritiesLength > 0)
     {
         int read = 0;
         while(certificateAuthoritiesLength > read)
         {
             ushort distinguishedNameLength = NetworkByteOrderConverter.ToUInt16(stream);
             read += (2 + distinguishedNameLength);
             byte[] distinguishedName = new byte[distinguishedNameLength];
             stream.Read(distinguishedName, 0, distinguishedNameLength);
             result.CertificateAuthorities.Add(distinguishedName);
         }
     }
     return result;
 }
        public void ProcessHandshake(DTLSRecord record)
        {
            SocketAddress address = record.RemoteEndPoint.Serialize();
            Session       session = Sessions.GetSession(address);

            byte[] data;
            if ((session != null) && session.IsEncypted(record))
            {
                int count = 0;
                while ((session.Cipher == null) && (count < (HandshakeTimeout / HANDSHAKE_DWELL_TIME)))
                {
                    System.Threading.Thread.Sleep(HANDSHAKE_DWELL_TIME);
                    count++;
                }

                if (session.Cipher == null)
                {
                    throw new Exception($"HandshakeTimeout: >{HandshakeTimeout}");
                }

                if (session.Cipher != null)
                {
                    long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                    data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                }
                else
                {
                    data = record.Fragment;
                }
            }
            else
            {
                data = record.Fragment;
            }
            using (MemoryStream stream = new MemoryStream(data))
            {
                HandshakeRecord handshakeRecord = HandshakeRecord.Deserialise(stream);
                if (handshakeRecord != null)
                {
#if DEBUG
                    Console.WriteLine(handshakeRecord.MessageType.ToString());
#endif
                    switch (handshakeRecord.MessageType)
                    {
                    case THandshakeType.HelloRequest:
                        //HelloReq
                        break;

                    case THandshakeType.ClientHello:
                        ClientHello clientHello = ClientHello.Deserialise(stream);
                        if (clientHello != null)
                        {
                            byte[] cookie = clientHello.CalculateCookie(record.RemoteEndPoint, _HelloSecret);

                            if (clientHello.Cookie == null)
                            {
                                Version version = clientHello.ClientVersion;
                                if (ServerVersion < version)
                                {
                                    version = ServerVersion;
                                }
                                if (session == null)
                                {
                                    session                = new Session();
                                    session.SessionID      = Guid.NewGuid();
                                    session.RemoteEndPoint = record.RemoteEndPoint;
                                    session.Version        = version;
                                    Sessions.AddSession(address, session);
                                }
                                else
                                {
                                    session.Reset();
                                    session.Version = version;
                                }
                                session.ClientEpoch          = record.Epoch;
                                session.ClientSequenceNumber = record.SequenceNumber;
                                //session.Handshake.UpdateHandshakeHash(data);
                                HelloVerifyRequest helloVerifyRequest = new HelloVerifyRequest();
                                helloVerifyRequest.Cookie        = cookie;
                                helloVerifyRequest.ServerVersion = ServerVersion;
                                SendResponse(session, (IHandshakeMessage)helloVerifyRequest, 0);
                            }
                            else
                            {
                                if (session != null && session.Cipher != null && !session.IsEncypted(record))
                                {
                                    session.Reset();
                                }

                                if (TLSUtils.ByteArrayCompare(clientHello.Cookie, cookie))
                                {
                                    Version version = clientHello.ClientVersion;
                                    if (ServerVersion < version)
                                    {
                                        version = ServerVersion;
                                    }
                                    if (clientHello.SessionID == null)
                                    {
                                        if (session == null)
                                        {
                                            session = new Session();
                                            session.NextSequenceNumber();
                                            session.SessionID      = Guid.NewGuid();
                                            session.RemoteEndPoint = record.RemoteEndPoint;
                                            Sessions.AddSession(address, session);
                                        }
                                    }
                                    else
                                    {
                                        Guid sessionID = Guid.Empty;
                                        if (clientHello.SessionID.Length >= 16)
                                        {
                                            byte[] receivedSessionID = new byte[16];
                                            Buffer.BlockCopy(clientHello.SessionID, 0, receivedSessionID, 0, 16);
                                            sessionID = new Guid(receivedSessionID);
                                        }
                                        if (sessionID != Guid.Empty)
                                        {
                                            session = Sessions.GetSession(sessionID);
                                        }
                                        if (session == null)
                                        {
                                            //need to Find Session
                                            session           = new Session();
                                            session.SessionID = Guid.NewGuid();
                                            session.NextSequenceNumber();
                                            session.RemoteEndPoint = record.RemoteEndPoint;
                                            Sessions.AddSession(address, session);
                                            //session.Version = clientHello.ClientVersion;
                                        }
                                    }
                                    session.Version = version;
                                    session.Handshake.InitaliseHandshakeHash(version < DTLSRecord.Version1_2);
                                    session.Handshake.UpdateHandshakeHash(data);
                                    TCipherSuite cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL;
                                    foreach (TCipherSuite item in clientHello.CipherSuites)
                                    {
                                        if (_SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, PrivateKey, _PSKIdentities, _ValidatePSK != null))
                                        {
                                            cipherSuite = item;
                                            break;
                                        }
                                    }

                                    TKeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite);

                                    ServerHello serverHello     = new ServerHello();
                                    byte[]      clientSessionID = new byte[32];
                                    byte[]      temp            = session.SessionID.ToByteArray();
                                    Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16);
                                    Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16);

                                    serverHello.SessionID = clientSessionID;    // session.SessionID.ToByteArray();
                                    serverHello.Random    = new RandomData();
                                    serverHello.Random.Generate();
                                    serverHello.CipherSuite   = (ushort)cipherSuite;
                                    serverHello.ServerVersion = session.Version;

                                    THashAlgorithm hash  = THashAlgorithm.SHA256;
                                    TEllipticCurve curve = TEllipticCurve.secp521r1;
                                    if (clientHello.Extensions != null)
                                    {
                                        foreach (Extension extension in clientHello.Extensions)
                                        {
                                            if (extension.SpecifcExtension is ClientCertificateTypeExtension)
                                            {
                                                ClientCertificateTypeExtension clientCertificateType = extension.SpecifcExtension as ClientCertificateTypeExtension;
                                                //TCertificateType certificateType = TCertificateType.Unknown;
                                                //foreach (TCertificateType item in clientCertificateType.CertificateTypes)
                                                //{

                                                //}
                                                //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType));
                                            }
                                            else if (extension.SpecifcExtension is EllipticCurvesExtension)
                                            {
                                                EllipticCurvesExtension ellipticCurves = extension.SpecifcExtension as EllipticCurvesExtension;
                                                foreach (TEllipticCurve item in ellipticCurves.SupportedCurves)
                                                {
                                                    if (EllipticCurveFactory.SupportedCurve(item))
                                                    {
                                                        curve = item;
                                                        break;
                                                    }
                                                }
                                            }
                                            else if (extension.SpecifcExtension is ServerCertificateTypeExtension)
                                            {
                                                //serverHello.AddExtension();
                                            }
                                            else if (extension.SpecifcExtension is SignatureAlgorithmsExtension)
                                            {
                                                SignatureAlgorithmsExtension signatureAlgorithms = extension.SpecifcExtension as SignatureAlgorithmsExtension;
                                                foreach (SignatureHashAlgorithm item in signatureAlgorithms.SupportedAlgorithms)
                                                {
                                                    if (item.Signature == TSignatureAlgorithm.ECDSA)
                                                    {
                                                        hash = item.Hash;
                                                        break;
                                                    }
                                                }
                                            }
                                        }
                                    }

                                    session.Handshake.CipherSuite  = cipherSuite;
                                    session.Handshake.ClientRandom = clientHello.Random;
                                    session.Handshake.ServerRandom = serverHello.Random;


                                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                    {
                                        EllipticCurvePointFormatsExtension pointFormatsExtension = new EllipticCurvePointFormatsExtension();
                                        pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed);
                                        serverHello.AddExtension(pointFormatsExtension);
                                    }
                                    session.Handshake.MessageSequence = 1;
                                    SendResponse(session, serverHello, session.Handshake.MessageSequence);
                                    session.Handshake.MessageSequence++;

                                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                    {
                                        if (Certificate != null)
                                        {
                                            SendResponse(session, Certificate, session.Handshake.MessageSequence);
                                            session.Handshake.MessageSequence++;
                                        }
                                        ECDHEKeyExchange keyExchange = new ECDHEKeyExchange();
                                        keyExchange.Curve = curve;
                                        keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                        keyExchange.ClientRandom         = clientHello.Random;
                                        keyExchange.ServerRandom         = serverHello.Random;
                                        keyExchange.GenerateEphemeralKey();
                                        session.Handshake.KeyExchange = keyExchange;
                                        if (session.Version == DTLSRecord.DefaultVersion)
                                        {
                                            hash = THashAlgorithm.SHA1;
                                        }
                                        ECDHEServerKeyExchange serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, PrivateKey);
                                        SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                        session.Handshake.MessageSequence++;
                                        if (_RequireClientCertificate)
                                        {
                                            CertificateRequest certificateRequest = new CertificateRequest();
                                            certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign);
                                            certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm()
                                            {
                                                Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA
                                            });
                                            SendResponse(session, certificateRequest, session.Handshake.MessageSequence);
                                            session.Handshake.MessageSequence++;
                                        }
                                    }
                                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                                    {
                                        ECDHEKeyExchange keyExchange = new ECDHEKeyExchange();
                                        keyExchange.Curve = curve;
                                        keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                        keyExchange.ClientRandom         = clientHello.Random;
                                        keyExchange.ServerRandom         = serverHello.Random;
                                        keyExchange.GenerateEphemeralKey();
                                        session.Handshake.KeyExchange = keyExchange;
                                        ECDHEPSKServerKeyExchange serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange);
                                        SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                        session.Handshake.MessageSequence++;
                                    }
                                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                                    {
                                        PSKKeyExchange keyExchange = new PSKKeyExchange();
                                        keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                        keyExchange.ClientRandom         = clientHello.Random;
                                        keyExchange.ServerRandom         = serverHello.Random;
                                        session.Handshake.KeyExchange    = keyExchange;
                                        //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message
                                        //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange();
                                        //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                        //session.Handshake.MessageSequence++;
                                    }
                                    SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence);
                                    session.Handshake.MessageSequence++;
                                }
                            }
                        }
                        break;

                    case THandshakeType.ServerHello:
                        break;

                    case THandshakeType.HelloVerifyRequest:
                        break;

                    case THandshakeType.Certificate:
                        Certificate clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509);
                        if (clientCertificate.CertChain.Count > 0)
                        {
                            session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER);
                        }
                        session.Handshake.UpdateHandshakeHash(data);
                        break;

                    case THandshakeType.ServerKeyExchange:
                        break;

                    case THandshakeType.CertificateRequest:
                        break;

                    case THandshakeType.ServerHelloDone:
                        break;

                    case THandshakeType.CertificateVerify:
                        CertificateVerify certificateVerify = CertificateVerify.Deserialise(stream, session.Version);
                        session.Handshake.UpdateHandshakeHash(data);
                        break;

                    case THandshakeType.ClientKeyExchange:
                        if ((session == null) || (session.Handshake.KeyExchange == null))
                        {
                        }
                        else
                        {
                            session.Handshake.UpdateHandshakeHash(data);
                            byte[] preMasterSecret = null;
                            if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                            {
                                ECDHEClientKeyExchange clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream);
                                if (clientKeyExchange != null)
                                {
                                    ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                    preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                }
                            }
                            else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                            {
                                ECDHEPSKClientKeyExchange clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream);
                                if (clientKeyExchange != null)
                                {
                                    session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                                    byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                                    if (psk == null)
                                    {
                                        psk = _ValidatePSK(clientKeyExchange.PSKIdentity);
                                        if (psk != null)
                                        {
                                            _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                        }
                                    }

                                    if (psk != null)
                                    {
                                        ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                        byte[]           otherSecret   = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                        preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                                    }
                                }
                            }
                            else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                            {
                                PSKClientKeyExchange clientKeyExchange = PSKClientKeyExchange.Deserialise(stream);
                                if (clientKeyExchange != null)
                                {
                                    session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                                    byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                                    if (psk == null)
                                    {
                                        psk = _ValidatePSK(clientKeyExchange.PSKIdentity);
                                        if (psk != null)
                                        {
                                            _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                        }
                                    }

                                    if (psk != null)
                                    {
                                        ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                        byte[]           otherSecret   = new byte[psk.Length];
                                        preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                                    }
                                }
                            }

                            if (preMasterSecret != null)
                            {
                                //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange);
                                //TLSUtils.AssignCipher(session);

                                session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake);
                            }
                        }
                        break;

                    case THandshakeType.Finished:
                        Finished finished = Finished.Deserialise(stream);
                        if (session != null)
                        {
                            byte[] handshakeHash        = session.Handshake.GetHash();
                            byte[] calculatedVerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, true, handshakeHash);
#if DEBUG
                            Console.Write("Handshake Hash:");
                            TLSUtils.WriteToConsole(handshakeHash);
                            Console.Write("Sent Verify:");
                            TLSUtils.WriteToConsole(finished.VerifyData);
                            Console.Write("Calc Verify:");
                            TLSUtils.WriteToConsole(calculatedVerifyData);
#endif
                            if (TLSUtils.ByteArrayCompare(finished.VerifyData, calculatedVerifyData))
                            {
                                SendChangeCipherSpec(session);
                                session.Handshake.UpdateHandshakeHash(data);
                                handshakeHash = session.Handshake.GetHash();
                                Finished serverFinished = new Finished();
                                serverFinished.VerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, false, handshakeHash);
                                SendResponse(session, serverFinished, session.Handshake.MessageSequence);
                                session.Handshake.MessageSequence++;
                            }
                            else
                            {
                                throw new Exception();
                            }
                        }
                        break;

                    default:
                        break;
                    }
                }
            }
        }
Esempio n. 4
0
        public void ProcessHandshake(DTLSRecord record)
        {
            if (record == null)
            {
                throw new ArgumentNullException(nameof(record));
            }

            var address = record.RemoteEndPoint.Serialize();
            var session = this.Sessions.GetSession(address);
            var data    = record.Fragment;

            if ((session != null) && session.IsEncypted(record))
            {
                var count = 0;
                while ((session.Cipher == null) && (count < 50))
                {
                    System.Threading.Thread.Sleep(10);
                    count++;
                }

                if (session.Cipher == null)
                {
                    throw new Exception("Need Cipher for Encrypted Session");
                }

                var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
            }

            using (var stream = new MemoryStream(data))
            {
                var handshakeRecord = HandshakeRecord.Deserialise(stream);
                switch (handshakeRecord.MessageType)
                {
                case THandshakeType.HelloRequest:
                {
                    //HelloReq
                    break;
                }

                case THandshakeType.ClientHello:
                {
                    var clientHello = ClientHello.Deserialise(stream);
                    var cookie      = clientHello.CalculateCookie(record.RemoteEndPoint, this._HelloSecret);

                    if (clientHello.Cookie == null)
                    {
                        var vers = clientHello.ClientVersion;
                        if (this.ServerVersion < vers)
                        {
                            vers = this.ServerVersion;
                        }

                        if (session == null)
                        {
                            session = new Session
                            {
                                SessionID      = Guid.NewGuid(),
                                RemoteEndPoint = record.RemoteEndPoint,
                                Version        = vers
                            };
                            this.Sessions.AddSession(address, session);
                        }
                        else
                        {
                            session.Reset();
                            session.Version = vers;
                        }

                        session.ClientEpoch          = record.Epoch;
                        session.ClientSequenceNumber = record.SequenceNumber;

                        var helloVerifyRequest = new HelloVerifyRequest
                        {
                            Cookie        = cookie,
                            ServerVersion = ServerVersion
                        };

                        this.SendResponse(session, helloVerifyRequest, 0);
                        break;
                    }

                    if (session != null && session.Cipher != null && !session.IsEncypted(record))
                    {
                        session.Reset();
                    }

                    if (!clientHello.Cookie.SequenceEqual(cookie))
                    {
                        break;
                    }

                    var version = clientHello.ClientVersion;
                    if (this.ServerVersion < version)
                    {
                        version = this.ServerVersion;
                    }

                    if (clientHello.SessionID == null)
                    {
                        if (session == null)
                        {
                            session = new Session();
                            session.NextSequenceNumber();
                            session.SessionID      = Guid.NewGuid();
                            session.RemoteEndPoint = record.RemoteEndPoint;
                            this.Sessions.AddSession(address, session);
                        }
                    }
                    else
                    {
                        if (clientHello.SessionID.Length >= 16)
                        {
                            session = this.Sessions.GetSession(new Guid(clientHello.SessionID.Take(16).ToArray()));
                        }

                        if (session == null)
                        {
                            //need to Find Session
                            session = new Session
                            {
                                SessionID = Guid.NewGuid()
                            };
                            session.NextSequenceNumber();
                            session.RemoteEndPoint = record.RemoteEndPoint;
                            this.Sessions.AddSession(address, session);
                        }
                    }

                    session.Version = version;
                    var cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL;
                    foreach (TCipherSuite item in clientHello.CipherSuites)
                    {
                        if (this._SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, this.PrivateKey, this._PSKIdentities, this._ValidatePSK != null))
                        {
                            cipherSuite = item;
                            break;
                        }
                    }

                    var clientSessionID = new byte[32];
                    var temp            = session.SessionID.ToByteArray();
                    Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16);
                    Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16);

                    var serverHello = new ServerHello
                    {
                        SessionID     = clientSessionID,    // session.SessionID.ToByteArray();
                        Random        = new RandomData(),
                        CipherSuite   = (ushort)cipherSuite,
                        ServerVersion = session.Version
                    };
                    serverHello.Random.Generate();

                    session.Handshake.UpdateHandshakeHash(data);
                    session.Handshake.CipherSuite  = cipherSuite;
                    session.Handshake.ClientRandom = clientHello.Random;
                    session.Handshake.ServerRandom = serverHello.Random;

                    var hash  = THashAlgorithm.SHA256;
                    var curve = TEllipticCurve.secp521r1;
                    if (clientHello.Extensions != null)
                    {
                        foreach (var extension in clientHello.Extensions)
                        {
                            if (extension.SpecificExtension is ClientCertificateTypeExtension)
                            {
                                var clientCertificateType = extension.SpecificExtension as ClientCertificateTypeExtension;
                                //TCertificateType certificateType = TCertificateType.Unknown;
                                //foreach (TCertificateType item in clientCertificateType.CertificateTypes)
                                //{

                                //}
                                //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType));
                            }
                            else if (extension.SpecificExtension is EllipticCurvesExtension)
                            {
                                var ellipticCurves = extension.SpecificExtension as EllipticCurvesExtension;
                                foreach (var item in ellipticCurves.SupportedCurves)
                                {
                                    if (EllipticCurveFactory.SupportedCurve(item))
                                    {
                                        curve = item;
                                        break;
                                    }
                                }
                            }
                            else if (extension.SpecificExtension is ServerCertificateTypeExtension)
                            {
                                //serverHello.AddExtension();
                            }
                            else if (extension.SpecificExtension is SignatureAlgorithmsExtension)
                            {
                                var signatureAlgorithms = extension.SpecificExtension as SignatureAlgorithmsExtension;
                                foreach (var item in signatureAlgorithms.SupportedAlgorithms)
                                {
                                    if (item.Signature == TSignatureAlgorithm.ECDSA)
                                    {
                                        hash = item.Hash;
                                        break;
                                    }
                                }
                            }
                        }
                    }

                    var keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite);
                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                    {
                        var pointFormatsExtension = new EllipticCurvePointFormatsExtension();
                        pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed);
                        serverHello.AddExtension(pointFormatsExtension);
                    }

                    session.Handshake.MessageSequence = 1;
                    this.SendResponse(session, serverHello, session.Handshake.MessageSequence);
                    session.Handshake.MessageSequence++;

                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                    {
                        if (this.Certificate != null)
                        {
                            this.SendResponse(session, this.Certificate, session.Handshake.MessageSequence);
                            session.Handshake.MessageSequence++;
                        }

                        var keyExchange = new ECDHEKeyExchange
                        {
                            Curve = curve,
                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                            ClientRandom         = clientHello.Random,
                            ServerRandom         = serverHello.Random
                        };

                        keyExchange.GenerateEphemeralKey();
                        session.Handshake.KeyExchange = keyExchange;
                        if (session.Version == DTLSRecord.DefaultVersion)
                        {
                            hash = THashAlgorithm.SHA1;
                        }

                        var serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, this.PrivateKey);
                        this.SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);

                        session.Handshake.MessageSequence++;
                        if (this._RequireClientCertificate)
                        {
                            var certificateRequest = new CertificateRequest();
                            certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign);
                            certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm()
                                {
                                    Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA
                                });
                            this.SendResponse(session, certificateRequest, session.Handshake.MessageSequence);
                            session.Handshake.MessageSequence++;
                        }
                    }
                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                    {
                        var keyExchange = new ECDHEKeyExchange
                        {
                            Curve = curve,
                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                            ClientRandom         = clientHello.Random,
                            ServerRandom         = serverHello.Random
                        };

                        keyExchange.GenerateEphemeralKey();
                        session.Handshake.KeyExchange = keyExchange;
                        var serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange);
                        this.SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                        session.Handshake.MessageSequence++;
                    }
                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                    {
                        var keyExchange = new PSKKeyExchange
                        {
                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                            ClientRandom         = clientHello.Random,
                            ServerRandom         = serverHello.Random
                        };
                        session.Handshake.KeyExchange = keyExchange;
                        //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message
                        //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange();
                        //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                        //session.Handshake.MessageSequence++;
                    }
                    this.SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence);
                    session.Handshake.MessageSequence++;
                    break;
                }

                case THandshakeType.ServerHello:
                {
                    break;
                }

                case THandshakeType.HelloVerifyRequest:
                {
                    break;
                }

                case THandshakeType.Certificate:
                {
                    var clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509);
                    if (clientCertificate.CertChain.Count > 0)
                    {
                        session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER);
                    }
                    session.Handshake.UpdateHandshakeHash(data);
                    break;
                }

                case THandshakeType.ServerKeyExchange:
                {
                    break;
                }

                case THandshakeType.CertificateRequest:
                {
                    break;
                }

                case THandshakeType.ServerHelloDone:
                {
                    break;
                }

                case THandshakeType.CertificateVerify:
                {
                    var certificateVerify = CertificateVerify.Deserialise(stream, session.Version);
                    session.Handshake.UpdateHandshakeHash(data);
                }
                break;

                case THandshakeType.ClientKeyExchange:
                {
                    if ((session == null) || (session.Handshake.KeyExchange == null))
                    {
                        break;
                    }

                    session.Handshake.UpdateHandshakeHash(data);
                    byte[] preMasterSecret = null;
                    if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                    {
                        var clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream);
                        if (clientKeyExchange != null)
                        {
                            var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                            preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                        }
                    }
                    else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                    {
                        var clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream);
                        if (clientKeyExchange != null)
                        {
                            session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                            var psk = this._PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                            if (psk == null)
                            {
                                psk = this._ValidatePSK(clientKeyExchange.PSKIdentity);
                                if (psk != null)
                                {
                                    this._PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                }
                            }

                            if (psk != null)
                            {
                                var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                var otherSecret   = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                            }
                        }
                    }
                    else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                    {
                        var clientKeyExchange = PSKClientKeyExchange.Deserialise(stream);
                        if (clientKeyExchange != null)
                        {
                            session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                            var psk = this._PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                            if (psk == null)
                            {
                                psk = this._ValidatePSK(clientKeyExchange.PSKIdentity);
                                if (psk != null)
                                {
                                    this._PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                }
                            }

                            if (psk != null)
                            {
                                var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                var otherSecret   = new byte[psk.Length];
                                preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                            }
                        }
                    }

                    if (preMasterSecret != null)
                    {
                        //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange);
                        //TLSUtils.AssignCipher(session);

                        session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake);
                    }
                    break;
                }

                case THandshakeType.Finished:
                {
                    var finished = Finished.Deserialise(stream);
                    if (session == null)
                    {
                        break;
                    }

                    var handshakeHash        = session.Handshake.GetHash(session.Version);
                    var calculatedVerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, true, handshakeHash);
                    if (!finished.VerifyData.SequenceEqual(calculatedVerifyData))
                    {
                        throw new Exception();
                    }

                    this.SendChangeCipherSpec(session);
                    session.Handshake.UpdateHandshakeHash(data);
                    handshakeHash = session.Handshake.GetHash(session.Version);
                    var serverFinished = new Finished
                    {
                        VerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, false, handshakeHash)
                    };
                    this.SendResponse(session, serverFinished, session.Handshake.MessageSequence);
                    session.Handshake.MessageSequence++;
                    break;
                }

                default:
                    break;
                }
            }
        }