예제 #1
0
 internal bool IsEncypted(DTLSRecord record)
 {
     bool result = false;
     if (EncyptedClientEpoch.HasValue)
         result = record.Epoch == EncyptedClientEpoch.Value;
     return result;
 }
예제 #2
0
 public void Add(DTLSRecord record)
 {
     lock (_Records)
     {
         int index = 0;
         bool added = false;
         while (index < _Records.Count)
         {
             if (record.Epoch < _Records[index].Epoch)
             {
                 _Records.Insert(index, record);
                 added = true;
                 break;
             }
             if ((record.SequenceNumber < _Records[index].SequenceNumber) && (record.Epoch == _Records[index].Epoch))
             {
                 _Records.Insert(index, record);
                 added = true;
                 break;
             }
             else if ((record.SequenceNumber == _Records[index].SequenceNumber) && (record.Epoch == _Records[index].Epoch))
             {
                 added = true;
                 break;
             }
             index++;
         }
         if (!added)
             _Records.Add(record);
     }
 }
예제 #3
0
        public void Send(byte[] data)
        {
            try
            {
                DTLSRecord record = new DTLSRecord();
                record.RecordType     = TRecordType.ApplicationData;
                record.Epoch          = _Epoch;
                record.SequenceNumber = NextSequenceNumber();
                if (_Version != null)
                {
                    record.Version = _Version;
                }
                long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                record.Fragment = _Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.ApplicationData, data, 0, data.Length);
                int    responseSize = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;
                byte[] response     = new byte[responseSize];
                using (MemoryStream stream = new MemoryStream(response))
                {
                    record.Serialise(stream);
                }
                SocketAsyncEventArgs parameters = new SocketAsyncEventArgs()
                {
                    RemoteEndPoint = _ServerEndPoint
                };
                parameters.SetBuffer(response, 0, responseSize);
                if (_Socket != null)
                {
                    _Socket.SendToAsync(parameters);
                }
            }
#if DEBUG
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
#else
            catch
            {
#endif
            }
        }
예제 #4
0
        private DTLSRecord CreateRecord(Session session, IHandshakeMessage handshakeMessage, ushort messageSequence)
        {
            int        size   = handshakeMessage.CalculateSize(session.Version);
            DTLSRecord record = new DTLSRecord
            {
                RecordType     = TRecordType.Handshake,
                Epoch          = session.Epoch,
                SequenceNumber = session.NextSequenceNumber(),
                Fragment       = new byte[HandshakeRecord.RECORD_OVERHEAD + size]
            };

            if (session.Version != null)
            {
                record.Version = session.Version;
            }
            HandshakeRecord handshakeRecord = new HandshakeRecord
            {
                MessageType    = handshakeMessage.MessageType,
                MessageSeq     = messageSequence,
                Length         = (uint)size,
                FragmentLength = (uint)size
            };

            using (MemoryStream stream = new MemoryStream(record.Fragment))
            {
                handshakeRecord.Serialise(stream);
                handshakeMessage.Serialise(stream, session.Version);
            }
            if (handshakeMessage.MessageType != THandshakeType.HelloVerifyRequest)
            {
                session.Handshake.UpdateHandshakeHash(record.Fragment);
            }
            if (session.Cipher != null)
            {
                long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                record.Fragment = session.Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
            }
            return(record);
        }
예제 #5
0
        public async Task SendAsync(byte[] data, TimeSpan timeout)
        {
            if (data == null)
            {
                throw new ArgumentNullException(nameof(data));
            }

            if (this._Socket == null)
            {
                throw new Exception("Socket Cannot be Null");
            }

            if (this._Cipher == null)
            {
                throw new Exception("Cipher Cannot be Null");
            }

            var record = new DTLSRecord
            {
                RecordType     = TRecordType.ApplicationData,
                Epoch          = _Epoch,
                SequenceNumber = this._NextSequenceNumber(),
                Version        = this._Version
            };

            var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;

            record.Fragment = this._Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.ApplicationData, data, 0, data.Length);

            var recordSize  = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;
            var recordBytes = new byte[recordSize];

            using (var stream = new MemoryStream(recordBytes))
            {
                record.Serialise(stream);
            }

            await this._Socket.SendAsync(recordBytes, timeout).ConfigureAwait(false);
        }
예제 #6
0
        private void SendAlert(TAlertLevel alertLevel, TAlertDescription alertDescription)
        {
            DTLSRecord record = new DTLSRecord();

            record.RecordType     = TRecordType.Alert;
            record.Epoch          = _Epoch;
            record.SequenceNumber = NextSequenceNumber();
            if (_Version != null)
            {
                record.Version = _Version;
            }
            long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;

            byte[] data = new byte[2];
            data[0] = (byte)alertLevel;
            data[1] = (byte)alertDescription;
            if (_Cipher == null)
            {
                record.Fragment = data;
            }
            else
            {
                record.Fragment = _Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.ApplicationData, data, 0, data.Length);
            }
            int responseSize = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;

            byte[] response = new byte[responseSize];
            using (MemoryStream stream = new MemoryStream(response))
            {
                record.Serialise(stream);
            }
            SocketAsyncEventArgs parameters = new SocketAsyncEventArgs()
            {
                RemoteEndPoint = _ServerEndPoint
            };

            parameters.SetBuffer(response, 0, responseSize);
            _Socket.SendToAsync(parameters);
        }
예제 #7
0
 private void ReceiveCallback(object sender, SocketAsyncEventArgs e)
 {
     if (e.BytesTransferred == 0)
     {
     }
     else
     {
         int    count = e.BytesTransferred;
         byte[] data  = new byte[count];
         Buffer.BlockCopy(e.Buffer, 0, data, 0, count);
         MemoryStream stream = new MemoryStream(data);
         while (stream.Position < stream.Length)
         {
             DTLSRecord record = DTLSRecord.Deserialise(stream);
             if (record != null)
             {
                 record.RemoteEndPoint = e.RemoteEndPoint;
                 _Records.Add(record);
                 _TriggerProcessRecords.Set();
             }
         }
         Socket socket = sender as Socket;
         if (socket != null)
         {
             System.Net.EndPoint remoteEndPoint;
             if (socket.AddressFamily == AddressFamily.InterNetwork)
             {
                 remoteEndPoint = new IPEndPoint(IPAddress.Any, 0);
             }
             else
             {
                 remoteEndPoint = new IPEndPoint(IPAddress.IPv6Any, 0);
             }
             e.RemoteEndPoint = remoteEndPoint;
             e.SetBuffer(0, 4096);
             socket.ReceiveFromAsync(e);
         }
     }
 }
예제 #8
0
        private void SendChangeCipherSpec(Session session)
        {
            if (session == null)
            {
                throw new ArgumentNullException(nameof(session));
            }

            var size         = 1;
            var responseSize = DTLSRecord.RECORD_OVERHEAD + size;
            var response     = new byte[responseSize];
            var record       = new DTLSRecord
            {
                RecordType     = TRecordType.ChangeCipherSpec,
                Epoch          = session.Epoch,
                SequenceNumber = session.NextSequenceNumber(),
                Fragment       = new byte[size]
            };

            record.Fragment[0] = 1;
            if (session.Version != null)
            {
                record.Version = session.Version;
            }

            using (var stream = new MemoryStream(response))
            {
                record.Serialise(stream);
            }
            var parameters = new SocketAsyncEventArgs()
            {
                RemoteEndPoint = session.RemoteEndPoint
            };

            parameters.SetBuffer(response, 0, responseSize);
            this._Socket.SendToAsync(parameters);
            session.ChangeEpoch();
        }
예제 #9
0
        public void ProcessHandshake(DTLSRecord record)
        {
            SocketAddress address = record.RemoteEndPoint.Serialize();
            Session session = Sessions.GetSession(address);
            byte[] data;
            if ((session != null) && session.IsEncypted(record))
            {
                int count = 0;
                while ((session.Cipher == null) && (count < 50))
                {
                    System.Threading.Thread.Sleep(10);
                    count++;
                }

                if (session.Cipher == null)
                {                    
                    throw new Exception();
                }

                if (session.Cipher != null)
                {
                    long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                    data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                }
                else
                    data = record.Fragment;
            }
            else
                data = record.Fragment;
            using (MemoryStream stream = new MemoryStream(data))
            {
                HandshakeRecord handshakeRecord = HandshakeRecord.Deserialise(stream);
                if (handshakeRecord != null)
                {
#if DEBUG
                    Console.WriteLine(handshakeRecord.MessageType.ToString());
#endif
                    switch (handshakeRecord.MessageType)
                    {
                        case THandshakeType.HelloRequest:
                            //HelloReq
                            break;
                        case THandshakeType.ClientHello:
                            ClientHello clientHello = ClientHello.Deserialise(stream);
                            if (clientHello != null)
                            {

                                byte[] cookie = clientHello.CalculateCookie(record.RemoteEndPoint, _HelloSecret);

                                if (clientHello.Cookie == null)
                                {
                                    Version version = clientHello.ClientVersion;
                                    if (ServerVersion < version)
                                        version = ServerVersion;
                                    if (session == null)
                                    {
                                        session = new Session();
                                        session.SessionID = Guid.NewGuid();
                                        session.RemoteEndPoint = record.RemoteEndPoint;
                                        session.Version = version;
                                        Sessions.AddSession(address, session);
                                    }
                                    else
                                    {
                                        session.Reset();
                                        session.Version = version;
                                    }
                                    session.ClientEpoch = record.Epoch;
                                    session.ClientSequenceNumber = record.SequenceNumber;
                                    //session.Handshake.UpdateHandshakeHash(data);
                                    HelloVerifyRequest helloVerifyRequest = new HelloVerifyRequest();
                                    helloVerifyRequest.Cookie = cookie;
                                    helloVerifyRequest.ServerVersion = ServerVersion;
                                    SendResponse(session, (IHandshakeMessage)helloVerifyRequest, 0);

                                }
                                else
                                {

                                    if (session != null && session.Cipher != null && !session.IsEncypted(record))
                                    {
                                        session.Reset();
                                    }

                                    if (TLSUtils.ByteArrayCompare(clientHello.Cookie, cookie))
                                    {
                                        Version version = clientHello.ClientVersion;
                                        if (ServerVersion < version)
                                            version = ServerVersion;
                                        if (clientHello.SessionID == null)
                                        {
                                            if (session == null)
                                            {
                                                session = new Session();
                                                session.NextSequenceNumber();
                                                session.SessionID = Guid.NewGuid();
                                                session.RemoteEndPoint = record.RemoteEndPoint;
                                                Sessions.AddSession(address, session);
                                            }
                                        }
                                        else
                                        {
                                            Guid sessionID = Guid.Empty;
                                            if (clientHello.SessionID.Length >= 16)
                                            {
                                                byte[] receivedSessionID = new byte[16];
                                                Buffer.BlockCopy(clientHello.SessionID, 0, receivedSessionID, 0, 16);
                                                sessionID = new Guid(receivedSessionID);
                                            }
                                            if (sessionID != Guid.Empty)
                                                session = Sessions.GetSession(sessionID);
                                            if (session == null)
                                            {
                                                //need to Find Session
                                                session = new Session();
                                                session.SessionID = Guid.NewGuid();
                                                session.NextSequenceNumber();
                                                session.RemoteEndPoint = record.RemoteEndPoint;
                                                Sessions.AddSession(address, session);
                                                //session.Version = clientHello.ClientVersion;
                                            }
                                        }
                                        session.Version = version;
                                        session.Handshake.InitaliseHandshakeHash(version < DTLSRecord.Version1_2);
                                        session.Handshake.UpdateHandshakeHash(data);
                                        TCipherSuite cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL;
                                        foreach (TCipherSuite item in clientHello.CipherSuites)
                                        {
                                            if (_SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, PrivateKey, _PSKIdentities, _ValidatePSK != null))
                                            {
                                                cipherSuite = item;
                                                break;
                                            }
                                        }

                                        TKeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite);

                                        ServerHello serverHello = new ServerHello();
                                        byte[] clientSessionID = new byte[32];
                                        byte[] temp = session.SessionID.ToByteArray();
                                        Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16);
                                        Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16);

                                        serverHello.SessionID = clientSessionID;// session.SessionID.ToByteArray();
                                        serverHello.Random = new RandomData();
                                        serverHello.Random.Generate();
                                        serverHello.CipherSuite = (ushort)cipherSuite;
                                        serverHello.ServerVersion = session.Version;

                                        THashAlgorithm hash = THashAlgorithm.SHA256;
                                        TEllipticCurve curve = TEllipticCurve.secp521r1;
                                        if (clientHello.Extensions != null)
                                        {
                                            foreach (Extension extension in clientHello.Extensions)
                                            {
                                                if (extension.SpecifcExtension is ClientCertificateTypeExtension)
                                                {
                                                    ClientCertificateTypeExtension clientCertificateType = extension.SpecifcExtension as ClientCertificateTypeExtension;
                                                    //TCertificateType certificateType = TCertificateType.Unknown;
                                                    //foreach (TCertificateType item in clientCertificateType.CertificateTypes)
                                                    //{

                                                    //}
                                                    //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType));
                                                }
                                                else if (extension.SpecifcExtension is EllipticCurvesExtension)
                                                {
                                                    EllipticCurvesExtension ellipticCurves = extension.SpecifcExtension as EllipticCurvesExtension;
                                                    foreach (TEllipticCurve item in ellipticCurves.SupportedCurves)
                                                    {
                                                        if (EllipticCurveFactory.SupportedCurve(item))
                                                        {
                                                            curve = item;
                                                            break;
                                                        }
                                                    }
                                                }
                                                else if (extension.SpecifcExtension is ServerCertificateTypeExtension)
                                                {
                                                    //serverHello.AddExtension();
                                                }
                                                else if (extension.SpecifcExtension is SignatureAlgorithmsExtension)
                                                {
                                                    SignatureAlgorithmsExtension signatureAlgorithms = extension.SpecifcExtension as SignatureAlgorithmsExtension;
                                                    foreach (SignatureHashAlgorithm item in signatureAlgorithms.SupportedAlgorithms)
                                                    {
                                                        if (item.Signature == TSignatureAlgorithm.ECDSA)
                                                        {
                                                            hash = item.Hash;
                                                            break;
                                                        }
                                                    }
                                                }
                                            }

                                        }

                                        session.Handshake.CipherSuite = cipherSuite;
                                        session.Handshake.ClientRandom = clientHello.Random;
                                        session.Handshake.ServerRandom = serverHello.Random;


                                        if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                        {
                                            EllipticCurvePointFormatsExtension pointFormatsExtension = new EllipticCurvePointFormatsExtension();
                                            pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed);
                                            serverHello.AddExtension(pointFormatsExtension);
                                        }
                                        session.Handshake.MessageSequence = 1;
                                        SendResponse(session, serverHello, session.Handshake.MessageSequence);
                                        session.Handshake.MessageSequence++;

                                        if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                        {
                                            if (Certificate != null)
                                            {
                                                SendResponse(session, Certificate, session.Handshake.MessageSequence);
                                                session.Handshake.MessageSequence++;
                                            }
                                            ECDHEKeyExchange keyExchange = new ECDHEKeyExchange();
                                            keyExchange.Curve = curve;
                                            keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                            keyExchange.ClientRandom = clientHello.Random;
                                            keyExchange.ServerRandom = serverHello.Random;
                                            keyExchange.GenerateEphemeralKey();
                                            session.Handshake.KeyExchange = keyExchange;
                                            if (session.Version == DTLSRecord.DefaultVersion)
                                                hash = THashAlgorithm.SHA1;
                                            ECDHEServerKeyExchange serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, PrivateKey);
                                            SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                            session.Handshake.MessageSequence++;
                                            if (_RequireClientCertificate)
                                            {
                                                CertificateRequest certificateRequest = new CertificateRequest();
                                                certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign);
                                                certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm() { Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA });
                                                SendResponse(session, certificateRequest, session.Handshake.MessageSequence);
                                                session.Handshake.MessageSequence++;
                                            }
                                        }
                                        else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                                        {
                                            ECDHEKeyExchange keyExchange = new ECDHEKeyExchange();
                                            keyExchange.Curve = curve;
                                            keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                            keyExchange.ClientRandom = clientHello.Random;
                                            keyExchange.ServerRandom = serverHello.Random;
                                            keyExchange.GenerateEphemeralKey();
                                            session.Handshake.KeyExchange = keyExchange;
                                            ECDHEPSKServerKeyExchange serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange);
                                            SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                            session.Handshake.MessageSequence++;

                                        }
                                        else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                                        {
                                            PSKKeyExchange keyExchange = new PSKKeyExchange();
                                            keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                            keyExchange.ClientRandom = clientHello.Random;
                                            keyExchange.ServerRandom = serverHello.Random;
                                            session.Handshake.KeyExchange = keyExchange;
                                            //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message
                                            //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange();
                                            //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                            //session.Handshake.MessageSequence++;
                                        }
                                        SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence);
                                        session.Handshake.MessageSequence++;
                                    }
                                }
                            }
                            break;
                        case THandshakeType.ServerHello:
                            break;
                        case THandshakeType.HelloVerifyRequest:
                            break;
                        case THandshakeType.Certificate:
                            Certificate clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509);
                            if (clientCertificate.CertChain.Count > 0)
                            {
                                session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER);                                
                            }
                            session.Handshake.UpdateHandshakeHash(data);
                            break;
                        case THandshakeType.ServerKeyExchange:
                            break;
                        case THandshakeType.CertificateRequest:
                            break;
                        case THandshakeType.ServerHelloDone:
                            break;
                        case THandshakeType.CertificateVerify:
                            CertificateVerify certificateVerify = CertificateVerify.Deserialise(stream, session.Version);
                            session.Handshake.UpdateHandshakeHash(data);
                            break;
                        case THandshakeType.ClientKeyExchange:
                            if ((session == null) || (session.Handshake.KeyExchange == null))
                            {

                            }
                            else
                            {
                                session.Handshake.UpdateHandshakeHash(data);
                                byte[] preMasterSecret = null;
                                if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                {
                                    ECDHEClientKeyExchange clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream);
                                    if (clientKeyExchange != null)
                                    {
                                        ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                        preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                    }
                                }
                                else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                                {
                                    ECDHEPSKClientKeyExchange clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream);
                                    if (clientKeyExchange != null)
                                    {
                                        session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                                        byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                                        if (psk == null)
                                        {
                                            psk = _ValidatePSK(clientKeyExchange.PSKIdentity);
                                            if (psk != null)
                                            {
                                                _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                            }
                                        }

                                        if (psk != null)
                                        {
                                            ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                            byte[] otherSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                            preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);

                                        }
                                    }
                                }
                                else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                                {
                                    PSKClientKeyExchange clientKeyExchange = PSKClientKeyExchange.Deserialise(stream);
                                    if (clientKeyExchange != null)
                                    {
                                        session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                                        byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                                        if (psk == null)
                                        {
                                            psk = _ValidatePSK(clientKeyExchange.PSKIdentity);
                                            if (psk != null)
                                            {
                                                _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                            }
                                        }

                                        if (psk != null)
                                        {
                                            ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                            byte[] otherSecret = new byte[psk.Length];
                                            preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                                        }
                                    }
                                }

                                if (preMasterSecret != null)
                                {
                                    //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange);
                                    //TLSUtils.AssignCipher(session);

                                    session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake);
                                }
                            }
                            break;
                        case THandshakeType.Finished:
                            Finished finished = Finished.Deserialise(stream);
                            if (session != null)
                            {
                                byte[] handshakeHash = session.Handshake.GetHash();
                                byte[] calculatedVerifyData = TLSUtils.GetVerifyData(session.Version,session.Handshake,false, true, handshakeHash);
#if DEBUG
                                Console.Write("Handshake Hash:");
                                TLSUtils.WriteToConsole(handshakeHash);
                                Console.Write("Sent Verify:");
                                TLSUtils.WriteToConsole(finished.VerifyData);
                                Console.Write("Calc Verify:");
                                TLSUtils.WriteToConsole(calculatedVerifyData);
#endif
                                if (TLSUtils.ByteArrayCompare(finished.VerifyData, calculatedVerifyData))
                                {

                                    SendChangeCipherSpec(session);
                                    session.Handshake.UpdateHandshakeHash(data);
                                    handshakeHash = session.Handshake.GetHash();
                                    Finished serverFinished = new Finished();
                                    serverFinished.VerifyData = TLSUtils.GetVerifyData(session.Version,session.Handshake,false, false, handshakeHash);
                                    SendResponse(session, serverFinished, session.Handshake.MessageSequence);
                                    session.Handshake.MessageSequence++;
                                }
                                else
                                {
                                    throw new Exception();
                                }
                            }
                            break;
                        default:
                            break;
                    }
                }
            }
        }
예제 #10
0
		private DTLSRecord CreateRecord(Session session, IHandshakeMessage handshakeMessage, ushort messageSequence)
		{
			int size = handshakeMessage.CalculateSize(session.Version);
			DTLSRecord record = new DTLSRecord();
			record.RecordType = TRecordType.Handshake;
			record.Epoch = session.Epoch;
			record.SequenceNumber = session.NextSequenceNumber();
			record.Fragment = new byte[HandshakeRecord.RECORD_OVERHEAD + size];
			if (session.Version != null)
				record.Version = session.Version;
			HandshakeRecord handshakeRecord = new HandshakeRecord();
			handshakeRecord.MessageType = handshakeMessage.MessageType;
			handshakeRecord.MessageSeq = messageSequence;
			handshakeRecord.Length = (uint)size;
			handshakeRecord.FragmentLength = (uint)size;
			using (MemoryStream stream = new MemoryStream(record.Fragment))
			{
				handshakeRecord.Serialise(stream);
				handshakeMessage.Serialise(stream, session.Version);
			}
			if (handshakeMessage.MessageType != THandshakeType.HelloVerifyRequest)
			{
				session.Handshake.UpdateHandshakeHash(record.Fragment);
			}
			if (session.Cipher != null)
			{
				long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
				record.Fragment = session.Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
			}
			return record;
		}
예제 #11
0
		private void SendChangeCipherSpec(Session session)
		{
			int size = 1;
			int responseSize = DTLSRecord.RECORD_OVERHEAD + size;
			byte[] response = new byte[responseSize];
			DTLSRecord record = new DTLSRecord();
			record.RecordType = TRecordType.ChangeCipherSpec;
			record.Epoch = session.Epoch;
			record.SequenceNumber = session.NextSequenceNumber();
			record.Fragment = new byte[size];
			record.Fragment[0] = 1;
			if (session.Version != null)
				record.Version = session.Version;
			using (MemoryStream stream = new MemoryStream(response))
			{
				record.Serialise(stream);
			}
            SocketAsyncEventArgs parameters = new SocketAsyncEventArgs()
            {
                RemoteEndPoint = session.RemoteEndPoint
            };
            parameters.SetBuffer(response, 0, responseSize);
            _Socket.SendToAsync(parameters);
            session.ChangeEpoch();
		}
예제 #12
0
        private void ProcessRecord(SocketAddress address, Session session, DTLSRecord record)
        {
            try
            {
#if DEBUG
            Console.WriteLine(record.RecordType.ToString());
#endif
                switch (record.RecordType)
                {
                    case TRecordType.ChangeCipherSpec:
                        if (session != null)
                        {
                            session.ClientEpoch++;
                            session.ClientSequenceNumber = 0;
                            session.SetEncyptChange(record);
                        }
                        break;
                    case TRecordType.Alert:
                        if (session != null)
                        {
                            AlertRecord alertRecord;
                            try
                            {
                                if (session.Cipher == null)
                                {
                                    alertRecord = AlertRecord.Deserialise(record.Fragment);
                                }
                                else
                                {
                                    long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                                    byte[] data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Alert, record.Fragment, 0, record.Fragment.Length);
                                    alertRecord = AlertRecord.Deserialise(data);
                                }
                            }
                            catch
                            {
                                alertRecord = new AlertRecord();
                                alertRecord.AlertLevel = TAlertLevel.Fatal;
                            }
                            if (alertRecord.AlertLevel == TAlertLevel.Fatal)
                                _Sessions.Remove(session, address);
                            else if ((alertRecord.AlertLevel == TAlertLevel.Warning) || (alertRecord.AlertDescription == TAlertDescription.CloseNotify))
                            {
                                if (alertRecord.AlertDescription == TAlertDescription.CloseNotify)
                                    SendAlert(session, address, TAlertLevel.Warning, TAlertDescription.CloseNotify);
                                _Sessions.Remove(session, address);
                            }
                        }
                        break;
                    case TRecordType.Handshake:
                        _Handshake.ProcessHandshake(record);
                        break;
                    case TRecordType.ApplicationData:
                        if (session != null)
                        {
                            if (session.Cipher != null)
                            {
                                long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                                byte[] data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.ApplicationData, record.Fragment, 0, record.Fragment.Length);
                                if (DataReceived != null)
                                {
                                    DataReceived(record.RemoteEndPoint, data);
                                }
                            }
                        }
                        break;
                    default:
                        break;
                }
            }
#if DEBUG
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
#else
            catch
            {
#endif
                SendAlert(session, address, TAlertLevel.Fatal, TAlertDescription.InternalError);
            }
        }
예제 #13
0
		private void SendResponse(Session session, IHandshakeMessage handshakeMessage, ushort messageSequence)
		{
			int size = handshakeMessage.CalculateSize(session.Version);
			int maxPayloadSize = _MaxPacketSize - DTLSRecord.RECORD_OVERHEAD + HandshakeRecord.RECORD_OVERHEAD;
			if (size > maxPayloadSize)
			{

			}
			else
			{
				
				DTLSRecord record = new DTLSRecord();
				record.RecordType = TRecordType.Handshake;
				record.Epoch = session.Epoch;
				record.SequenceNumber = session.NextSequenceNumber();
				record.Fragment = new byte[HandshakeRecord.RECORD_OVERHEAD + size];
				if (session.Version != null)
					record.Version = session.Version;
				HandshakeRecord handshakeRecord = new HandshakeRecord();
				handshakeRecord.MessageType = handshakeMessage.MessageType;
				handshakeRecord.MessageSeq = messageSequence;
				handshakeRecord.Length = (uint)size;
				handshakeRecord.FragmentLength = (uint)size;
				using (MemoryStream stream = new MemoryStream(record.Fragment))
				{
					handshakeRecord.Serialise(stream);
					handshakeMessage.Serialise(stream, session.Version);
				}
                if (handshakeMessage.MessageType != THandshakeType.HelloVerifyRequest)
                {
                    session.Handshake.UpdateHandshakeHash(record.Fragment);
                }
                int responseSize = DTLSRecord.RECORD_OVERHEAD + HandshakeRecord.RECORD_OVERHEAD + size;
                if (session.Cipher != null)
                {
                   long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;                   
                   record.Fragment = session.Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                   responseSize = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;
                }
                byte[] response = new byte[responseSize];
				using (MemoryStream stream = new MemoryStream(response))
				{
					record.Serialise(stream);
				}
                SocketAsyncEventArgs parameters = new SocketAsyncEventArgs()
                {
                    RemoteEndPoint = session.RemoteEndPoint
                };
                parameters.SetBuffer(response, 0, responseSize);
                _Socket.SendToAsync(parameters);
			}


		}
예제 #14
0
 internal void SetEncyptChange(DTLSRecord record)
 {
     EncyptedClientEpoch = (ushort)(record.Epoch + 1);
 }
예제 #15
0
        private void ProcessHandshake(DTLSRecord record)
        {
            byte[] data;
            if (_EncyptedServerEpoch.HasValue && (_EncyptedServerEpoch.Value == record.Epoch))
            {
                int count = 0;
                while ((_Cipher == null) && (count < 500))
                {
                    System.Threading.Thread.Sleep(10);
                    count++;
                }

                if (_Cipher == null)
                {
                    throw new Exception();
                }


                if (_Cipher != null)
                {
                    long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                    data = _Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                }
                else
                {
                    data = record.Fragment;
                }
            }
            else
            {
                data = record.Fragment;
            }

            using (MemoryStream stream = new MemoryStream(data))
            {
                HandshakeRecord handshakeRecord = HandshakeRecord.Deserialise(stream);
                if (handshakeRecord != null)
                {
#if DEBUG
                    Console.WriteLine(handshakeRecord.MessageType.ToString());
#endif
                    switch (handshakeRecord.MessageType)
                    {
                    case THandshakeType.HelloRequest:
                        break;

                    case THandshakeType.ClientHello:
                        break;

                    case THandshakeType.ServerHello:
                        ServerHello serverHello = ServerHello.Deserialise(stream);
                        if (serverHello != null)
                        {
                            _ServerEpoch = record.Epoch;
                            _HandshakeInfo.UpdateHandshakeHash(data);
                            _HandshakeInfo.CipherSuite  = (TCipherSuite)serverHello.CipherSuite;
                            _HandshakeInfo.ServerRandom = serverHello.Random;
                            Version version = SupportedVersion;
                            if (serverHello.ServerVersion < version)
                            {
                                version = serverHello.ServerVersion;
                            }
                            _Version = version;
                        }
                        break;

                    case THandshakeType.HelloVerifyRequest:
                        HelloVerifyRequest helloVerifyRequest = HelloVerifyRequest.Deserialise(stream);
                        if (helloVerifyRequest != null)
                        {
                            _Version = helloVerifyRequest.ServerVersion;
                            SendHello(helloVerifyRequest.Cookie);
                        }
                        break;

                    case THandshakeType.Certificate:
                        _HandshakeInfo.UpdateHandshakeHash(data);
                        break;

                    case THandshakeType.ServerKeyExchange:
                        _HandshakeInfo.UpdateHandshakeHash(data);
                        TKeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(_HandshakeInfo.CipherSuite);
                        byte[]       preMasterSecret = null;
                        IKeyExchange keyExchange     = null;
                        if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                        {
                            ECDHEServerKeyExchange serverKeyExchange = ECDHEServerKeyExchange.Deserialise(stream, _Version);
                            ECDHEKeyExchange       keyExchangeECDHE  = new ECDHEKeyExchange
                            {
                                CipherSuite          = _HandshakeInfo.CipherSuite,
                                Curve                = serverKeyExchange.EllipticCurve,
                                KeyExchangeAlgorithm = keyExchangeAlgorithm,
                                ClientRandom         = _HandshakeInfo.ClientRandom,
                                ServerRandom         = _HandshakeInfo.ServerRandom
                            };
                            keyExchangeECDHE.GenerateEphemeralKey();
                            ECDHEClientKeyExchange clientKeyExchange = new ECDHEClientKeyExchange(keyExchangeECDHE.PublicKey);
                            _ClientKeyExchange = clientKeyExchange;
                            preMasterSecret    = keyExchangeECDHE.GetPreMasterSecret(serverKeyExchange.PublicKeyBytes);
                            keyExchange        = keyExchangeECDHE;
                        }
                        else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                        {
                            ECDHEPSKServerKeyExchange serverKeyExchange = ECDHEPSKServerKeyExchange.Deserialise(stream, _Version);
                            ECDHEKeyExchange          keyExchangeECDHE  = new ECDHEKeyExchange
                            {
                                CipherSuite          = _HandshakeInfo.CipherSuite,
                                Curve                = serverKeyExchange.EllipticCurve,
                                KeyExchangeAlgorithm = keyExchangeAlgorithm,
                                ClientRandom         = _HandshakeInfo.ClientRandom,
                                ServerRandom         = _HandshakeInfo.ServerRandom
                            };
                            keyExchangeECDHE.GenerateEphemeralKey();
                            ECDHEPSKClientKeyExchange clientKeyExchange = new ECDHEPSKClientKeyExchange(keyExchangeECDHE.PublicKey);
                            if (serverKeyExchange.PSKIdentityHint != null)
                            {
                                byte[] key = _PSKIdentities.GetKey(serverKeyExchange.PSKIdentityHint);
                                if (key != null)
                                {
                                    _PSKIdentity = new PSKIdentity()
                                    {
                                        Identity = serverKeyExchange.PSKIdentityHint, Key = key
                                    }
                                }
                                ;
                            }
                            if (_PSKIdentity == null)
                            {
                                _PSKIdentity = _PSKIdentities.GetRandom();
                            }
                            clientKeyExchange.PSKIdentity = _PSKIdentity.Identity;
                            _ClientKeyExchange            = clientKeyExchange;
                            byte[] otherSecret = keyExchangeECDHE.GetPreMasterSecret(serverKeyExchange.PublicKeyBytes);
                            preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, _PSKIdentity.Key);
                            keyExchange     = keyExchangeECDHE;
                        }
                        else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                        {
                            PSKServerKeyExchange serverKeyExchange = PSKServerKeyExchange.Deserialise(stream, _Version);
                            PSKClientKeyExchange clientKeyExchange = new PSKClientKeyExchange();
                            if (serverKeyExchange.PSKIdentityHint != null)
                            {
                                byte[] key = _PSKIdentities.GetKey(serverKeyExchange.PSKIdentityHint);
                                if (key != null)
                                {
                                    _PSKIdentity = new PSKIdentity()
                                    {
                                        Identity = serverKeyExchange.PSKIdentityHint, Key = key
                                    }
                                }
                                ;
                            }
                            if (_PSKIdentity == null)
                            {
                                _PSKIdentity = _PSKIdentities.GetRandom();
                            }
                            byte[] otherSecret = new byte[_PSKIdentity.Key.Length];
                            clientKeyExchange.PSKIdentity = _PSKIdentity.Identity;
                            _ClientKeyExchange            = clientKeyExchange;
                            preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, _PSKIdentity.Key);
                        }
                        _Cipher = TLSUtils.AssignCipher(preMasterSecret, true, _Version, _HandshakeInfo);

                        break;

                    case THandshakeType.CertificateRequest:
                        _HandshakeInfo.UpdateHandshakeHash(data);
                        _SendCertificate = true;
                        break;

                    case THandshakeType.ServerHelloDone:
                        _HandshakeInfo.UpdateHandshakeHash(data);
                        if (_Cipher == null)
                        {
                            keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(_HandshakeInfo.CipherSuite);
                            if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                            {
                                PSKClientKeyExchange clientKeyExchange = new PSKClientKeyExchange();
                                _PSKIdentity = _PSKIdentities.GetRandom();
                                byte[] otherSecret = new byte[_PSKIdentity.Key.Length];
                                clientKeyExchange.PSKIdentity = _PSKIdentity.Identity;
                                _ClientKeyExchange            = clientKeyExchange;
                                preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, _PSKIdentity.Key);
                                _Cipher         = TLSUtils.AssignCipher(preMasterSecret, true, _Version, _HandshakeInfo);
                            }
                        }

                        if (_SendCertificate)
                        {
                            SendHandshakeMessage(_Certificate, false);
                        }
                        SendHandshakeMessage(_ClientKeyExchange, false);
                        if (_SendCertificate)
                        {
                            CertificateVerify certificateVerify = new CertificateVerify();
                            byte[]            signatureHash     = _HandshakeInfo.GetHash();
                            certificateVerify.SignatureHashAlgorithm = new SignatureHashAlgorithm()
                            {
                                Signature = TSignatureAlgorithm.ECDSA, Hash = THashAlgorithm.SHA256
                            };
                            certificateVerify.Signature = TLSUtils.Sign(_PrivateKey, true, _Version, _HandshakeInfo, certificateVerify.SignatureHashAlgorithm, signatureHash);
                            SendHandshakeMessage(certificateVerify, false);
                        }
                        SendChangeCipherSpec();
                        byte[]   handshakeHash = _HandshakeInfo.GetHash();
                        Finished finished      = new Finished
                        {
                            VerifyData = TLSUtils.GetVerifyData(_Version, _HandshakeInfo, true, true, handshakeHash)
                        };
                        SendHandshakeMessage(finished, true);
#if DEBUG
                        Console.Write($"Handshake Hash: {TLSUtils.WriteToString(handshakeHash)}");
                        Console.Write($"Sent Verify: {TLSUtils.WriteToString(finished.VerifyData)}");
#endif
                        break;

                    case THandshakeType.CertificateVerify:
                        break;

                    case THandshakeType.ClientKeyExchange:
                        break;

                    case THandshakeType.Finished:
                        Finished serverFinished = Finished.Deserialise(stream);
                        handshakeHash = _HandshakeInfo.GetHash();
                        byte[] calculatedVerifyData = TLSUtils.GetVerifyData(_Version, _HandshakeInfo, true, false, handshakeHash);
#if DEBUG
                        Console.Write("$Recieved Verify: {TLSUtils.WriteToString(serverFinished.VerifyData)}");
                        Console.Write($"Calc Verify: {TLSUtils.WriteToString(calculatedVerifyData)}");
#endif
                        if (TLSUtils.ByteArrayCompare(serverFinished.VerifyData, calculatedVerifyData))
                        {
#if DEBUG
                            Console.WriteLine("Handshake Complete");
#endif
                            _Connected.Set();
                        }
                        break;

                    default:
                        break;
                    }
                }
            }
        }
예제 #16
0
        public void ProcessHandshake(DTLSRecord record)
        {
            if (record == null)
            {
                throw new ArgumentNullException(nameof(record));
            }

            var address = record.RemoteEndPoint.Serialize();
            var session = this.Sessions.GetSession(address);
            var data    = record.Fragment;

            if ((session != null) && session.IsEncypted(record))
            {
                var count = 0;
                while ((session.Cipher == null) && (count < 50))
                {
                    System.Threading.Thread.Sleep(10);
                    count++;
                }

                if (session.Cipher == null)
                {
                    throw new Exception("Need Cipher for Encrypted Session");
                }

                var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
            }

            using (var stream = new MemoryStream(data))
            {
                var handshakeRecord = HandshakeRecord.Deserialise(stream);
                switch (handshakeRecord.MessageType)
                {
                case THandshakeType.HelloRequest:
                {
                    //HelloReq
                    break;
                }

                case THandshakeType.ClientHello:
                {
                    var clientHello = ClientHello.Deserialise(stream);
                    var cookie      = clientHello.CalculateCookie(record.RemoteEndPoint, this._HelloSecret);

                    if (clientHello.Cookie == null)
                    {
                        var vers = clientHello.ClientVersion;
                        if (this.ServerVersion < vers)
                        {
                            vers = this.ServerVersion;
                        }

                        if (session == null)
                        {
                            session = new Session
                            {
                                SessionID      = Guid.NewGuid(),
                                RemoteEndPoint = record.RemoteEndPoint,
                                Version        = vers
                            };
                            this.Sessions.AddSession(address, session);
                        }
                        else
                        {
                            session.Reset();
                            session.Version = vers;
                        }

                        session.ClientEpoch          = record.Epoch;
                        session.ClientSequenceNumber = record.SequenceNumber;

                        var helloVerifyRequest = new HelloVerifyRequest
                        {
                            Cookie        = cookie,
                            ServerVersion = ServerVersion
                        };

                        this.SendResponse(session, helloVerifyRequest, 0);
                        break;
                    }

                    if (session != null && session.Cipher != null && !session.IsEncypted(record))
                    {
                        session.Reset();
                    }

                    if (!clientHello.Cookie.SequenceEqual(cookie))
                    {
                        break;
                    }

                    var version = clientHello.ClientVersion;
                    if (this.ServerVersion < version)
                    {
                        version = this.ServerVersion;
                    }

                    if (clientHello.SessionID == null)
                    {
                        if (session == null)
                        {
                            session = new Session();
                            session.NextSequenceNumber();
                            session.SessionID      = Guid.NewGuid();
                            session.RemoteEndPoint = record.RemoteEndPoint;
                            this.Sessions.AddSession(address, session);
                        }
                    }
                    else
                    {
                        if (clientHello.SessionID.Length >= 16)
                        {
                            session = this.Sessions.GetSession(new Guid(clientHello.SessionID.Take(16).ToArray()));
                        }

                        if (session == null)
                        {
                            //need to Find Session
                            session = new Session
                            {
                                SessionID = Guid.NewGuid()
                            };
                            session.NextSequenceNumber();
                            session.RemoteEndPoint = record.RemoteEndPoint;
                            this.Sessions.AddSession(address, session);
                        }
                    }

                    session.Version = version;
                    var cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL;
                    foreach (TCipherSuite item in clientHello.CipherSuites)
                    {
                        if (this._SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, this.PrivateKey, this._PSKIdentities, this._ValidatePSK != null))
                        {
                            cipherSuite = item;
                            break;
                        }
                    }

                    var clientSessionID = new byte[32];
                    var temp            = session.SessionID.ToByteArray();
                    Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16);
                    Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16);

                    var serverHello = new ServerHello
                    {
                        SessionID     = clientSessionID,    // session.SessionID.ToByteArray();
                        Random        = new RandomData(),
                        CipherSuite   = (ushort)cipherSuite,
                        ServerVersion = session.Version
                    };
                    serverHello.Random.Generate();

                    session.Handshake.UpdateHandshakeHash(data);
                    session.Handshake.CipherSuite  = cipherSuite;
                    session.Handshake.ClientRandom = clientHello.Random;
                    session.Handshake.ServerRandom = serverHello.Random;

                    var hash  = THashAlgorithm.SHA256;
                    var curve = TEllipticCurve.secp521r1;
                    if (clientHello.Extensions != null)
                    {
                        foreach (var extension in clientHello.Extensions)
                        {
                            if (extension.SpecificExtension is ClientCertificateTypeExtension)
                            {
                                var clientCertificateType = extension.SpecificExtension as ClientCertificateTypeExtension;
                                //TCertificateType certificateType = TCertificateType.Unknown;
                                //foreach (TCertificateType item in clientCertificateType.CertificateTypes)
                                //{

                                //}
                                //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType));
                            }
                            else if (extension.SpecificExtension is EllipticCurvesExtension)
                            {
                                var ellipticCurves = extension.SpecificExtension as EllipticCurvesExtension;
                                foreach (var item in ellipticCurves.SupportedCurves)
                                {
                                    if (EllipticCurveFactory.SupportedCurve(item))
                                    {
                                        curve = item;
                                        break;
                                    }
                                }
                            }
                            else if (extension.SpecificExtension is ServerCertificateTypeExtension)
                            {
                                //serverHello.AddExtension();
                            }
                            else if (extension.SpecificExtension is SignatureAlgorithmsExtension)
                            {
                                var signatureAlgorithms = extension.SpecificExtension as SignatureAlgorithmsExtension;
                                foreach (var item in signatureAlgorithms.SupportedAlgorithms)
                                {
                                    if (item.Signature == TSignatureAlgorithm.ECDSA)
                                    {
                                        hash = item.Hash;
                                        break;
                                    }
                                }
                            }
                        }
                    }

                    var keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite);
                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                    {
                        var pointFormatsExtension = new EllipticCurvePointFormatsExtension();
                        pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed);
                        serverHello.AddExtension(pointFormatsExtension);
                    }

                    session.Handshake.MessageSequence = 1;
                    this.SendResponse(session, serverHello, session.Handshake.MessageSequence);
                    session.Handshake.MessageSequence++;

                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                    {
                        if (this.Certificate != null)
                        {
                            this.SendResponse(session, this.Certificate, session.Handshake.MessageSequence);
                            session.Handshake.MessageSequence++;
                        }

                        var keyExchange = new ECDHEKeyExchange
                        {
                            Curve = curve,
                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                            ClientRandom         = clientHello.Random,
                            ServerRandom         = serverHello.Random
                        };

                        keyExchange.GenerateEphemeralKey();
                        session.Handshake.KeyExchange = keyExchange;
                        if (session.Version == DTLSRecord.DefaultVersion)
                        {
                            hash = THashAlgorithm.SHA1;
                        }

                        var serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, this.PrivateKey);
                        this.SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);

                        session.Handshake.MessageSequence++;
                        if (this._RequireClientCertificate)
                        {
                            var certificateRequest = new CertificateRequest();
                            certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign);
                            certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm()
                                {
                                    Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA
                                });
                            this.SendResponse(session, certificateRequest, session.Handshake.MessageSequence);
                            session.Handshake.MessageSequence++;
                        }
                    }
                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                    {
                        var keyExchange = new ECDHEKeyExchange
                        {
                            Curve = curve,
                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                            ClientRandom         = clientHello.Random,
                            ServerRandom         = serverHello.Random
                        };

                        keyExchange.GenerateEphemeralKey();
                        session.Handshake.KeyExchange = keyExchange;
                        var serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange);
                        this.SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                        session.Handshake.MessageSequence++;
                    }
                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                    {
                        var keyExchange = new PSKKeyExchange
                        {
                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                            ClientRandom         = clientHello.Random,
                            ServerRandom         = serverHello.Random
                        };
                        session.Handshake.KeyExchange = keyExchange;
                        //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message
                        //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange();
                        //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                        //session.Handshake.MessageSequence++;
                    }
                    this.SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence);
                    session.Handshake.MessageSequence++;
                    break;
                }

                case THandshakeType.ServerHello:
                {
                    break;
                }

                case THandshakeType.HelloVerifyRequest:
                {
                    break;
                }

                case THandshakeType.Certificate:
                {
                    var clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509);
                    if (clientCertificate.CertChain.Count > 0)
                    {
                        session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER);
                    }
                    session.Handshake.UpdateHandshakeHash(data);
                    break;
                }

                case THandshakeType.ServerKeyExchange:
                {
                    break;
                }

                case THandshakeType.CertificateRequest:
                {
                    break;
                }

                case THandshakeType.ServerHelloDone:
                {
                    break;
                }

                case THandshakeType.CertificateVerify:
                {
                    var certificateVerify = CertificateVerify.Deserialise(stream, session.Version);
                    session.Handshake.UpdateHandshakeHash(data);
                }
                break;

                case THandshakeType.ClientKeyExchange:
                {
                    if ((session == null) || (session.Handshake.KeyExchange == null))
                    {
                        break;
                    }

                    session.Handshake.UpdateHandshakeHash(data);
                    byte[] preMasterSecret = null;
                    if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                    {
                        var clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream);
                        if (clientKeyExchange != null)
                        {
                            var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                            preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                        }
                    }
                    else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                    {
                        var clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream);
                        if (clientKeyExchange != null)
                        {
                            session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                            var psk = this._PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                            if (psk == null)
                            {
                                psk = this._ValidatePSK(clientKeyExchange.PSKIdentity);
                                if (psk != null)
                                {
                                    this._PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                }
                            }

                            if (psk != null)
                            {
                                var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                var otherSecret   = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                            }
                        }
                    }
                    else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                    {
                        var clientKeyExchange = PSKClientKeyExchange.Deserialise(stream);
                        if (clientKeyExchange != null)
                        {
                            session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                            var psk = this._PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                            if (psk == null)
                            {
                                psk = this._ValidatePSK(clientKeyExchange.PSKIdentity);
                                if (psk != null)
                                {
                                    this._PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                }
                            }

                            if (psk != null)
                            {
                                var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                var otherSecret   = new byte[psk.Length];
                                preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                            }
                        }
                    }

                    if (preMasterSecret != null)
                    {
                        //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange);
                        //TLSUtils.AssignCipher(session);

                        session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake);
                    }
                    break;
                }

                case THandshakeType.Finished:
                {
                    var finished = Finished.Deserialise(stream);
                    if (session == null)
                    {
                        break;
                    }

                    var handshakeHash        = session.Handshake.GetHash(session.Version);
                    var calculatedVerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, true, handshakeHash);
                    if (!finished.VerifyData.SequenceEqual(calculatedVerifyData))
                    {
                        throw new Exception();
                    }

                    this.SendChangeCipherSpec(session);
                    session.Handshake.UpdateHandshakeHash(data);
                    handshakeHash = session.Handshake.GetHash(session.Version);
                    var serverFinished = new Finished
                    {
                        VerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, false, handshakeHash)
                    };
                    this.SendResponse(session, serverFinished, session.Handshake.MessageSequence);
                    session.Handshake.MessageSequence++;
                    break;
                }

                default:
                    break;
                }
            }
        }
예제 #17
0
        private void ProcessRecord(DTLSRecord record)
        {
            try
            {
#if DEBUG
                Console.WriteLine(record.RecordType.ToString());
#endif
                switch (record.RecordType)
                {
                case TRecordType.ChangeCipherSpec:
                    if (_ServerEpoch.HasValue)
                    {
                        _ServerEpoch++;
                        _ServerSequenceNumber = 0;
                        _EncyptedServerEpoch  = _ServerEpoch;
                    }
                    break;

                case TRecordType.Alert:
                    AlertRecord alertRecord;
                    try
                    {
                        if ((_Cipher == null) || (!_EncyptedServerEpoch.HasValue))
                        {
                            alertRecord = AlertRecord.Deserialise(record.Fragment);
                        }
                        else
                        {
                            long   sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                            byte[] data           = _Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Alert, record.Fragment, 0, record.Fragment.Length);
                            alertRecord = AlertRecord.Deserialise(data);
                        }
                    }
                    catch
                    {
                        alertRecord            = new AlertRecord();
                        alertRecord.AlertLevel = TAlertLevel.Fatal;
                    }
                    if (alertRecord.AlertLevel == TAlertLevel.Fatal)
                    {
                        _Connected.Set();
                        //Terminate
                    }
                    else if ((alertRecord.AlertLevel == TAlertLevel.Warning) || (alertRecord.AlertDescription == TAlertDescription.CloseNotify))
                    {
                        if (alertRecord.AlertDescription == TAlertDescription.CloseNotify)
                        {
                            SendAlert(TAlertLevel.Warning, TAlertDescription.CloseNotify);
                            _Connected.Set();
                        }
                        //_Sessions.Remove(session, address);
                    }
                    break;

                case TRecordType.Handshake:
                    ProcessHandshake(record);
                    _ServerSequenceNumber = record.SequenceNumber + 1;
                    break;

                case TRecordType.ApplicationData:
                    if (_Cipher != null)
                    {
                        long   sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                        byte[] data           = _Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.ApplicationData, record.Fragment, 0, record.Fragment.Length);
                        if (DataReceived != null)
                        {
                            DataReceived(record.RemoteEndPoint, data);
                        }
                    }
                    _ServerSequenceNumber = record.SequenceNumber + 1;
                    break;

                default:
                    break;
                }
            }
#if DEBUG
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
#else
            catch
            {
#endif
            }
        }
예제 #18
0
        private void SendAlert(TAlertLevel alertLevel, TAlertDescription alertDescription)
        {
            DTLSRecord record = new DTLSRecord();
            record.RecordType = TRecordType.Alert;
            record.Epoch = _Epoch;
            record.SequenceNumber = NextSequenceNumber();
            if (_Version != null)
                record.Version = _Version;
            long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;

            byte[] data = new byte[2];
            data[0] = (byte)alertLevel;
            data[1] = (byte)alertDescription;
            if (_Cipher == null)
                record.Fragment = data;
            else
                record.Fragment = _Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.ApplicationData, data, 0, data.Length);
            int responseSize = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;
            byte[] response = new byte[responseSize];
            using (MemoryStream stream = new MemoryStream(response))
            {
                record.Serialise(stream);
            }
            SocketAsyncEventArgs parameters = new SocketAsyncEventArgs()
            {
                RemoteEndPoint = _ServerEndPoint
            };
            parameters.SetBuffer(response, 0, responseSize);
            _Socket.SendToAsync(parameters);
        }
예제 #19
0
        private async Task _ProcessHandshakeAsync(DTLSRecord record)
        {
            if (record == null)
            {
                throw new ArgumentNullException(nameof(record));
            }

            var data = record.Fragment;

            if (this._EncyptedServerEpoch == record.Epoch)
            {
                var count = 0;
                while ((this._Cipher == null) && (count < 500))
                {
                    await Task.Delay(10).ConfigureAwait(false);

                    count++;
                }

                if (this._Cipher == null)
                {
                    throw new Exception("Need Cipher for Encrypted Session");
                }

                var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                data = this._Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
            }

            using (var tempStream = new MemoryStream(data))
            {
                var handshakeRec = HandshakeRecord.Deserialise(tempStream);
                if (handshakeRec.Length > (handshakeRec.FragmentLength + handshakeRec.FragmentOffset))
                {
                    this._IsFragment = true;
                    this._FragmentedRecordList.Add(data);
                    return;
                }
                else if (this._IsFragment)
                {
                    this._FragmentedRecordList.Add(data);
                    data = new byte[0];
                    foreach (var rec in this._FragmentedRecordList)
                    {
                        data = data.Concat(rec.Skip(HandshakeRecord.RECORD_OVERHEAD)).ToArray();
                    }

                    var tempHandshakeRec = new HandshakeRecord()
                    {
                        Length         = handshakeRec.Length,
                        MessageSeq     = handshakeRec.MessageSeq,
                        MessageType    = handshakeRec.MessageType,
                        FragmentLength = handshakeRec.Length,
                        FragmentOffset = 0
                    };

                    var tempHandshakeBytes = new byte[HandshakeRecord.RECORD_OVERHEAD];
                    using (var updateStream = new MemoryStream(tempHandshakeBytes))
                    {
                        tempHandshakeRec.Serialise(updateStream);
                    }

                    data = tempHandshakeBytes.Concat(data).ToArray();
                }
            }

            using (var stream = new MemoryStream(data))
            {
                var handshakeRecord = HandshakeRecord.Deserialise(stream);
                switch (handshakeRecord.MessageType)
                {
                case THandshakeType.HelloRequest:
                {
                    break;
                }

                case THandshakeType.ClientHello:
                {
                    break;
                }

                case THandshakeType.ServerHello:
                {
                    var serverHello = ServerHello.Deserialise(stream);
                    this._HandshakeInfo.UpdateHandshakeHash(data);
                    this._ServerEpoch = record.Epoch;
                    this._HandshakeInfo.CipherSuite  = (TCipherSuite)serverHello.CipherSuite;
                    this._HandshakeInfo.ServerRandom = serverHello.Random;
                    this._Version = serverHello.ServerVersion <= this._Version ? serverHello.ServerVersion : _SupportedVersion;
                    break;
                }

                case THandshakeType.HelloVerifyRequest:
                {
                    var helloVerifyRequest = HelloVerifyRequest.Deserialise(stream);
                    this._Version = helloVerifyRequest.ServerVersion;
                    await this._SendHelloAsync(helloVerifyRequest.Cookie).ConfigureAwait(false);

                    break;
                }

                case THandshakeType.Certificate:
                {
                    var cert = Certificate.Deserialise(stream, TCertificateType.X509);
                    this._HandshakeInfo.UpdateHandshakeHash(data);
                    this.ServerCertificate = cert.Cert;
                    break;
                }

                case THandshakeType.ServerKeyExchange:
                {
                    this._HandshakeInfo.UpdateHandshakeHash(data);
                    var          keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(this._HandshakeInfo.CipherSuite);
                    byte[]       preMasterSecret      = null;
                    IKeyExchange keyExchange          = null;
                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                    {
                        var serverKeyExchange = ECDHEServerKeyExchange.Deserialise(stream, this._Version);
                        var keyExchangeECDHE  = new ECDHEKeyExchange
                        {
                            CipherSuite          = this._HandshakeInfo.CipherSuite,
                            Curve                = serverKeyExchange.EllipticCurve,
                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                            ClientRandom         = this._HandshakeInfo.ClientRandom,
                            ServerRandom         = this._HandshakeInfo.ServerRandom
                        };
                        keyExchangeECDHE.GenerateEphemeralKey();
                        var clientKeyExchange = new ECDHEClientKeyExchange(keyExchangeECDHE.PublicKey);
                        this._ClientKeyExchange = clientKeyExchange;
                        preMasterSecret         = keyExchangeECDHE.GetPreMasterSecret(serverKeyExchange.PublicKeyBytes);
                        keyExchange             = keyExchangeECDHE;
                    }
                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                    {
                        var serverKeyExchange = ECDHEPSKServerKeyExchange.Deserialise(stream);
                        var keyExchangeECDHE  = new ECDHEKeyExchange
                        {
                            CipherSuite          = this._HandshakeInfo.CipherSuite,
                            Curve                = serverKeyExchange.EllipticCurve,
                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                            ClientRandom         = this._HandshakeInfo.ClientRandom,
                            ServerRandom         = this._HandshakeInfo.ServerRandom
                        };
                        keyExchangeECDHE.GenerateEphemeralKey();
                        var clientKeyExchange = new ECDHEPSKClientKeyExchange(keyExchangeECDHE.PublicKey);
                        if (serverKeyExchange.PSKIdentityHint != null)
                        {
                            var key = this.PSKIdentities.GetKey(serverKeyExchange.PSKIdentityHint);
                            if (key != null)
                            {
                                this._PSKIdentity = new PSKIdentity()
                                {
                                    Identity = serverKeyExchange.PSKIdentityHint, Key = key
                                };
                            }
                        }
                        if (this._PSKIdentity == null)
                        {
                            this._PSKIdentity = this.PSKIdentities.GetRandom();
                        }

                        clientKeyExchange.PSKIdentity = this._PSKIdentity.Identity;
                        this._ClientKeyExchange       = clientKeyExchange;
                        var otherSecret = keyExchangeECDHE.GetPreMasterSecret(serverKeyExchange.PublicKeyBytes);
                        preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, this._PSKIdentity.Key);
                        keyExchange     = keyExchangeECDHE;
                    }
                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                    {
                        var serverKeyExchange = PSKServerKeyExchange.Deserialise(stream);
                        var clientKeyExchange = new PSKClientKeyExchange();
                        if (serverKeyExchange.PSKIdentityHint != null)
                        {
                            var key = this.PSKIdentities.GetKey(serverKeyExchange.PSKIdentityHint);
                            if (key != null)
                            {
                                this._PSKIdentity = new PSKIdentity()
                                {
                                    Identity = serverKeyExchange.PSKIdentityHint, Key = key
                                };
                            }
                        }
                        if (this._PSKIdentity == null)
                        {
                            this._PSKIdentity = this.PSKIdentities.GetRandom();
                        }

                        var otherSecret = new byte[this._PSKIdentity.Key.Length];
                        clientKeyExchange.PSKIdentity = this._PSKIdentity.Identity;
                        this._ClientKeyExchange       = clientKeyExchange;
                        preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, this._PSKIdentity.Key);
                    }
                    this._Cipher = TLSUtils.AssignCipher(preMasterSecret, true, this._Version, this._HandshakeInfo);
                    break;
                }

                case THandshakeType.CertificateRequest:
                {
                    this._HandshakeInfo.UpdateHandshakeHash(data);
                    this._SendCertificate = true;
                    break;
                }

                case THandshakeType.ServerHelloDone:
                {
                    this._HandshakeInfo.UpdateHandshakeHash(data);
                    var keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(this._HandshakeInfo.CipherSuite);
                    if (this._Cipher == null)
                    {
                        if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                        {
                            var clientKeyExchange = new PSKClientKeyExchange();
                            this._PSKIdentity = this.PSKIdentities.GetRandom();
                            var otherSecret = new byte[this._PSKIdentity.Key.Length];
                            clientKeyExchange.PSKIdentity = this._PSKIdentity.Identity;
                            this._ClientKeyExchange       = clientKeyExchange;
                            var preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, this._PSKIdentity.Key);
                            this._Cipher = TLSUtils.AssignCipher(preMasterSecret, true, this._Version, this._HandshakeInfo);
                        }
                        else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.RSA)
                        {
                            var clientKeyExchange = new RSAClientKeyExchange();
                            this._ClientKeyExchange = clientKeyExchange;
                            var PreMasterSecret = TLSUtils.GetRsaPreMasterSecret(this._Version);
                            clientKeyExchange.PremasterSecret = TLSUtils.GetEncryptedRsaPreMasterSecret(this.ServerCertificate, PreMasterSecret);
                            this._Cipher = TLSUtils.AssignCipher(PreMasterSecret, true, this._Version, this._HandshakeInfo);
                        }
                        else
                        {
                            throw new NotImplementedException($"Key Exchange Algorithm {keyExchangeAlgorithm} Not Implemented");
                        }
                    }

                    if (this._SendCertificate)
                    {
                        await this._SendHandshakeMessageAsync(this._Certificate, false).ConfigureAwait(false);
                    }

                    await this._SendHandshakeMessageAsync(this._ClientKeyExchange, false).ConfigureAwait(false);

                    if (this._SendCertificate)
                    {
                        var signatureHashAlgorithm = new SignatureHashAlgorithm()
                        {
                            Signature = TSignatureAlgorithm.ECDSA, Hash = THashAlgorithm.SHA256
                        };
                        if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.RSA)
                        {
                            signatureHashAlgorithm = new SignatureHashAlgorithm()
                            {
                                Signature = TSignatureAlgorithm.RSA, Hash = THashAlgorithm.SHA1
                            };
                        }

                        var certVerify = new CertificateVerify
                        {
                            SignatureHashAlgorithm = signatureHashAlgorithm,
                            Signature = TLSUtils.Sign(this._PrivateKey, this._PrivateKeyRsa, true, this._Version, this._HandshakeInfo, signatureHashAlgorithm, this._HandshakeInfo.GetHash(this._Version))
                        };

                        await this._SendHandshakeMessageAsync(certVerify, false).ConfigureAwait(false);
                    }

                    await this._SendChangeCipherSpecAsync().ConfigureAwait(false);

                    var handshakeHash = this._HandshakeInfo.GetHash(this._Version);
                    var finished      = new Finished
                    {
                        VerifyData = TLSUtils.GetVerifyData(this._Version, this._HandshakeInfo, true, true, handshakeHash)
                    };

                    await this._SendHandshakeMessageAsync(finished, true).ConfigureAwait(false);

                    break;
                }

                case THandshakeType.NewSessionTicket:
                {
                    this._HandshakeInfo.UpdateHandshakeHash(data);
                    break;
                }

                case THandshakeType.CertificateVerify:
                {
                    break;
                }

                case THandshakeType.ClientKeyExchange:
                {
                    break;
                }

                case THandshakeType.Finished:
                {
                    var serverFinished       = Finished.Deserialise(stream);
                    var handshakeHash        = this._HandshakeInfo.GetHash(this._Version);
                    var calculatedVerifyData = TLSUtils.GetVerifyData(this._Version, this._HandshakeInfo, true, false, handshakeHash);
                    if (serverFinished.VerifyData.SequenceEqual(calculatedVerifyData))
                    {
                        this._ConnectionComplete = true;
                    }
                    break;
                }

                default:
                {
                    break;
                }
                }
            }

            this._IsFragment = false;
            this._FragmentedRecordList.RemoveAll(x => true);
        }
예제 #20
0
        public void Send(byte[] data)
        {
            try
            {
                DTLSRecord record = new DTLSRecord();
                record.RecordType = TRecordType.ApplicationData;
                record.Epoch = _Epoch;
                record.SequenceNumber = NextSequenceNumber();
                if (_Version != null)
                    record.Version = _Version;
                long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                record.Fragment = _Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.ApplicationData, data, 0, data.Length);
                int responseSize = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;
                byte[] response = new byte[responseSize];
                using (MemoryStream stream = new MemoryStream(response))
                {
                    record.Serialise(stream);
                }
                SocketAsyncEventArgs parameters = new SocketAsyncEventArgs()
                {
                    RemoteEndPoint = _ServerEndPoint
                };
                parameters.SetBuffer(response, 0, responseSize);
                if (_Socket != null)
                    _Socket.SendToAsync(parameters);
            }
#if DEBUG
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
#else
            catch
            {
#endif

            }
        }
예제 #21
0
        private void ProcessRecord(DTLSRecord record)
        {
            try
            {
#if DEBUG
            Console.WriteLine(record.RecordType.ToString());
#endif
                switch (record.RecordType)
                {
                    case TRecordType.ChangeCipherSpec:
                        if (_ServerEpoch.HasValue)
                        {
                            _ServerEpoch++;
                            _ServerSequenceNumber = 0;
                            _EncyptedServerEpoch = _ServerEpoch;
                        }
                        break;
                    case TRecordType.Alert:
                        AlertRecord alertRecord;
                        try
                        {
                            if ((_Cipher == null) || (!_EncyptedServerEpoch.HasValue))
                            {
                                alertRecord = AlertRecord.Deserialise(record.Fragment);
                            }
                            else
                            {
                                long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                                byte[] data = _Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Alert, record.Fragment, 0, record.Fragment.Length);
                                alertRecord = AlertRecord.Deserialise(data);
                            }
                        }
                        catch
                        {
                            alertRecord = new AlertRecord();
                            alertRecord.AlertLevel = TAlertLevel.Fatal;
                        }
                        if (alertRecord.AlertLevel == TAlertLevel.Fatal)
                        {
                            _Connected.Set();
                            //Terminate
                        }
                        else if ((alertRecord.AlertLevel == TAlertLevel.Warning) || (alertRecord.AlertDescription == TAlertDescription.CloseNotify))
                        {
                            if (alertRecord.AlertDescription == TAlertDescription.CloseNotify)
                            {
                                SendAlert(TAlertLevel.Warning, TAlertDescription.CloseNotify);
                                _Connected.Set();
                            }
                            //_Sessions.Remove(session, address);
                        }
                        break;
                    case TRecordType.Handshake:
                        ProcessHandshake(record);
                        _ServerSequenceNumber = record.SequenceNumber + 1;
                        break;
                    case TRecordType.ApplicationData:
                        if (_Cipher != null)
                        {
                            long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                            byte[] data = _Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.ApplicationData, record.Fragment, 0, record.Fragment.Length);
                            if (DataReceived != null)
                            {
                                DataReceived(record.RemoteEndPoint, data);
                            }
                        }
                        _ServerSequenceNumber = record.SequenceNumber + 1;
                        break;
                    default:
                        break;
                }
            }
#if DEBUG
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
#else
            catch
            {
#endif
            }
        }
예제 #22
0
        private void ProcessHandshake(DTLSRecord record)
        {
            byte[] data;
            if (_EncyptedServerEpoch.HasValue && (_EncyptedServerEpoch.Value == record.Epoch))
            {

                int count = 0;
                while ((_Cipher == null) && (count < 500))
                {
                    System.Threading.Thread.Sleep(10);
                    count++;
                }

                if (_Cipher == null)
                    throw new Exception();


                if (_Cipher != null)
                {
                    long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                    data = _Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                }
                else
                    data = record.Fragment;
            }
            else
                data = record.Fragment;

            using (MemoryStream stream = new MemoryStream(data))
            {
                HandshakeRecord handshakeRecord = HandshakeRecord.Deserialise(stream);
                if (handshakeRecord != null)
                {
#if DEBUG
                    Console.WriteLine(handshakeRecord.MessageType.ToString());
#endif
                    switch (handshakeRecord.MessageType)
                    {
                        case THandshakeType.HelloRequest:
                            break;
                        case THandshakeType.ClientHello:
                            break;
                        case THandshakeType.ServerHello:
                            ServerHello serverHello = ServerHello.Deserialise(stream);
                            if (serverHello != null)
                            {
                                _ServerEpoch = record.Epoch;
                                _HandshakeInfo.UpdateHandshakeHash(data);
                                _HandshakeInfo.CipherSuite = (TCipherSuite)serverHello.CipherSuite;
                                _HandshakeInfo.ServerRandom = serverHello.Random;
                                Version version = SupportedVersion;
                                if (serverHello.ServerVersion < version)
                                    version = serverHello.ServerVersion;
                                _Version = version;
                            }
                            break;
                        case THandshakeType.HelloVerifyRequest:
                            HelloVerifyRequest helloVerifyRequest = HelloVerifyRequest.Deserialise(stream);
                            if (helloVerifyRequest != null)
                            {
                                _Version = helloVerifyRequest.ServerVersion;
                                SendHello(helloVerifyRequest.Cookie);
                            }
                            break;
                        case THandshakeType.Certificate:
                            _HandshakeInfo.UpdateHandshakeHash(data);
                            break;
                        case THandshakeType.ServerKeyExchange:
                            _HandshakeInfo.UpdateHandshakeHash(data);
                            TKeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(_HandshakeInfo.CipherSuite);
                            byte[] preMasterSecret = null;
                            IKeyExchange keyExchange = null;
                            if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                            {
                                ECDHEServerKeyExchange serverKeyExchange = ECDHEServerKeyExchange.Deserialise(stream, _Version);
                                ECDHEKeyExchange keyExchangeECDHE = new ECDHEKeyExchange();
                                keyExchangeECDHE.CipherSuite = _HandshakeInfo.CipherSuite;
                                keyExchangeECDHE.Curve = serverKeyExchange.EllipticCurve;
                                keyExchangeECDHE.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                keyExchangeECDHE.ClientRandom = _HandshakeInfo.ClientRandom;
                                keyExchangeECDHE.ServerRandom = _HandshakeInfo.ServerRandom;
                                keyExchangeECDHE.GenerateEphemeralKey();
                                ECDHEClientKeyExchange clientKeyExchange = new ECDHEClientKeyExchange(keyExchangeECDHE.PublicKey);
                                _ClientKeyExchange = clientKeyExchange;
                                preMasterSecret = keyExchangeECDHE.GetPreMasterSecret(serverKeyExchange.PublicKeyBytes);
                                keyExchange = keyExchangeECDHE;
                            }
                            else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                            {
                                ECDHEPSKServerKeyExchange serverKeyExchange = ECDHEPSKServerKeyExchange.Deserialise(stream, _Version);
                                ECDHEKeyExchange keyExchangeECDHE = new ECDHEKeyExchange();
                                keyExchangeECDHE.CipherSuite = _HandshakeInfo.CipherSuite;
                                keyExchangeECDHE.Curve = serverKeyExchange.EllipticCurve;
                                keyExchangeECDHE.KeyExchangeAlgorithm = keyExchangeAlgorithm;
                                keyExchangeECDHE.ClientRandom = _HandshakeInfo.ClientRandom;
                                keyExchangeECDHE.ServerRandom = _HandshakeInfo.ServerRandom;
                                keyExchangeECDHE.GenerateEphemeralKey();
                                ECDHEPSKClientKeyExchange clientKeyExchange = new ECDHEPSKClientKeyExchange(keyExchangeECDHE.PublicKey);
                                if (serverKeyExchange.PSKIdentityHint != null)
                                {
                                    byte[] key = _PSKIdentities.GetKey(serverKeyExchange.PSKIdentityHint);
                                    if (key != null)
                                        _PSKIdentity = new PSKIdentity() { Identity = serverKeyExchange.PSKIdentityHint, Key = key };
                                }
                                if (_PSKIdentity == null)
                                    _PSKIdentity = _PSKIdentities.GetRandom();
                                clientKeyExchange.PSKIdentity = _PSKIdentity.Identity;
                                _ClientKeyExchange = clientKeyExchange;
                                byte[] otherSecret = keyExchangeECDHE.GetPreMasterSecret(serverKeyExchange.PublicKeyBytes);
                                preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, _PSKIdentity.Key);
                                keyExchange = keyExchangeECDHE;
                            }
                            else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                            {
                                PSKServerKeyExchange serverKeyExchange = PSKServerKeyExchange.Deserialise(stream, _Version);
                                PSKClientKeyExchange clientKeyExchange = new PSKClientKeyExchange();
                                if (serverKeyExchange.PSKIdentityHint != null)
                                {
                                    byte[] key = _PSKIdentities.GetKey(serverKeyExchange.PSKIdentityHint);
                                    if (key != null)
                                        _PSKIdentity = new PSKIdentity() { Identity = serverKeyExchange.PSKIdentityHint, Key = key };
                                }
                                if (_PSKIdentity == null)
                                    _PSKIdentity = _PSKIdentities.GetRandom();
                                byte[] otherSecret = new byte[_PSKIdentity.Key.Length];
                                clientKeyExchange.PSKIdentity = _PSKIdentity.Identity;
                                _ClientKeyExchange = clientKeyExchange;
                                preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, _PSKIdentity.Key);
                            }
                            _Cipher = TLSUtils.AssignCipher(preMasterSecret, true, _Version, _HandshakeInfo);

                            break;
                        case THandshakeType.CertificateRequest:
                            _HandshakeInfo.UpdateHandshakeHash(data);
                            _SendCertificate = true;
                            break;
                        case THandshakeType.ServerHelloDone:
                            _HandshakeInfo.UpdateHandshakeHash(data);
                            if (_Cipher == null)
                            {
                                keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(_HandshakeInfo.CipherSuite);
                                if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                                {
                                    PSKClientKeyExchange clientKeyExchange = new PSKClientKeyExchange();
                                    _PSKIdentity = _PSKIdentities.GetRandom();
                                    byte[] otherSecret = new byte[_PSKIdentity.Key.Length];
                                    clientKeyExchange.PSKIdentity = _PSKIdentity.Identity;
                                    _ClientKeyExchange = clientKeyExchange;
                                    preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, _PSKIdentity.Key);
                                    _Cipher = TLSUtils.AssignCipher(preMasterSecret, true, _Version, _HandshakeInfo);
                                }
                            }

                            if (_SendCertificate)
                            {
                                SendHandshakeMessage(_Certificate, false);
                            }
                            SendHandshakeMessage(_ClientKeyExchange, false);
                            if (_SendCertificate)
                            {
                                CertificateVerify certificateVerify = new CertificateVerify();
                                byte[] signatureHash = _HandshakeInfo.GetHash();
                                certificateVerify.SignatureHashAlgorithm = new SignatureHashAlgorithm() { Signature = TSignatureAlgorithm.ECDSA, Hash = THashAlgorithm.SHA256 };
                                certificateVerify.Signature = TLSUtils.Sign(_PrivateKey, true, _Version, _HandshakeInfo, certificateVerify.SignatureHashAlgorithm, signatureHash);
                                SendHandshakeMessage(certificateVerify, false);
                            }
                            SendChangeCipherSpec();
                            byte[] handshakeHash = _HandshakeInfo.GetHash();
                            Finished finished = new Finished();
                            finished.VerifyData = TLSUtils.GetVerifyData(_Version,_HandshakeInfo,true, true, handshakeHash);
                            SendHandshakeMessage(finished, true);
#if DEBUG
                            Console.Write("Handshake Hash:");
                            TLSUtils.WriteToConsole(handshakeHash);
                            Console.Write("Sent Verify:");
                            TLSUtils.WriteToConsole(finished.VerifyData);
#endif
                            break;
                        case THandshakeType.CertificateVerify:
                            break;
                        case THandshakeType.ClientKeyExchange:
                            break;
                        case THandshakeType.Finished:
                            Finished serverFinished = Finished.Deserialise(stream);
                            handshakeHash = _HandshakeInfo.GetHash();
                            byte[] calculatedVerifyData = TLSUtils.GetVerifyData(_Version,_HandshakeInfo, true, false, handshakeHash);
#if DEBUG
                            Console.Write("Recieved Verify:");
                            TLSUtils.WriteToConsole(serverFinished.VerifyData);
                            Console.Write("Calc Verify:");
                            TLSUtils.WriteToConsole(calculatedVerifyData);
#endif
                            if (TLSUtils.ByteArrayCompare(serverFinished.VerifyData, calculatedVerifyData))
                            {
#if DEBUG
                                Console.WriteLine("Handshake Complete");
#endif
                                _Connected.Set();
                            }
                            break;
                        default:
                            break;
                    }
                }
            }
        }
예제 #23
0
        private void ReceiveCallback(object sender, SocketAsyncEventArgs e)
        {
            if (e.BytesTransferred == 0)
            {
#if DEBUG
                Console.WriteLine($"ReceiveCallback got 0 bytes");
#endif
            }
            else
            {
                int count = e.BytesTransferred;
#if DEBUG
                Console.WriteLine($"ReceiveCallback got {count} bytes");
#endif

                byte[] data = new byte[count];
                Buffer.BlockCopy(e.Buffer, 0, data, 0, count);
                MemoryStream stream = new MemoryStream(data);
                while (stream.Position < stream.Length)
                {
                    DTLSRecord record = DTLSRecord.Deserialise(stream);
                    if (record != null)
                    {
                        record.RemoteEndPoint = e.RemoteEndPoint;
                        SocketAddress address = record.RemoteEndPoint.Serialize();
                        Session       session = _Sessions.GetSession(address);
                        if (session == null)
                        {
#if DEBUG
                            Console.WriteLine($"session was null");
#endif
                            ThreadPool.QueueUserWorkItem(ProcessRecord, record);
                        }
                        else
                        {
                            CheckSession(session, record);
                        }
                    }
                }
                if (sender is Socket socket)
                {
                    System.Net.EndPoint remoteEndPoint;
                    if (socket.AddressFamily == AddressFamily.InterNetwork)
                    {
                        remoteEndPoint = new IPEndPoint(IPAddress.Any, 0);
                    }
                    else
                    {
                        remoteEndPoint = new IPEndPoint(IPAddress.IPv6Any, 0);
                    }
                    e.RemoteEndPoint = remoteEndPoint;
                    e.SetBuffer(0, 4096);
                    bool pending = socket.ReceiveFromAsync(e);
                    if (!pending)
                    {
                        // If ReceiveFromAsync returns false, the callback will not be triggered automatically so we must call it ourselves.
                        ReceiveCallback(sender, e);
                    }
                }
            }
        }
예제 #24
0
        private void ProcessRecord(SocketAddress address, Session session, DTLSRecord record)
        {
            try
            {
#if DEBUG
                Console.WriteLine($"> RecordType={record.RecordType.ToString()} sessionID={session?.SessionID} remoteEndPoint={record?.RemoteEndPoint}");
                Console.Write($"Data: {TLSUtils.WriteToString(record?.Fragment)}");
#endif
                switch (record.RecordType)
                {
                case TRecordType.ChangeCipherSpec:
                    if (session != null)
                    {
                        session.ClientEpoch++;
                        session.ClientSequenceNumber = 0;
                        session.SetEncyptChange(record);
                    }
                    break;

                case TRecordType.Alert:
                    if (session != null)
                    {
                        AlertRecord alertRecord;
                        try
                        {
                            if (session.Cipher == null)
                            {
                                alertRecord = AlertRecord.Deserialise(record.Fragment);
                            }
                            else
                            {
                                long   sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                                byte[] data           = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Alert, record.Fragment, 0, record.Fragment.Length);
                                alertRecord = AlertRecord.Deserialise(data);
                            }
                        }
                        catch
                        {
                            alertRecord = new AlertRecord
                            {
                                AlertLevel = TAlertLevel.Fatal
                            };
                        }
                        if (alertRecord.AlertLevel == TAlertLevel.Fatal)
                        {
                            _Sessions.Remove(session, address);
                        }
                        else if ((alertRecord.AlertLevel == TAlertLevel.Warning) || (alertRecord.AlertDescription == TAlertDescription.CloseNotify))
                        {
                            if (alertRecord.AlertDescription == TAlertDescription.CloseNotify)
                            {
                                SendAlert(session, address, TAlertLevel.Warning, TAlertDescription.CloseNotify);
                            }
                            _Sessions.Remove(session, address);
                        }
                    }
                    break;

                case TRecordType.Handshake:
                    _Handshake.ProcessHandshake(record);
                    break;

                case TRecordType.ApplicationData:
                    if (session != null)
                    {
                        if (session.Cipher != null)
                        {
                            long   sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                            byte[] data           = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.ApplicationData, record.Fragment, 0, record.Fragment.Length);
                            DataReceived?.Invoke(record.RemoteEndPoint, data);
                        }
                    }
                    break;

                default:
                    break;
                }
            }
#if DEBUG
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
                StackFrame callStack = new StackFrame(1, true);
                Console.WriteLine($"Exception! Type: { ex.GetType()}\n\tData: { ex.Data.Count}\n\tMessage: { ex.Message}\n\tSource: { ex.Source}\n\t" +
                                  $"StackTrace: { ex.StackTrace}\n\tFile: {callStack.GetFileName()}\n\t" +
                                  $"Line: {callStack.GetFileLineNumber()}");
#else
            catch
            {
#endif
                SendAlert(session, address, TAlertLevel.Fatal, TAlertDescription.InternalError);
            }
        }
예제 #25
0
        private async Task _ProcessRecordAsync(DTLSRecord record)
        {
            try
            {
                if (record == null)
                {
                    throw new ArgumentNullException(nameof(record));
                }

                switch (record.RecordType)
                {
                case TRecordType.ChangeCipherSpec:
                {
                    this._ReceivedData = new byte[0];
                    if (this._ServerEpoch.HasValue)
                    {
                        this._ServerEpoch++;
                        this._ServerSequenceNumber = 0;
                        this._EncyptedServerEpoch  = this._ServerEpoch;
                    }
                    break;
                }

                case TRecordType.Alert:
                {
                    this._ReceivedData = new byte[0];
                    AlertRecord alertRecord;
                    try
                    {
                        if ((this._Cipher == null) || (!this._EncyptedServerEpoch.HasValue))
                        {
                            alertRecord = AlertRecord.Deserialise(record.Fragment);
                        }
                        else
                        {
                            var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                            var data           = this._Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Alert, record.Fragment, 0, record.Fragment.Length);
                            alertRecord = AlertRecord.Deserialise(data);
                        }
                    }
                    catch
                    {
                        alertRecord = new AlertRecord
                        {
                            AlertLevel = TAlertLevel.Fatal
                        };
                    }
                    if (alertRecord.AlertLevel == TAlertLevel.Fatal)
                    {
                        this._ConnectionComplete = true;
                    }
                    else if ((alertRecord.AlertLevel == TAlertLevel.Warning) || (alertRecord.AlertDescription == TAlertDescription.CloseNotify))
                    {
                        if (alertRecord.AlertDescription == TAlertDescription.CloseNotify)
                        {
                            await this._SendAlertAsync(TAlertLevel.Warning, TAlertDescription.CloseNotify).ConfigureAwait(false);

                            this._ConnectionComplete = true;
                        }
                    }
                    break;
                }

                case TRecordType.Handshake:
                {
                    this._ReceivedData = new byte[0];
                    await this._ProcessHandshakeAsync(record).ConfigureAwait(false);

                    this._ServerSequenceNumber = record.SequenceNumber + 1;
                    break;
                }

                case TRecordType.ApplicationData:
                {
                    if (this._Cipher != null)
                    {
                        var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                        var data           = this._Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.ApplicationData, record.Fragment, 0, record.Fragment.Length);
                        this._DataReceivedFunction?.Invoke(record.RemoteEndPoint, data);
                        this._ReceivedData = data;
                    }
                    this._ServerSequenceNumber = record.SequenceNumber + 1;
                    break;
                }

                default:
                    break;
                }
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
            }
        }
예제 #26
0
        private void ProcessRecord(object state)
        {
            DTLSRecord record = state as DTLSRecord;

            if (record != null)
            {
                SocketAddress address = null;
                Session       session = null;
                try
                {
                    address = record.RemoteEndPoint.Serialize();
                    session = _Sessions.GetSession(address);
                    if (session == null)
                    {
                        ProcessRecord(address, session, record);
                        session = _Sessions.GetSession(address);
                        if (session != null)
                        {
                            lock (session)
                            {
                                if (record.RecordType != TRecordType.ChangeCipherSpec)
                                {
                                    session.ClientSequenceNumber++;
                                }
                            }
                        }
                    }
                    else
                    {
                        bool processRecord = false;
                        if ((session.ClientEpoch == record.Epoch) && (session.ClientSequenceNumber == record.SequenceNumber))
                        {
                            processRecord = true;
                        }
                        else if (session.ClientEpoch > record.Epoch)
                        {
                            processRecord = true;
                        }
                        else if ((session.ClientEpoch == record.Epoch) && (session.ClientSequenceNumber > record.SequenceNumber))
                        {
                            processRecord = true;
                        }
                        if (processRecord)
                        {
                            do
                            {
                                ProcessRecord(address, session, record);
                                lock (session)
                                {
                                    if (record.RecordType != TRecordType.ChangeCipherSpec)
                                    {
                                        session.ClientSequenceNumber++;
                                    }
                                }
                                record = session.Records.PeekRecord();
                                if (record != null)
                                {
                                    if ((session.ClientSequenceNumber == record.SequenceNumber) && (session.ClientEpoch == record.Epoch))
                                    {
                                        session.Records.RemoveRecord();
                                    }
                                    else
                                    {
                                        record = null;
                                    }
                                }
                            } while (record != null);
                        }
                    }
                }
                catch (Org.BouncyCastle.Crypto.Tls.TlsFatalAlert ex)
                {
                    SendAlert(session, address, TAlertLevel.Fatal, (TAlertDescription)ex.AlertDescription);
                }
                catch
                {
                    SendAlert(session, address, TAlertLevel.Fatal, TAlertDescription.InternalError);
                }
            }
        }
예제 #27
0
        private IEnumerable <byte[]> _GetBytes(IHandshakeMessage handshakeMessage, bool encrypt)
        {
            if (handshakeMessage == null)
            {
                throw new ArgumentNullException(nameof(handshakeMessage));
            }

            var size           = handshakeMessage.CalculateSize(this._Version);
            var maxPayloadSize = _MaxPacketSize - DTLSRecord.RECORD_OVERHEAD + HandshakeRecord.RECORD_OVERHEAD;

            if (size > maxPayloadSize)
            {
                var wholeMessage = new List <byte[]>();

                var record = new DTLSRecord
                {
                    RecordType = TRecordType.Handshake,
                    Epoch      = _Epoch,
                    Version    = this._Version
                };

                var handshakeRecord = new HandshakeRecord
                {
                    MessageType = handshakeMessage.MessageType,
                    MessageSeq  = _MessageSequence
                };

                if (!(handshakeMessage.MessageType == THandshakeType.HelloVerifyRequest ||
                      (handshakeMessage.MessageType == THandshakeType.ClientHello && (handshakeMessage as ClientHello).Cookie == null)))
                {
                    record.Fragment                = new byte[HandshakeRecord.RECORD_OVERHEAD + size];
                    handshakeRecord.Length         = (uint)size;
                    handshakeRecord.FragmentLength = (uint)size;
                    handshakeRecord.FragmentOffset = 0u;
                    using (var stream = new MemoryStream(record.Fragment))
                    {
                        handshakeRecord.Serialise(stream);
                        handshakeMessage.Serialise(stream, this._Version);
                    }

                    this._HandshakeInfo.UpdateHandshakeHash(record.Fragment);
                }

                var dataMessage = new byte[size];
                using (var stream = new MemoryStream(dataMessage))
                {
                    handshakeMessage.Serialise(stream, this._Version);
                }

                var dataMessageFragments = dataMessage.ChunkBySize(maxPayloadSize);
                handshakeRecord.FragmentOffset = 0U;
                dataMessageFragments.ForEach(x =>
                {
                    handshakeRecord.Length         = (uint)size;
                    handshakeRecord.FragmentLength = (uint)x.Count();
                    record.SequenceNumber          = this._NextSequenceNumber();

                    var baseMessage = new byte[HandshakeRecord.RECORD_OVERHEAD];
                    using (var stream = new MemoryStream(baseMessage))
                    {
                        handshakeRecord.Serialise(stream);
                    }

                    record.Fragment = baseMessage.Concat(x).ToArray();

                    var responseSize = DTLSRecord.RECORD_OVERHEAD + HandshakeRecord.RECORD_OVERHEAD + x.Count();
                    if ((this._Cipher != null) && encrypt)
                    {
                        var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                        record.Fragment    = this._Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                        responseSize       = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;
                    }
                    var response = new byte[responseSize];
                    using (var stream = new MemoryStream(response))
                    {
                        record.Serialise(stream);
                    }

                    wholeMessage.Add(response);
                    handshakeRecord.FragmentOffset += (uint)x.Count();
                });

                this._MessageSequence++;
                return(wholeMessage);
            }
            else
            {
                var record = new DTLSRecord
                {
                    RecordType     = TRecordType.Handshake,
                    Epoch          = _Epoch,
                    SequenceNumber = this._NextSequenceNumber(),
                    Fragment       = new byte[HandshakeRecord.RECORD_OVERHEAD + size],
                    Version        = this._Version
                };

                var handshakeRecord = new HandshakeRecord
                {
                    MessageType = handshakeMessage.MessageType,
                    MessageSeq  = _MessageSequence
                };
                this._MessageSequence++;
                handshakeRecord.Length         = (uint)size;
                handshakeRecord.FragmentLength = (uint)size;
                using (var stream = new MemoryStream(record.Fragment))
                {
                    handshakeRecord.Serialise(stream);
                    handshakeMessage.Serialise(stream, this._Version);
                }

                if (!(handshakeMessage.MessageType == THandshakeType.HelloVerifyRequest ||
                      (handshakeMessage.MessageType == THandshakeType.ClientHello && (handshakeMessage as ClientHello).Cookie == null)))
                {
                    this._HandshakeInfo.UpdateHandshakeHash(record.Fragment);
                }

                var responseSize = DTLSRecord.RECORD_OVERHEAD + HandshakeRecord.RECORD_OVERHEAD + size;
                if ((this._Cipher != null) && encrypt)
                {
                    var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                    record.Fragment = this._Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                    responseSize    = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;
                }

                var response = new byte[responseSize];
                using (var stream = new MemoryStream(response))
                {
                    record.Serialise(stream);
                }

                return(new List <byte[]>()
                {
                    response
                });
            }
        }
예제 #28
0
        private void ProcessRecord(SocketAddress address, Session session, DTLSRecord record)
        {
            try
            {
#if DEBUG
                Console.WriteLine(record.RecordType.ToString());
#endif
                switch (record.RecordType)
                {
                case TRecordType.ChangeCipherSpec:
                    if (session != null)
                    {
                        session.ClientEpoch++;
                        session.ClientSequenceNumber = 0;
                        session.SetEncyptChange(record);
                    }
                    break;

                case TRecordType.Alert:
                    if (session != null)
                    {
                        AlertRecord alertRecord;
                        try
                        {
                            if (session.Cipher == null)
                            {
                                alertRecord = AlertRecord.Deserialise(record.Fragment);
                            }
                            else
                            {
                                long   sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                                byte[] data           = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Alert, record.Fragment, 0, record.Fragment.Length);
                                alertRecord = AlertRecord.Deserialise(data);
                            }
                        }
                        catch
                        {
                            alertRecord            = new AlertRecord();
                            alertRecord.AlertLevel = TAlertLevel.Fatal;
                        }
                        if (alertRecord.AlertLevel == TAlertLevel.Fatal)
                        {
                            _Sessions.Remove(session, address);
                        }
                        else if ((alertRecord.AlertLevel == TAlertLevel.Warning) || (alertRecord.AlertDescription == TAlertDescription.CloseNotify))
                        {
                            if (alertRecord.AlertDescription == TAlertDescription.CloseNotify)
                            {
                                SendAlert(session, address, TAlertLevel.Warning, TAlertDescription.CloseNotify);
                            }
                            _Sessions.Remove(session, address);
                        }
                    }
                    break;

                case TRecordType.Handshake:
                    _Handshake.ProcessHandshake(record);
                    break;

                case TRecordType.ApplicationData:
                    if (session != null)
                    {
                        if (session.Cipher != null)
                        {
                            long   sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                            byte[] data           = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.ApplicationData, record.Fragment, 0, record.Fragment.Length);
                            if (DataReceived != null)
                            {
                                DataReceived(record.RemoteEndPoint, data);
                            }
                        }
                    }
                    break;

                default:
                    break;
                }
            }
#if DEBUG
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
#else
            catch
            {
#endif
                SendAlert(session, address, TAlertLevel.Fatal, TAlertDescription.InternalError);
            }
        }
예제 #29
0
        private void SendResponse(Session session, IHandshakeMessage handshakeMessage, ushort messageSequence)
        {
            if (session == null)
            {
                throw new ArgumentNullException(nameof(session));
            }

            if (handshakeMessage == null)
            {
                throw new ArgumentNullException(nameof(handshakeMessage));
            }

            var size           = handshakeMessage.CalculateSize(session.Version);
            var maxPayloadSize = this._MaxPacketSize - DTLSRecord.RECORD_OVERHEAD + HandshakeRecord.RECORD_OVERHEAD;

            if (size > maxPayloadSize)
            {
                //fragments
                return;
            }

            var record = new DTLSRecord
            {
                RecordType     = TRecordType.Handshake,
                Epoch          = session.Epoch,
                SequenceNumber = session.NextSequenceNumber(),
                Fragment       = new byte[HandshakeRecord.RECORD_OVERHEAD + size]
            };

            if (session.Version != null)
            {
                record.Version = session.Version;
            }

            var handshakeRecord = new HandshakeRecord
            {
                MessageType    = handshakeMessage.MessageType,
                MessageSeq     = messageSequence,
                Length         = (uint)size,
                FragmentLength = (uint)size
            };

            using (var stream = new MemoryStream(record.Fragment))
            {
                handshakeRecord.Serialise(stream);
                handshakeMessage.Serialise(stream, session.Version);
            }

            if (handshakeMessage.MessageType != THandshakeType.HelloVerifyRequest)
            {
                session.Handshake.UpdateHandshakeHash(record.Fragment);
            }

            var responseSize = DTLSRecord.RECORD_OVERHEAD + HandshakeRecord.RECORD_OVERHEAD + size;

            if (session.Cipher != null)
            {
                var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                record.Fragment = session.Cipher.EncodePlaintext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                responseSize    = DTLSRecord.RECORD_OVERHEAD + record.Fragment.Length;
            }

            var response = new byte[responseSize];

            using (var stream = new MemoryStream(response))
            {
                record.Serialise(stream);
            }
            var parameters = new SocketAsyncEventArgs()
            {
                RemoteEndPoint = session.RemoteEndPoint
            };

            parameters.SetBuffer(response, 0, responseSize);
            this._Socket.SendToAsync(parameters);
        }
예제 #30
0
        public void ProcessHandshake(DTLSRecord record)
        {
#if DEBUG
            Console.WriteLine($"> ProcessHandshake got {record}");
#endif
            SocketAddress address = record.RemoteEndPoint.Serialize();
            Session       session = Sessions.GetSession(address);
            byte[]        data;
            if ((session != null) && session.IsEncypted(record))
            {
                int count = 0;
                while ((session.Cipher == null) && (count < (HandshakeTimeout / HANDSHAKE_DWELL_TIME)))
                {
                    System.Threading.Thread.Sleep(HANDSHAKE_DWELL_TIME);
                    count++;
                }

                if (session.Cipher == null)
                {
                    throw new Exception($"HandshakeTimeout: >{HandshakeTimeout}");
                }

                if (session.Cipher != null)
                {
                    long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                    data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length);
                }
                else
                {
                    data = record.Fragment;
                }
            }
            else
            {
                data = record.Fragment;
            }
            using (MemoryStream stream = new MemoryStream(data))
            {
                HandshakeRecord handshakeRecord = HandshakeRecord.Deserialise(stream);
                if (handshakeRecord != null)
                {
#if DEBUG
                    Console.WriteLine(handshakeRecord.MessageType.ToString());
#endif
                    switch (handshakeRecord.MessageType)
                    {
                    case THandshakeType.HelloRequest:
                        //HelloReq
                        break;

                    case THandshakeType.ClientHello:
                        ClientHello clientHello = ClientHello.Deserialise(stream);
                        if (clientHello != null)
                        {
                            byte[] cookie = clientHello.CalculateCookie(record.RemoteEndPoint, _HelloSecret);

                            if (clientHello.Cookie == null)
                            {
                                Version version = clientHello.ClientVersion;
                                if (ServerVersion < version)
                                {
                                    version = ServerVersion;
                                }
                                if (session == null)
                                {
                                    session = new Session
                                    {
                                        SessionID      = Guid.NewGuid(),
                                        RemoteEndPoint = record.RemoteEndPoint,
                                        Version        = version
                                    };
                                    Sessions.AddSession(address, session);
                                }
                                else
                                {
                                    session.Reset();
                                    session.Version = version;
                                }
                                session.ClientEpoch          = record.Epoch;
                                session.ClientSequenceNumber = record.SequenceNumber;
                                //session.Handshake.UpdateHandshakeHash(data);
                                HelloVerifyRequest helloVerifyRequest = new HelloVerifyRequest
                                {
                                    Cookie        = cookie,
                                    ServerVersion = ServerVersion
                                };
                                SendResponse(session, (IHandshakeMessage)helloVerifyRequest, 0);
                            }
                            else
                            {
                                if (session != null && session.Cipher != null && !session.IsEncypted(record))
                                {
                                    session.Reset();
                                }

                                if (TLSUtils.ByteArrayCompare(clientHello.Cookie, cookie))
                                {
                                    Version version = clientHello.ClientVersion;
                                    if (ServerVersion < version)
                                    {
                                        version = ServerVersion;
                                    }
                                    if (clientHello.SessionID == null)
                                    {
                                        if (session == null)
                                        {
                                            session = new Session();
                                            session.NextSequenceNumber();
                                            session.SessionID      = Guid.NewGuid();
                                            session.RemoteEndPoint = record.RemoteEndPoint;
                                            Sessions.AddSession(address, session);
                                        }
                                    }
                                    else
                                    {
                                        Guid sessionID = Guid.Empty;
                                        if (clientHello.SessionID.Length >= 16)
                                        {
                                            byte[] receivedSessionID = new byte[16];
                                            Buffer.BlockCopy(clientHello.SessionID, 0, receivedSessionID, 0, 16);
                                            sessionID = new Guid(receivedSessionID);
                                        }
                                        if (sessionID != Guid.Empty)
                                        {
                                            session = Sessions.GetSession(sessionID);
                                        }
                                        if (session == null)
                                        {
                                            //need to Find Session
                                            session = new Session
                                            {
                                                SessionID = Guid.NewGuid()
                                            };
                                            session.NextSequenceNumber();
                                            session.RemoteEndPoint = record.RemoteEndPoint;
                                            Sessions.AddSession(address, session);
                                            //session.Version = clientHello.ClientVersion;
                                        }
                                    }
                                    session.Version = version;
                                    session.Handshake.InitaliseHandshakeHash(version < DTLSRecord.Version1_2);
                                    session.Handshake.UpdateHandshakeHash(data);
                                    TCipherSuite cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL;
                                    foreach (TCipherSuite item in clientHello.CipherSuites)
                                    {
                                        if (_SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, PrivateKey, _PSKIdentities, _ValidatePSK != null))
                                        {
                                            cipherSuite = item;
                                            break;
                                        }
                                    }

                                    TKeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite);

                                    ServerHello serverHello     = new ServerHello();
                                    byte[]      clientSessionID = new byte[32];
                                    byte[]      temp            = session.SessionID.ToByteArray();
                                    Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16);
                                    Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16);

                                    serverHello.SessionID = clientSessionID;    // session.SessionID.ToByteArray();
                                    serverHello.Random    = new RandomData();
                                    serverHello.Random.Generate();
                                    serverHello.CipherSuite   = (ushort)cipherSuite;
                                    serverHello.ServerVersion = session.Version;

                                    THashAlgorithm hash  = THashAlgorithm.SHA256;
                                    TEllipticCurve curve = TEllipticCurve.secp521r1;
                                    if (clientHello.Extensions != null)
                                    {
                                        foreach (Extension extension in clientHello.Extensions)
                                        {
                                            if (extension.SpecifcExtension is ClientCertificateTypeExtension)
                                            {
                                                ClientCertificateTypeExtension clientCertificateType = extension.SpecifcExtension as ClientCertificateTypeExtension;
                                                //TCertificateType certificateType = TCertificateType.Unknown;
                                                //foreach (TCertificateType item in clientCertificateType.CertificateTypes)
                                                //{

                                                //}
                                                //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType));
                                            }
                                            else if (extension.SpecifcExtension is EllipticCurvesExtension)
                                            {
                                                EllipticCurvesExtension ellipticCurves = extension.SpecifcExtension as EllipticCurvesExtension;
                                                foreach (TEllipticCurve item in ellipticCurves.SupportedCurves)
                                                {
                                                    if (EllipticCurveFactory.SupportedCurve(item))
                                                    {
                                                        curve = item;
                                                        break;
                                                    }
                                                }
                                            }
                                            else if (extension.SpecifcExtension is ServerCertificateTypeExtension)
                                            {
                                                //serverHello.AddExtension();
                                            }
                                            else if (extension.SpecifcExtension is SignatureAlgorithmsExtension)
                                            {
                                                SignatureAlgorithmsExtension signatureAlgorithms = extension.SpecifcExtension as SignatureAlgorithmsExtension;
                                                foreach (SignatureHashAlgorithm item in signatureAlgorithms.SupportedAlgorithms)
                                                {
                                                    if (item.Signature == TSignatureAlgorithm.ECDSA)
                                                    {
                                                        hash = item.Hash;
                                                        break;
                                                    }
                                                }
                                            }
                                        }
                                    }

                                    session.Handshake.CipherSuite  = cipherSuite;
                                    session.Handshake.ClientRandom = clientHello.Random;
                                    session.Handshake.ServerRandom = serverHello.Random;


                                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                    {
                                        EllipticCurvePointFormatsExtension pointFormatsExtension = new EllipticCurvePointFormatsExtension();
                                        pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed);
                                        serverHello.AddExtension(pointFormatsExtension);
                                    }
                                    session.Handshake.MessageSequence = 1;
                                    SendResponse(session, serverHello, session.Handshake.MessageSequence);
                                    session.Handshake.MessageSequence++;

                                    if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                                    {
                                        if (Certificate != null)
                                        {
                                            SendResponse(session, Certificate, session.Handshake.MessageSequence);
                                            session.Handshake.MessageSequence++;
                                        }
                                        ECDHEKeyExchange keyExchange = new ECDHEKeyExchange
                                        {
                                            Curve = curve,
                                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                                            ClientRandom         = clientHello.Random,
                                            ServerRandom         = serverHello.Random
                                        };
                                        keyExchange.GenerateEphemeralKey();
                                        session.Handshake.KeyExchange = keyExchange;
                                        if (session.Version == DTLSRecord.DefaultVersion)
                                        {
                                            hash = THashAlgorithm.SHA1;
                                        }
                                        ECDHEServerKeyExchange serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, PrivateKey);
                                        SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                        session.Handshake.MessageSequence++;
                                        if (_RequireClientCertificate)
                                        {
                                            CertificateRequest certificateRequest = new CertificateRequest();
                                            certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign);
                                            certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm()
                                            {
                                                Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA
                                            });
                                            SendResponse(session, certificateRequest, session.Handshake.MessageSequence);
                                            session.Handshake.MessageSequence++;
                                        }
                                    }
                                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                                    {
                                        ECDHEKeyExchange keyExchange = new ECDHEKeyExchange
                                        {
                                            Curve = curve,
                                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                                            ClientRandom         = clientHello.Random,
                                            ServerRandom         = serverHello.Random
                                        };
                                        keyExchange.GenerateEphemeralKey();
                                        session.Handshake.KeyExchange = keyExchange;
                                        ECDHEPSKServerKeyExchange serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange);
                                        SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                        session.Handshake.MessageSequence++;
                                    }
                                    else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                                    {
                                        PSKKeyExchange keyExchange = new PSKKeyExchange
                                        {
                                            KeyExchangeAlgorithm = keyExchangeAlgorithm,
                                            ClientRandom         = clientHello.Random,
                                            ServerRandom         = serverHello.Random
                                        };
                                        session.Handshake.KeyExchange = keyExchange;
                                        //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message
                                        //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange();
                                        //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence);
                                        //session.Handshake.MessageSequence++;
                                    }
                                    SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence);
                                    session.Handshake.MessageSequence++;
                                }
                            }
                        }
                        break;

                    case THandshakeType.ServerHello:
                        break;

                    case THandshakeType.HelloVerifyRequest:
                        break;

                    case THandshakeType.Certificate:
                        Certificate clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509);
                        if (clientCertificate.CertChain.Count > 0)
                        {
                            session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER);
                        }
                        session.Handshake.UpdateHandshakeHash(data);
                        break;

                    case THandshakeType.ServerKeyExchange:
                        break;

                    case THandshakeType.CertificateRequest:
                        break;

                    case THandshakeType.ServerHelloDone:
                        break;

                    case THandshakeType.CertificateVerify:
                        CertificateVerify certificateVerify = CertificateVerify.Deserialise(stream, session.Version);
                        session.Handshake.UpdateHandshakeHash(data);
                        break;

                    case THandshakeType.ClientKeyExchange:
                        if ((session == null) || (session.Handshake.KeyExchange == null))
                        {
                        }
                        else
                        {
                            session.Handshake.UpdateHandshakeHash(data);
                            byte[] preMasterSecret = null;
                            if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA)
                            {
                                ECDHEClientKeyExchange clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream);
                                if (clientKeyExchange != null)
                                {
                                    ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                    preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                }
                            }
                            else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK)
                            {
                                ECDHEPSKClientKeyExchange clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream);
                                if (clientKeyExchange != null)
                                {
                                    session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                                    byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                                    if (psk == null)
                                    {
                                        psk = _ValidatePSK(clientKeyExchange.PSKIdentity);
                                        if (psk != null)
                                        {
                                            _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                        }
                                    }

                                    if (psk != null)
                                    {
                                        ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                        byte[]           otherSecret   = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes);
                                        preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                                    }
                                }
                            }
                            else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK)
                            {
                                PSKClientKeyExchange clientKeyExchange = PSKClientKeyExchange.Deserialise(stream);
                                if (clientKeyExchange != null)
                                {
                                    session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity);
                                    byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity);

                                    if (psk == null)
                                    {
                                        psk = _ValidatePSK(clientKeyExchange.PSKIdentity);
                                        if (psk != null)
                                        {
                                            _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk);
                                        }
                                    }

                                    if (psk != null)
                                    {
                                        ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange;
                                        byte[]           otherSecret   = new byte[psk.Length];
                                        preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk);
                                    }
                                }
                            }

                            if (preMasterSecret != null)
                            {
                                //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange);
                                //TLSUtils.AssignCipher(session);

                                session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake);
                            }
                            else
                            {
                                Console.WriteLine($"preMasterSecret is null!");
                            }
                        }
                        break;

                    case THandshakeType.Finished:
                        Finished finished = Finished.Deserialise(stream);
                        if (session != null)
                        {
                            byte[] handshakeHash        = session.Handshake.GetHash();
                            byte[] calculatedVerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, true, handshakeHash);
#if DEBUG
                            Console.Write($"Handshake Hash: {TLSUtils.WriteToString(handshakeHash)}");
                            Console.Write($"Sent Verify: {TLSUtils.WriteToString(finished.VerifyData)}");
                            Console.Write($"Calc Verify: {TLSUtils.WriteToString(calculatedVerifyData)}");
#endif
                            if (TLSUtils.ByteArrayCompare(finished.VerifyData, calculatedVerifyData))
                            {
                                SendChangeCipherSpec(session);
                                session.Handshake.UpdateHandshakeHash(data);
                                handshakeHash = session.Handshake.GetHash();
                                Finished serverFinished = new Finished
                                {
                                    VerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, false, handshakeHash)
                                };
                                SendResponse(session, serverFinished, session.Handshake.MessageSequence);
                                session.Handshake.MessageSequence++;
                            }
                            else
                            {
                                throw new Exception();
                            }
                        }
                        break;

                    default:
                        break;
                    }
                }
            }
        }
예제 #31
0
		public static DTLSRecord Deserialise(Stream stream)
		{
			DTLSRecord result = new DTLSRecord();
			result._RecordType = (TRecordType)stream.ReadByte();
			// could check here for a valid type, and bail out if invalid
			result._Version = new Version(255 - stream.ReadByte(), 255 - stream.ReadByte());
			result._Epoch = NetworkByteOrderConverter.ToUInt16(stream);
			result._SequenceNumber = NetworkByteOrderConverter.ToInt48(stream);
			result._Length = NetworkByteOrderConverter.ToUInt16(stream);
			if (result._Length > 0)
			{
				result._Fragment = new byte[result._Length];
				int length = stream.Read(result._Fragment, 0, result._Length);
				while (length < result._Length)
				{
					int bytesRead = stream.Read(result._Fragment, length, result._Length - length);
					if (bytesRead > 0)
					{
						length += bytesRead;
					}
					else
					{
						break;
					}
				}
			}
			return result;
		}
예제 #32
0
 internal void SetEncyptChange(DTLSRecord record)
 {
     EncyptedClientEpoch = (ushort)(record.Epoch + 1);
 }
예제 #33
0
        private void ProcessRecord(DTLSRecord record)
        {
            try
            {
#if DEBUG
                Console.WriteLine(record.RecordType.ToString());
#endif
                switch (record.RecordType)
                {
                case TRecordType.ChangeCipherSpec:
                    if (_ServerEpoch.HasValue)
                    {
                        _ServerEpoch++;
                        _ServerSequenceNumber = 0;
                        _EncyptedServerEpoch  = _ServerEpoch;
                    }
                    break;

                case TRecordType.Alert:
                    AlertRecord alertRecord;
                    try
                    {
                        if ((_Cipher == null) || (!_EncyptedServerEpoch.HasValue))
                        {
                            alertRecord = AlertRecord.Deserialise(record.Fragment);
                        }
                        else
                        {
                            long   sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                            byte[] data           = _Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Alert, record.Fragment, 0, record.Fragment.Length);
                            alertRecord = AlertRecord.Deserialise(data);
                        }
                    }
                    catch
                    {
                        alertRecord = new AlertRecord
                        {
                            AlertLevel = TAlertLevel.Fatal
                        };
                    }
                    if (alertRecord.AlertLevel == TAlertLevel.Fatal)
                    {
                        _Connected.Set();
                        //Terminate
                    }
                    else if ((alertRecord.AlertLevel == TAlertLevel.Warning) || (alertRecord.AlertDescription == TAlertDescription.CloseNotify))
                    {
                        if (alertRecord.AlertDescription == TAlertDescription.CloseNotify)
                        {
                            SendAlert(TAlertLevel.Warning, TAlertDescription.CloseNotify);
                            _Connected.Set();
                        }
                        //_Sessions.Remove(session, address);
                    }
                    break;

                case TRecordType.Handshake:
                    ProcessHandshake(record);
                    _ServerSequenceNumber = record.SequenceNumber + 1;
                    break;

                case TRecordType.ApplicationData:
                    if (_Cipher != null)
                    {
                        long   sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber;
                        byte[] data           = _Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.ApplicationData, record.Fragment, 0, record.Fragment.Length);
                        DataReceived?.Invoke(record.RemoteEndPoint, data);
                    }
                    _ServerSequenceNumber = record.SequenceNumber + 1;
                    break;

                default:
                    break;
                }
            }
#if DEBUG
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
                StackFrame callStack = new StackFrame(1, true);
                Console.WriteLine($"Exception! Type: { ex.GetType()}\n\tData: { ex.Data.Count}\n\tMessage: { ex.Message}\n\tSource: { ex.Source}\n\t" +
                                  $"StackTrace: { ex.StackTrace}\n\tFile: {callStack.GetFileName()}\n\t" +
                                  $"Line: {callStack.GetFileLineNumber()}");
#else
            catch
            {
#endif
            }
        }
예제 #34
0
 private void CheckSession(Session session, DTLSRecord record)
 {
     if ((session.ClientEpoch == record.Epoch) && (session.ClientSequenceNumber == record.SequenceNumber))
     {
         ThreadPool.QueueUserWorkItem(ProcessRecord, record);
     }
     else if (session.ClientEpoch > record.Epoch)
     {
         ThreadPool.QueueUserWorkItem(ProcessRecord, record);
     }
     else if ((session.ClientEpoch == record.Epoch) && (session.ClientSequenceNumber > record.SequenceNumber))
     {
         ThreadPool.QueueUserWorkItem(ProcessRecord, record);
     }
     else
     {
         bool canProcessNow = false;
         lock (session)
         {
             if ((session.ClientSequenceNumber == record.SequenceNumber) && (session.ClientEpoch == record.Epoch))
             {
                 canProcessNow = true;
             }
             else
             {
                 session.Records.Add(record);
             }
         }
         if (canProcessNow)
             CheckSession(session, record);
     }
 }