public void ProcessHandshake(DTLSRecord record) { SocketAddress address = record.RemoteEndPoint.Serialize(); Session session = Sessions.GetSession(address); byte[] data; if ((session != null) && session.IsEncypted(record)) { int count = 0; while ((session.Cipher == null) && (count < 50)) { System.Threading.Thread.Sleep(10); count++; } if (session.Cipher == null) { throw new Exception(); } if (session.Cipher != null) { long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber; data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length); } else data = record.Fragment; } else data = record.Fragment; using (MemoryStream stream = new MemoryStream(data)) { HandshakeRecord handshakeRecord = HandshakeRecord.Deserialise(stream); if (handshakeRecord != null) { #if DEBUG Console.WriteLine(handshakeRecord.MessageType.ToString()); #endif switch (handshakeRecord.MessageType) { case THandshakeType.HelloRequest: //HelloReq break; case THandshakeType.ClientHello: ClientHello clientHello = ClientHello.Deserialise(stream); if (clientHello != null) { byte[] cookie = clientHello.CalculateCookie(record.RemoteEndPoint, _HelloSecret); if (clientHello.Cookie == null) { Version version = clientHello.ClientVersion; if (ServerVersion < version) version = ServerVersion; if (session == null) { session = new Session(); session.SessionID = Guid.NewGuid(); session.RemoteEndPoint = record.RemoteEndPoint; session.Version = version; Sessions.AddSession(address, session); } else { session.Reset(); session.Version = version; } session.ClientEpoch = record.Epoch; session.ClientSequenceNumber = record.SequenceNumber; //session.Handshake.UpdateHandshakeHash(data); HelloVerifyRequest helloVerifyRequest = new HelloVerifyRequest(); helloVerifyRequest.Cookie = cookie; helloVerifyRequest.ServerVersion = ServerVersion; SendResponse(session, (IHandshakeMessage)helloVerifyRequest, 0); } else { if (session != null && session.Cipher != null && !session.IsEncypted(record)) { session.Reset(); } if (TLSUtils.ByteArrayCompare(clientHello.Cookie, cookie)) { Version version = clientHello.ClientVersion; if (ServerVersion < version) version = ServerVersion; if (clientHello.SessionID == null) { if (session == null) { session = new Session(); session.NextSequenceNumber(); session.SessionID = Guid.NewGuid(); session.RemoteEndPoint = record.RemoteEndPoint; Sessions.AddSession(address, session); } } else { Guid sessionID = Guid.Empty; if (clientHello.SessionID.Length >= 16) { byte[] receivedSessionID = new byte[16]; Buffer.BlockCopy(clientHello.SessionID, 0, receivedSessionID, 0, 16); sessionID = new Guid(receivedSessionID); } if (sessionID != Guid.Empty) session = Sessions.GetSession(sessionID); if (session == null) { //need to Find Session session = new Session(); session.SessionID = Guid.NewGuid(); session.NextSequenceNumber(); session.RemoteEndPoint = record.RemoteEndPoint; Sessions.AddSession(address, session); //session.Version = clientHello.ClientVersion; } } session.Version = version; session.Handshake.InitaliseHandshakeHash(version < DTLSRecord.Version1_2); session.Handshake.UpdateHandshakeHash(data); TCipherSuite cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL; foreach (TCipherSuite item in clientHello.CipherSuites) { if (_SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, PrivateKey, _PSKIdentities, _ValidatePSK != null)) { cipherSuite = item; break; } } TKeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite); ServerHello serverHello = new ServerHello(); byte[] clientSessionID = new byte[32]; byte[] temp = session.SessionID.ToByteArray(); Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16); Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16); serverHello.SessionID = clientSessionID;// session.SessionID.ToByteArray(); serverHello.Random = new RandomData(); serverHello.Random.Generate(); serverHello.CipherSuite = (ushort)cipherSuite; serverHello.ServerVersion = session.Version; THashAlgorithm hash = THashAlgorithm.SHA256; TEllipticCurve curve = TEllipticCurve.secp521r1; if (clientHello.Extensions != null) { foreach (Extension extension in clientHello.Extensions) { if (extension.SpecifcExtension is ClientCertificateTypeExtension) { ClientCertificateTypeExtension clientCertificateType = extension.SpecifcExtension as ClientCertificateTypeExtension; //TCertificateType certificateType = TCertificateType.Unknown; //foreach (TCertificateType item in clientCertificateType.CertificateTypes) //{ //} //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType)); } else if (extension.SpecifcExtension is EllipticCurvesExtension) { EllipticCurvesExtension ellipticCurves = extension.SpecifcExtension as EllipticCurvesExtension; foreach (TEllipticCurve item in ellipticCurves.SupportedCurves) { if (EllipticCurveFactory.SupportedCurve(item)) { curve = item; break; } } } else if (extension.SpecifcExtension is ServerCertificateTypeExtension) { //serverHello.AddExtension(); } else if (extension.SpecifcExtension is SignatureAlgorithmsExtension) { SignatureAlgorithmsExtension signatureAlgorithms = extension.SpecifcExtension as SignatureAlgorithmsExtension; foreach (SignatureHashAlgorithm item in signatureAlgorithms.SupportedAlgorithms) { if (item.Signature == TSignatureAlgorithm.ECDSA) { hash = item.Hash; break; } } } } } session.Handshake.CipherSuite = cipherSuite; session.Handshake.ClientRandom = clientHello.Random; session.Handshake.ServerRandom = serverHello.Random; if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { EllipticCurvePointFormatsExtension pointFormatsExtension = new EllipticCurvePointFormatsExtension(); pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed); serverHello.AddExtension(pointFormatsExtension); } session.Handshake.MessageSequence = 1; SendResponse(session, serverHello, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { if (Certificate != null) { SendResponse(session, Certificate, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } ECDHEKeyExchange keyExchange = new ECDHEKeyExchange(); keyExchange.Curve = curve; keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm; keyExchange.ClientRandom = clientHello.Random; keyExchange.ServerRandom = serverHello.Random; keyExchange.GenerateEphemeralKey(); session.Handshake.KeyExchange = keyExchange; if (session.Version == DTLSRecord.DefaultVersion) hash = THashAlgorithm.SHA1; ECDHEServerKeyExchange serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, PrivateKey); SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; if (_RequireClientCertificate) { CertificateRequest certificateRequest = new CertificateRequest(); certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign); certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm() { Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA }); SendResponse(session, certificateRequest, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } } else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK) { ECDHEKeyExchange keyExchange = new ECDHEKeyExchange(); keyExchange.Curve = curve; keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm; keyExchange.ClientRandom = clientHello.Random; keyExchange.ServerRandom = serverHello.Random; keyExchange.GenerateEphemeralKey(); session.Handshake.KeyExchange = keyExchange; ECDHEPSKServerKeyExchange serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange); SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK) { PSKKeyExchange keyExchange = new PSKKeyExchange(); keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm; keyExchange.ClientRandom = clientHello.Random; keyExchange.ServerRandom = serverHello.Random; session.Handshake.KeyExchange = keyExchange; //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange(); //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); //session.Handshake.MessageSequence++; } SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } } } break; case THandshakeType.ServerHello: break; case THandshakeType.HelloVerifyRequest: break; case THandshakeType.Certificate: Certificate clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509); if (clientCertificate.CertChain.Count > 0) { session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER); } session.Handshake.UpdateHandshakeHash(data); break; case THandshakeType.ServerKeyExchange: break; case THandshakeType.CertificateRequest: break; case THandshakeType.ServerHelloDone: break; case THandshakeType.CertificateVerify: CertificateVerify certificateVerify = CertificateVerify.Deserialise(stream, session.Version); session.Handshake.UpdateHandshakeHash(data); break; case THandshakeType.ClientKeyExchange: if ((session == null) || (session.Handshake.KeyExchange == null)) { } else { session.Handshake.UpdateHandshakeHash(data); byte[] preMasterSecret = null; if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { ECDHEClientKeyExchange clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes); } } else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK) { ECDHEPSKClientKeyExchange clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity); byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity); if (psk == null) { psk = _ValidatePSK(clientKeyExchange.PSKIdentity); if (psk != null) { _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk); } } if (psk != null) { ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; byte[] otherSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes); preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk); } } } else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK) { PSKClientKeyExchange clientKeyExchange = PSKClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity); byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity); if (psk == null) { psk = _ValidatePSK(clientKeyExchange.PSKIdentity); if (psk != null) { _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk); } } if (psk != null) { ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; byte[] otherSecret = new byte[psk.Length]; preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk); } } } if (preMasterSecret != null) { //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange); //TLSUtils.AssignCipher(session); session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake); } } break; case THandshakeType.Finished: Finished finished = Finished.Deserialise(stream); if (session != null) { byte[] handshakeHash = session.Handshake.GetHash(); byte[] calculatedVerifyData = TLSUtils.GetVerifyData(session.Version,session.Handshake,false, true, handshakeHash); #if DEBUG Console.Write("Handshake Hash:"); TLSUtils.WriteToConsole(handshakeHash); Console.Write("Sent Verify:"); TLSUtils.WriteToConsole(finished.VerifyData); Console.Write("Calc Verify:"); TLSUtils.WriteToConsole(calculatedVerifyData); #endif if (TLSUtils.ByteArrayCompare(finished.VerifyData, calculatedVerifyData)) { SendChangeCipherSpec(session); session.Handshake.UpdateHandshakeHash(data); handshakeHash = session.Handshake.GetHash(); Finished serverFinished = new Finished(); serverFinished.VerifyData = TLSUtils.GetVerifyData(session.Version,session.Handshake,false, false, handshakeHash); SendResponse(session, serverFinished, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } else { throw new Exception(); } } break; default: break; } } } }
public static CertificateRequest Deserialise(Stream stream, Version version) { CertificateRequest result = new CertificateRequest(); int certificateTypeCount = stream.ReadByte(); if (certificateTypeCount > 0) { for (int index = 0; index < certificateTypeCount; index++) { result.CertificateTypes.Add((TClientCertificateType)stream.ReadByte()); } } if (version >= DTLSRecord.Version1_2) { ushort length = NetworkByteOrderConverter.ToUInt16(stream); ushort supportedAlgorithmsLength = (ushort)(length / 2); if (supportedAlgorithmsLength > 0) { for (uint index = 0; index < supportedAlgorithmsLength; index++) { THashAlgorithm hash = (THashAlgorithm)stream.ReadByte(); TSignatureAlgorithm signature = (TSignatureAlgorithm)stream.ReadByte(); result.SupportedAlgorithms.Add(new SignatureHashAlgorithm() { Hash = hash, Signature = signature }); } } } ushort certificateAuthoritiesLength = NetworkByteOrderConverter.ToUInt16(stream); if (certificateAuthoritiesLength > 0) { int read = 0; while(certificateAuthoritiesLength > read) { ushort distinguishedNameLength = NetworkByteOrderConverter.ToUInt16(stream); read += (2 + distinguishedNameLength); byte[] distinguishedName = new byte[distinguishedNameLength]; stream.Read(distinguishedName, 0, distinguishedNameLength); result.CertificateAuthorities.Add(distinguishedName); } } return result; }
public void ProcessHandshake(DTLSRecord record) { SocketAddress address = record.RemoteEndPoint.Serialize(); Session session = Sessions.GetSession(address); byte[] data; if ((session != null) && session.IsEncypted(record)) { int count = 0; while ((session.Cipher == null) && (count < (HandshakeTimeout / HANDSHAKE_DWELL_TIME))) { System.Threading.Thread.Sleep(HANDSHAKE_DWELL_TIME); count++; } if (session.Cipher == null) { throw new Exception($"HandshakeTimeout: >{HandshakeTimeout}"); } if (session.Cipher != null) { long sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber; data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length); } else { data = record.Fragment; } } else { data = record.Fragment; } using (MemoryStream stream = new MemoryStream(data)) { HandshakeRecord handshakeRecord = HandshakeRecord.Deserialise(stream); if (handshakeRecord != null) { #if DEBUG Console.WriteLine(handshakeRecord.MessageType.ToString()); #endif switch (handshakeRecord.MessageType) { case THandshakeType.HelloRequest: //HelloReq break; case THandshakeType.ClientHello: ClientHello clientHello = ClientHello.Deserialise(stream); if (clientHello != null) { byte[] cookie = clientHello.CalculateCookie(record.RemoteEndPoint, _HelloSecret); if (clientHello.Cookie == null) { Version version = clientHello.ClientVersion; if (ServerVersion < version) { version = ServerVersion; } if (session == null) { session = new Session(); session.SessionID = Guid.NewGuid(); session.RemoteEndPoint = record.RemoteEndPoint; session.Version = version; Sessions.AddSession(address, session); } else { session.Reset(); session.Version = version; } session.ClientEpoch = record.Epoch; session.ClientSequenceNumber = record.SequenceNumber; //session.Handshake.UpdateHandshakeHash(data); HelloVerifyRequest helloVerifyRequest = new HelloVerifyRequest(); helloVerifyRequest.Cookie = cookie; helloVerifyRequest.ServerVersion = ServerVersion; SendResponse(session, (IHandshakeMessage)helloVerifyRequest, 0); } else { if (session != null && session.Cipher != null && !session.IsEncypted(record)) { session.Reset(); } if (TLSUtils.ByteArrayCompare(clientHello.Cookie, cookie)) { Version version = clientHello.ClientVersion; if (ServerVersion < version) { version = ServerVersion; } if (clientHello.SessionID == null) { if (session == null) { session = new Session(); session.NextSequenceNumber(); session.SessionID = Guid.NewGuid(); session.RemoteEndPoint = record.RemoteEndPoint; Sessions.AddSession(address, session); } } else { Guid sessionID = Guid.Empty; if (clientHello.SessionID.Length >= 16) { byte[] receivedSessionID = new byte[16]; Buffer.BlockCopy(clientHello.SessionID, 0, receivedSessionID, 0, 16); sessionID = new Guid(receivedSessionID); } if (sessionID != Guid.Empty) { session = Sessions.GetSession(sessionID); } if (session == null) { //need to Find Session session = new Session(); session.SessionID = Guid.NewGuid(); session.NextSequenceNumber(); session.RemoteEndPoint = record.RemoteEndPoint; Sessions.AddSession(address, session); //session.Version = clientHello.ClientVersion; } } session.Version = version; session.Handshake.InitaliseHandshakeHash(version < DTLSRecord.Version1_2); session.Handshake.UpdateHandshakeHash(data); TCipherSuite cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL; foreach (TCipherSuite item in clientHello.CipherSuites) { if (_SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, PrivateKey, _PSKIdentities, _ValidatePSK != null)) { cipherSuite = item; break; } } TKeyExchangeAlgorithm keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite); ServerHello serverHello = new ServerHello(); byte[] clientSessionID = new byte[32]; byte[] temp = session.SessionID.ToByteArray(); Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16); Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16); serverHello.SessionID = clientSessionID; // session.SessionID.ToByteArray(); serverHello.Random = new RandomData(); serverHello.Random.Generate(); serverHello.CipherSuite = (ushort)cipherSuite; serverHello.ServerVersion = session.Version; THashAlgorithm hash = THashAlgorithm.SHA256; TEllipticCurve curve = TEllipticCurve.secp521r1; if (clientHello.Extensions != null) { foreach (Extension extension in clientHello.Extensions) { if (extension.SpecifcExtension is ClientCertificateTypeExtension) { ClientCertificateTypeExtension clientCertificateType = extension.SpecifcExtension as ClientCertificateTypeExtension; //TCertificateType certificateType = TCertificateType.Unknown; //foreach (TCertificateType item in clientCertificateType.CertificateTypes) //{ //} //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType)); } else if (extension.SpecifcExtension is EllipticCurvesExtension) { EllipticCurvesExtension ellipticCurves = extension.SpecifcExtension as EllipticCurvesExtension; foreach (TEllipticCurve item in ellipticCurves.SupportedCurves) { if (EllipticCurveFactory.SupportedCurve(item)) { curve = item; break; } } } else if (extension.SpecifcExtension is ServerCertificateTypeExtension) { //serverHello.AddExtension(); } else if (extension.SpecifcExtension is SignatureAlgorithmsExtension) { SignatureAlgorithmsExtension signatureAlgorithms = extension.SpecifcExtension as SignatureAlgorithmsExtension; foreach (SignatureHashAlgorithm item in signatureAlgorithms.SupportedAlgorithms) { if (item.Signature == TSignatureAlgorithm.ECDSA) { hash = item.Hash; break; } } } } } session.Handshake.CipherSuite = cipherSuite; session.Handshake.ClientRandom = clientHello.Random; session.Handshake.ServerRandom = serverHello.Random; if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { EllipticCurvePointFormatsExtension pointFormatsExtension = new EllipticCurvePointFormatsExtension(); pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed); serverHello.AddExtension(pointFormatsExtension); } session.Handshake.MessageSequence = 1; SendResponse(session, serverHello, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { if (Certificate != null) { SendResponse(session, Certificate, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } ECDHEKeyExchange keyExchange = new ECDHEKeyExchange(); keyExchange.Curve = curve; keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm; keyExchange.ClientRandom = clientHello.Random; keyExchange.ServerRandom = serverHello.Random; keyExchange.GenerateEphemeralKey(); session.Handshake.KeyExchange = keyExchange; if (session.Version == DTLSRecord.DefaultVersion) { hash = THashAlgorithm.SHA1; } ECDHEServerKeyExchange serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, PrivateKey); SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; if (_RequireClientCertificate) { CertificateRequest certificateRequest = new CertificateRequest(); certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign); certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm() { Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA }); SendResponse(session, certificateRequest, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } } else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK) { ECDHEKeyExchange keyExchange = new ECDHEKeyExchange(); keyExchange.Curve = curve; keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm; keyExchange.ClientRandom = clientHello.Random; keyExchange.ServerRandom = serverHello.Random; keyExchange.GenerateEphemeralKey(); session.Handshake.KeyExchange = keyExchange; ECDHEPSKServerKeyExchange serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange); SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK) { PSKKeyExchange keyExchange = new PSKKeyExchange(); keyExchange.KeyExchangeAlgorithm = keyExchangeAlgorithm; keyExchange.ClientRandom = clientHello.Random; keyExchange.ServerRandom = serverHello.Random; session.Handshake.KeyExchange = keyExchange; //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange(); //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); //session.Handshake.MessageSequence++; } SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } } } break; case THandshakeType.ServerHello: break; case THandshakeType.HelloVerifyRequest: break; case THandshakeType.Certificate: Certificate clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509); if (clientCertificate.CertChain.Count > 0) { session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER); } session.Handshake.UpdateHandshakeHash(data); break; case THandshakeType.ServerKeyExchange: break; case THandshakeType.CertificateRequest: break; case THandshakeType.ServerHelloDone: break; case THandshakeType.CertificateVerify: CertificateVerify certificateVerify = CertificateVerify.Deserialise(stream, session.Version); session.Handshake.UpdateHandshakeHash(data); break; case THandshakeType.ClientKeyExchange: if ((session == null) || (session.Handshake.KeyExchange == null)) { } else { session.Handshake.UpdateHandshakeHash(data); byte[] preMasterSecret = null; if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { ECDHEClientKeyExchange clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes); } } else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK) { ECDHEPSKClientKeyExchange clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity); byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity); if (psk == null) { psk = _ValidatePSK(clientKeyExchange.PSKIdentity); if (psk != null) { _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk); } } if (psk != null) { ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; byte[] otherSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes); preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk); } } } else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK) { PSKClientKeyExchange clientKeyExchange = PSKClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity); byte[] psk = _PSKIdentities.GetKey(clientKeyExchange.PSKIdentity); if (psk == null) { psk = _ValidatePSK(clientKeyExchange.PSKIdentity); if (psk != null) { _PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk); } } if (psk != null) { ECDHEKeyExchange ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; byte[] otherSecret = new byte[psk.Length]; preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk); } } } if (preMasterSecret != null) { //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange); //TLSUtils.AssignCipher(session); session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake); } } break; case THandshakeType.Finished: Finished finished = Finished.Deserialise(stream); if (session != null) { byte[] handshakeHash = session.Handshake.GetHash(); byte[] calculatedVerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, true, handshakeHash); #if DEBUG Console.Write("Handshake Hash:"); TLSUtils.WriteToConsole(handshakeHash); Console.Write("Sent Verify:"); TLSUtils.WriteToConsole(finished.VerifyData); Console.Write("Calc Verify:"); TLSUtils.WriteToConsole(calculatedVerifyData); #endif if (TLSUtils.ByteArrayCompare(finished.VerifyData, calculatedVerifyData)) { SendChangeCipherSpec(session); session.Handshake.UpdateHandshakeHash(data); handshakeHash = session.Handshake.GetHash(); Finished serverFinished = new Finished(); serverFinished.VerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, false, handshakeHash); SendResponse(session, serverFinished, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } else { throw new Exception(); } } break; default: break; } } } }
public void ProcessHandshake(DTLSRecord record) { if (record == null) { throw new ArgumentNullException(nameof(record)); } var address = record.RemoteEndPoint.Serialize(); var session = this.Sessions.GetSession(address); var data = record.Fragment; if ((session != null) && session.IsEncypted(record)) { var count = 0; while ((session.Cipher == null) && (count < 50)) { System.Threading.Thread.Sleep(10); count++; } if (session.Cipher == null) { throw new Exception("Need Cipher for Encrypted Session"); } var sequenceNumber = ((long)record.Epoch << 48) + record.SequenceNumber; data = session.Cipher.DecodeCiphertext(sequenceNumber, (byte)TRecordType.Handshake, record.Fragment, 0, record.Fragment.Length); } using (var stream = new MemoryStream(data)) { var handshakeRecord = HandshakeRecord.Deserialise(stream); switch (handshakeRecord.MessageType) { case THandshakeType.HelloRequest: { //HelloReq break; } case THandshakeType.ClientHello: { var clientHello = ClientHello.Deserialise(stream); var cookie = clientHello.CalculateCookie(record.RemoteEndPoint, this._HelloSecret); if (clientHello.Cookie == null) { var vers = clientHello.ClientVersion; if (this.ServerVersion < vers) { vers = this.ServerVersion; } if (session == null) { session = new Session { SessionID = Guid.NewGuid(), RemoteEndPoint = record.RemoteEndPoint, Version = vers }; this.Sessions.AddSession(address, session); } else { session.Reset(); session.Version = vers; } session.ClientEpoch = record.Epoch; session.ClientSequenceNumber = record.SequenceNumber; var helloVerifyRequest = new HelloVerifyRequest { Cookie = cookie, ServerVersion = ServerVersion }; this.SendResponse(session, helloVerifyRequest, 0); break; } if (session != null && session.Cipher != null && !session.IsEncypted(record)) { session.Reset(); } if (!clientHello.Cookie.SequenceEqual(cookie)) { break; } var version = clientHello.ClientVersion; if (this.ServerVersion < version) { version = this.ServerVersion; } if (clientHello.SessionID == null) { if (session == null) { session = new Session(); session.NextSequenceNumber(); session.SessionID = Guid.NewGuid(); session.RemoteEndPoint = record.RemoteEndPoint; this.Sessions.AddSession(address, session); } } else { if (clientHello.SessionID.Length >= 16) { session = this.Sessions.GetSession(new Guid(clientHello.SessionID.Take(16).ToArray())); } if (session == null) { //need to Find Session session = new Session { SessionID = Guid.NewGuid() }; session.NextSequenceNumber(); session.RemoteEndPoint = record.RemoteEndPoint; this.Sessions.AddSession(address, session); } } session.Version = version; var cipherSuite = TCipherSuite.TLS_NULL_WITH_NULL_NULL; foreach (TCipherSuite item in clientHello.CipherSuites) { if (this._SupportedCipherSuites.ContainsKey(item) && CipherSuites.SupportedVersion(item, session.Version) && CipherSuites.SuiteUsable(item, this.PrivateKey, this._PSKIdentities, this._ValidatePSK != null)) { cipherSuite = item; break; } } var clientSessionID = new byte[32]; var temp = session.SessionID.ToByteArray(); Buffer.BlockCopy(temp, 0, clientSessionID, 0, 16); Buffer.BlockCopy(temp, 0, clientSessionID, 16, 16); var serverHello = new ServerHello { SessionID = clientSessionID, // session.SessionID.ToByteArray(); Random = new RandomData(), CipherSuite = (ushort)cipherSuite, ServerVersion = session.Version }; serverHello.Random.Generate(); session.Handshake.UpdateHandshakeHash(data); session.Handshake.CipherSuite = cipherSuite; session.Handshake.ClientRandom = clientHello.Random; session.Handshake.ServerRandom = serverHello.Random; var hash = THashAlgorithm.SHA256; var curve = TEllipticCurve.secp521r1; if (clientHello.Extensions != null) { foreach (var extension in clientHello.Extensions) { if (extension.SpecificExtension is ClientCertificateTypeExtension) { var clientCertificateType = extension.SpecificExtension as ClientCertificateTypeExtension; //TCertificateType certificateType = TCertificateType.Unknown; //foreach (TCertificateType item in clientCertificateType.CertificateTypes) //{ //} //serverHello.AddExtension(new ClientCertificateTypeExtension(certificateType)); } else if (extension.SpecificExtension is EllipticCurvesExtension) { var ellipticCurves = extension.SpecificExtension as EllipticCurvesExtension; foreach (var item in ellipticCurves.SupportedCurves) { if (EllipticCurveFactory.SupportedCurve(item)) { curve = item; break; } } } else if (extension.SpecificExtension is ServerCertificateTypeExtension) { //serverHello.AddExtension(); } else if (extension.SpecificExtension is SignatureAlgorithmsExtension) { var signatureAlgorithms = extension.SpecificExtension as SignatureAlgorithmsExtension; foreach (var item in signatureAlgorithms.SupportedAlgorithms) { if (item.Signature == TSignatureAlgorithm.ECDSA) { hash = item.Hash; break; } } } } } var keyExchangeAlgorithm = CipherSuites.GetKeyExchangeAlgorithm(cipherSuite); if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { var pointFormatsExtension = new EllipticCurvePointFormatsExtension(); pointFormatsExtension.SupportedPointFormats.Add(TEllipticCurvePointFormat.Uncompressed); serverHello.AddExtension(pointFormatsExtension); } session.Handshake.MessageSequence = 1; this.SendResponse(session, serverHello, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { if (this.Certificate != null) { this.SendResponse(session, this.Certificate, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } var keyExchange = new ECDHEKeyExchange { Curve = curve, KeyExchangeAlgorithm = keyExchangeAlgorithm, ClientRandom = clientHello.Random, ServerRandom = serverHello.Random }; keyExchange.GenerateEphemeralKey(); session.Handshake.KeyExchange = keyExchange; if (session.Version == DTLSRecord.DefaultVersion) { hash = THashAlgorithm.SHA1; } var serverKeyExchange = new ECDHEServerKeyExchange(keyExchange, hash, TSignatureAlgorithm.ECDSA, this.PrivateKey); this.SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; if (this._RequireClientCertificate) { var certificateRequest = new CertificateRequest(); certificateRequest.CertificateTypes.Add(TClientCertificateType.ECDSASign); certificateRequest.SupportedAlgorithms.Add(new SignatureHashAlgorithm() { Hash = THashAlgorithm.SHA256, Signature = TSignatureAlgorithm.ECDSA }); this.SendResponse(session, certificateRequest, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } } else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK) { var keyExchange = new ECDHEKeyExchange { Curve = curve, KeyExchangeAlgorithm = keyExchangeAlgorithm, ClientRandom = clientHello.Random, ServerRandom = serverHello.Random }; keyExchange.GenerateEphemeralKey(); session.Handshake.KeyExchange = keyExchange; var serverKeyExchange = new ECDHEPSKServerKeyExchange(keyExchange); this.SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; } else if (keyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK) { var keyExchange = new PSKKeyExchange { KeyExchangeAlgorithm = keyExchangeAlgorithm, ClientRandom = clientHello.Random, ServerRandom = serverHello.Random }; session.Handshake.KeyExchange = keyExchange; //Need to be able to hint identity?? for PSK if not hinting don't really need key exchange message //PSKServerKeyExchange serverKeyExchange = new PSKServerKeyExchange(); //SendResponse(session, serverKeyExchange, session.Handshake.MessageSequence); //session.Handshake.MessageSequence++; } this.SendResponse(session, new ServerHelloDone(), session.Handshake.MessageSequence); session.Handshake.MessageSequence++; break; } case THandshakeType.ServerHello: { break; } case THandshakeType.HelloVerifyRequest: { break; } case THandshakeType.Certificate: { var clientCertificate = Certificate.Deserialise(stream, TCertificateType.X509); if (clientCertificate.CertChain.Count > 0) { session.CertificateInfo = Certificates.GetCertificateInfo(clientCertificate.CertChain[0], TCertificateFormat.CER); } session.Handshake.UpdateHandshakeHash(data); break; } case THandshakeType.ServerKeyExchange: { break; } case THandshakeType.CertificateRequest: { break; } case THandshakeType.ServerHelloDone: { break; } case THandshakeType.CertificateVerify: { var certificateVerify = CertificateVerify.Deserialise(stream, session.Version); session.Handshake.UpdateHandshakeHash(data); } break; case THandshakeType.ClientKeyExchange: { if ((session == null) || (session.Handshake.KeyExchange == null)) { break; } session.Handshake.UpdateHandshakeHash(data); byte[] preMasterSecret = null; if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_ECDSA) { var clientKeyExchange = ECDHEClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; preMasterSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes); } } else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.ECDHE_PSK) { var clientKeyExchange = ECDHEPSKClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity); var psk = this._PSKIdentities.GetKey(clientKeyExchange.PSKIdentity); if (psk == null) { psk = this._ValidatePSK(clientKeyExchange.PSKIdentity); if (psk != null) { this._PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk); } } if (psk != null) { var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; var otherSecret = ecKeyExchange.GetPreMasterSecret(clientKeyExchange.PublicKeyBytes); preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk); } } } else if (session.Handshake.KeyExchange.KeyExchangeAlgorithm == TKeyExchangeAlgorithm.PSK) { var clientKeyExchange = PSKClientKeyExchange.Deserialise(stream); if (clientKeyExchange != null) { session.PSKIdentity = Encoding.UTF8.GetString(clientKeyExchange.PSKIdentity); var psk = this._PSKIdentities.GetKey(clientKeyExchange.PSKIdentity); if (psk == null) { psk = this._ValidatePSK(clientKeyExchange.PSKIdentity); if (psk != null) { this._PSKIdentities.AddIdentity(clientKeyExchange.PSKIdentity, psk); } } if (psk != null) { var ecKeyExchange = session.Handshake.KeyExchange as ECDHEKeyExchange; var otherSecret = new byte[psk.Length]; preMasterSecret = TLSUtils.GetPSKPreMasterSecret(otherSecret, psk); } } } if (preMasterSecret != null) { //session.MasterSecret = TLSUtils.CalculateMasterSecret(preMasterSecret, session.KeyExchange); //TLSUtils.AssignCipher(session); session.Cipher = TLSUtils.AssignCipher(preMasterSecret, false, session.Version, session.Handshake); } break; } case THandshakeType.Finished: { var finished = Finished.Deserialise(stream); if (session == null) { break; } var handshakeHash = session.Handshake.GetHash(session.Version); var calculatedVerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, true, handshakeHash); if (!finished.VerifyData.SequenceEqual(calculatedVerifyData)) { throw new Exception(); } this.SendChangeCipherSpec(session); session.Handshake.UpdateHandshakeHash(data); handshakeHash = session.Handshake.GetHash(session.Version); var serverFinished = new Finished { VerifyData = TLSUtils.GetVerifyData(session.Version, session.Handshake, false, false, handshakeHash) }; this.SendResponse(session, serverFinished, session.Handshake.MessageSequence); session.Handshake.MessageSequence++; break; } default: break; } } }