Example #1
0
        private async Task <TcpClient> EstablishConnection(int iteration, IPEndPoint[] endpoints, TimeSpan timeout)
        {
            using (var cancelEstablishConnection = CancellationTokenSource.CreateLinkedTokenSource(this.cancellationTokenSource.Token))
            {
                cancelEstablishConnection.CancelAfter(timeout);

                var connectionTasks = new List <Task <TcpClient> >();
                for (int i = 0; i < endpoints.Length; i++)
                {
                    connectionTasks.Add(this.EstablishConnection(iteration, endpoints[i], timeout, cancelEstablishConnection.Token));
                }

                TcpClient successfulClient = null;
                while (connectionTasks.Count > 0)
                {
                    var resolvedTask = await Task.WhenAny(connectionTasks.ToArray());

                    if (resolvedTask.IsCompleted && !resolvedTask.IsFaulted && !resolvedTask.IsCanceled)
                    {
                        cancelEstablishConnection.Cancel();
                        successfulClient = resolvedTask.Result;
                    }

                    connectionTasks.Remove(resolvedTask);
                }

                if (successfulClient != null)
                {
                    return(successfulClient);
                }

                throw SecureTransportException.ConnectionFailed();
            }
        }
Example #2
0
        /// <summary>
        /// Send a packet to the other side of the connection asynchronously.
        /// </summary>
        /// <param name="data">Data to send</param>
        /// <returns>A <see cref="Task"/> that tracks completion of the send operation</returns>
        public Task SendAsync(byte[] data)
        {
            if (data == null)
            {
                throw new ArgumentNullException(nameof(data));
            }

            // Do not send the packet if the outgoing packets queue
            // is marked complete for adding.
            if (this.disposed || this.outgoingPackets.IsAddingCompleted)
            {
                return(Task.FromResult <object>(null));
            }

            var packet = default(Packet);

            packet.Id               = Interlocked.Increment(ref this.nextPacketId);
            packet.Data             = data;
            packet.CompletionSource = new TaskCompletionSource <object>();

            SecureTransportEventSource.Log.Send(this.transportId, this.connectionId, packet.Id, packet.Data.Length);
            if (this.outgoingPackets.TryAdd(packet))
            {
                // Signal that a packet is avaliable to send. This wakes up the PushPackets task
                // which actually sends the packet to the other side.
                this.outgoingPacketsAvailable.Release();
            }
            else
            {
                SecureTransportEventSource.Log.SendQueueFull(this.transportId, this.connectionId, packet.Id, packet.Data.Length);
                throw SecureTransportException.SendQueueFull();
            }

            return(packet.CompletionSource.Task);
        }
Example #3
0
        /// <summary>
        /// Stops the server with a specified timeout
        /// </summary>
        /// <param name="timeout">Timeout waiting for server stop</param>
        public void Stop(TimeSpan timeout)
        {
            if (this.cancellationTokenSource == null)
            {
                SecureTransportEventSource.Log.StopFailed_NotStarted(this.transportId);
                throw SecureTransportException.NotStarted();
            }

            this.StopAndCloseConnections();

            if (!this.hasStopped.Wait(timeout))
            {
                SecureTransportEventSource.Log.StopTimedout(this.transportId);
                throw SecureTransportException.StopTimedout();
            }

            this.cancellationTokenSource = null;
            SecureTransportEventSource.Log.Stopped(this.transportId);
        }
Example #4
0
        private async Task <TcpClient> EstablishConnection(int iteration, IPEndPoint endpoint, TimeSpan timeout, CancellationToken cancellationToken)
        {
            TcpClient client = new TcpClient();

            // Disable the Nagle algorithm. When NoDelay is set to true, TcpClient does not wait
            // until it has collected a significant amount of outgoing data before sending a packet.
            // This ensures that requests are sent out to the server immediately and helps reduce latency.
            client.NoDelay = true;
            var timer = Stopwatch.StartNew();

            try
            {
                SecureTransportEventSource.Log.EstablishConnection(this.transportId, iteration, endpoint.ToString(), (long)timeout.TotalMilliseconds);
                Task connectTask = client.ConnectAsync(endpoint.Address, endpoint.Port);
                Task cancelTask  = Task.Delay(timeout, cancellationToken);

                await Task.WhenAny(connectTask, cancelTask);

                if (connectTask.IsCompleted)
                {
                    await     connectTask;
                    TcpClient connectedClient = client;
                    client = null;
                    SecureTransportEventSource.Log.ConnectSucceeded(this.transportId, iteration, endpoint.Address.ToString(), endpoint.Port, timer.ElapsedMilliseconds);
                    return(connectedClient);
                }

                throw SecureTransportException.CancellationRequested($"Connection attempt to {endpoint} was cancelled");
            }
            catch (Exception ex)
            {
                SecureTransportEventSource.Log.ConnectFailed(this.transportId, iteration, endpoint.Address.ToString(), endpoint.Port, ex.Message, timer.ElapsedMilliseconds);
                throw;
            }
            finally
            {
                client?.Close();
            }
        }
Example #5
0
        /// <summary>
        /// Starts server with a specified client SSL validation timeout
        /// </summary>
        /// <param name="endpoint">local endpoint start server</param>
        /// <param name="validationTimeout">Client SSL validation timeout</param>
        /// <returns>Server task</returns>
        public Task StartServer(IPEndPoint endpoint, TimeSpan validationTimeout)
        {
            if (this.cancellationTokenSource != null)
            {
                SecureTransportEventSource.Log.StartServerFailed_AlreadyStarted(this.transportId);
                throw SecureTransportException.AlreadyStarted();
            }

            this.cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(this.rootCancellationToken);
            this.hasStarted.Reset();
            this.hasStopped.Reset();
            SecureTransportEventSource.Log.StartServer(this.transportId, endpoint.ToString());
            var task = Task.Run(() => this.StartListening(endpoint, validationTimeout), this.cancellationTokenSource.Token);

            if (!this.hasStarted.Wait(DefaultStartTimeout))
            {
                SecureTransportEventSource.Log.StartTimedout(this.transportId);
                throw SecureTransportException.StartTimedout();
            }

            return(task);
        }
Example #6
0
        /// <summary>
        /// Starts client with a specified server SSL validation timeout
        /// </summary>
        /// <param name="validationTimeout">Connection SSL validation timeout</param>
        /// <param name="endpoints">Endpoints to connect to</param>
        /// <returns>Client task</returns>
        public Task StartClient(TimeSpan validationTimeout, params IPEndPoint[] endpoints)
        {
            if (this.cancellationTokenSource != null)
            {
                SecureTransportEventSource.Log.StartClientFailed_AlreadyStarted(this.transportId);
                throw SecureTransportException.AlreadyStarted();
            }

            this.cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(this.rootCancellationToken);
            this.hasStarted.Reset();
            this.hasStopped.Reset();
            SecureTransportEventSource.Log.StartClient(this.transportId, endpoints.Length);
            var task = Task.Run(() => this.StartConnecting(endpoints, validationTimeout), this.cancellationTokenSource.Token);

            if (!this.hasStarted.Wait(DefaultStartTimeout))
            {
                SecureTransportEventSource.Log.StartTimedout(this.transportId);
                throw SecureTransportException.StartTimedout();
            }

            return(task);
        }
Example #7
0
        /// <summary>
        /// Gets the validated stream on client.
        /// </summary>
        /// <param name="serverName">Name of the server.</param>
        /// <param name="client">TCP client</param>
        /// <param name="timeout">Time to wait for the validation</param>
        /// <param name="cancellationToken">Token to be observed for cancellation signal</param>
        /// <returns>A Task that resolves to the validated stream</returns>
        public virtual async Task <Stream> AuthenticateAsClient(string serverName, TcpClient client, TimeSpan timeout, CancellationToken cancellationToken)
        {
            if (client == null)
            {
                throw new ArgumentNullException("client");
            }

            Stream    clientStream = null;
            SslStream sslStream    = null;

            try
            {
                clientStream = client.GetStream();

                sslStream = new SslStream(
                    clientStream,
                    false,
                    this.configuration.RemoteCertificateValidationCallback ?? this.certificateValidator.ValidateServerCertificate,
                    this.configuration.LocalCertificateSelectionCallback);

                X509CertificateCollection coll = new X509CertificateCollection(this.identities.ClientIdentities);

                SecureTransportEventSource.Log.AuthenticateAsClient(this.transportId, (int)timeout.TotalMilliseconds, this.configuration.MustCheckCertificateRevocation, this.configuration.MustCheckCertificateTrustChain);
                Task task = sslStream.AuthenticateAsClientAsync(
                    serverName,
                    coll,
                    this.configuration.SupportedProtocols,
                    this.configuration.MustCheckCertificateRevocation);

                await Task.WhenAny(task, Task.Delay(timeout, cancellationToken));

                if (!task.IsCompleted)
                {
                    throw SecureTransportException.SslValidationTimedOut();
                }

                // Task is already completed, await it to ensure that it successfully
                // completed.
                await task;

                return(sslStream);
            }
            catch (Exception ex)
            {
                SecureTransportEventSource.Log.AuthenticateAsClientFailed(this.transportId, ex.ToString());

                if (sslStream != null)
                {
                    sslStream.Dispose();
                }
                else
                {
                    if (clientStream != null)
                    {
                        clientStream.Dispose();
                    }
                }

                throw;
            }
        }
Example #8
0
        /// <summary>
        /// Gets the validated stream on server.
        /// </summary>
        /// <param name="client">TCP client</param>
        /// <param name="timeout">Time to wait for the validation</param>
        /// <param name="cancellationToken">Token to be observed for cancellation signal</param>
        /// <returns>A Task that resolves to the validated stream</returns>
        public virtual async Task <Stream> AuthenticateAsServer(TcpClient client, TimeSpan timeout, CancellationToken cancellationToken)
        {
            if (client == null)
            {
                throw new ArgumentNullException("client");
            }

            SslStream sslStream    = null;
            Stream    clientStream = client.GetStream();

            try
            {
                // We want to use ONLY the first cert in the given list
                X509Certificate serverCertificate = this.identities.ServerIdentity;
                if (serverCertificate == null)
                {
                    throw SecureTransportException.NoServerCertificate();
                }

                SecureTransportEventSource.Log.AuthenticateAsServer(this.transportId, (int)timeout.TotalMilliseconds, this.configuration.MustCheckCertificateRevocation, this.configuration.MustCheckCertificateTrustChain);
                sslStream = new SslStream(
                    clientStream,
                    false,
                    this.configuration.RemoteCertificateValidationCallback ?? this.certificateValidator.ValidateClientCertificate,
                    this.configuration.LocalCertificateSelectionCallback);

                Task task = sslStream.AuthenticateAsServerAsync(
                    serverCertificate,
                    clientCertificateRequired: this.configuration.IsClientCertificateRequired,
                    enabledSslProtocols: this.configuration.SupportedProtocols,
                    checkCertificateRevocation: this.configuration.MustCheckCertificateRevocation);

                await Task.WhenAny(task, Task.Delay(timeout, cancellationToken));

                if (!task.IsCompleted)
                {
                    throw SecureTransportException.SslValidationTimedOut();
                }

                // Task is already completed, await it to ensure that it successfully
                // completed.
                await task;

                return(sslStream);
            }
            catch (Exception ex)
            {
                SecureTransportEventSource.Log.AuthenticateAsServerFailed(this.transportId, ex.ToString());

                if (sslStream != null)
                {
                    sslStream.Dispose();
                }
                else
                {
                    clientStream.Dispose();
                    client.Close();
                }

                throw;
            }
        }
Example #9
0
        private async Task HandleConnection(TcpClient client, Stream secureStream, string remoteIdentity)
        {
            string remoteHostAddress = GetRemoteHostAddress(client);
            long   connectionId      = Interlocked.Increment(ref this.lastAssignedConnectionId);
            var    configuration     = new Connection.Configuration
            {
                RemoteIdentity           = remoteIdentity,
                MaxLifeSpan              = this.MaxConnectionLifeSpan,
                MaxConnectionIdleTime    = this.configuration.MaxConnectionIdleTime,
                SendBufferSize           = this.configuration.SendBufferSize,
                ReceiveBufferSize        = this.configuration.ReceiveBufferSize,
                SendQueueLength          = this.configuration.SendQueueLength,
                MaxUnflushedPacketsCount = this.configuration.MaxUnflushedPacketsCount,
            };

            using (var connection = new Connection(this.transportId, connectionId, client, secureStream, configuration, this.cancellationTokenSource.Token))
            {
                if (!this.activeConnections.TryAdd(connectionId, connection))
                {
                    Debug.Assert(false, "Failed to add connection to the active connections dictionary");
                    throw SecureTransportException.Unexpected("Failed to add connection to the active connections dictionary");
                }

                try
                {
                    if (this.OnProtocolNegotiation != null)
                    {
                        connection.DoProtocolNegotiation = this.OnProtocolNegotiation;
                    }

                    connection.UseNetworkByteOrder = this.UseNetworkByteOrder;

                    await connection.Start(this.configuration.CommunicationProtocolVersion);

                    this.instrumentation.ConnectionCreated(connectionId, GetRemoteEndPoint(client), remoteIdentity);

                    if (this.OnNewConnection != null)
                    {
                        var timer = Stopwatch.StartNew();
                        this.OnNewConnection(connection);
                        SecureTransportEventSource.Log.OnNewConnection(this.transportId, connectionId, timer.ElapsedMilliseconds);
                    }

                    await connection.PullPackets();

                    if (this.OnConnectionLost != null)
                    {
                        var timer = Stopwatch.StartNew();
                        this.OnConnectionLost();
                        SecureTransportEventSource.Log.OnConnectionLost(this.transportId, connectionId, timer.ElapsedMilliseconds);
                    }
                }
                catch (Exception ex)
                {
                    SecureTransportEventSource.Log.HandleConnectionFailed(this.transportId, connectionId, ex.ToString());
                    throw;
                }
                finally
                {
                    this.instrumentation.ConnectionClosed(connectionId, GetRemoteEndPoint(client), remoteIdentity);

                    Connection removedConnection;
                    if (!this.activeConnections.TryRemove(connectionId, out removedConnection))
                    {
                        Debug.Assert(false, "Failed to remove connection from the active connections dictionary");
                        throw SecureTransportException.Unexpected("Failed to remove connection from the active connections dictionary");
                    }
                }
            }
        }
Example #10
0
        private async Task StartListening(IPEndPoint localEndpoint, TimeSpan validationTimeout)
        {
            this.hasStarted.Set();
            try
            {
                this.listener = new TcpListener(localEndpoint);
                this.listener.Start();

                int iteration = 0;
                int consecutiveFailureCount = 0;
                while ((!this.cancellationTokenSource.Token.IsCancellationRequested) && (consecutiveFailureCount < ConsecutiveAcceptFailuresLimit))
                {
                    bool mustReleaseSemaphore = false;
                    try
                    {
                        if (!await this.acceptConnectionsSemaphore.WaitAsync(validationTimeout, this.cancellationTokenSource.Token))
                        {
                            throw SecureTransportException.AcceptConnectionTimedout();
                        }

                        mustReleaseSemaphore = true;

                        iteration++;
                        Task <TcpClient> acceptTask = this.listener.AcceptTcpClientAsync();

                        using (var cancelSource = CancellationTokenSource.CreateLinkedTokenSource(this.cancellationTokenSource.Token))
                        {
                            // AcceptTcpClientAsync does not have an overload that takes a cancellationToken, so it
                            // does not return even if cancellation is in progress.  The following task will observe
                            // the cancellation token and go into the canceled state when cancellation is requested.
                            Task delayTask = Task.Delay(Timeout.InfiniteTimeSpan, cancelSource.Token);

                            await Task.WhenAny(acceptTask, delayTask);

                            if (acceptTask.IsCompleted)
                            {
                                cancelSource.Cancel();
                                TcpClient client = await acceptTask;
                                consecutiveFailureCount = 0;

                                mustReleaseSemaphore = false;
                                var ignoredTask = this.AcceptConnection(iteration, client, validationTimeout)
                                                  .ContinueWith(t => this.acceptConnectionsSemaphore.Release());
                            }
                        }
                    }
                    catch (Exception ex)
                    {
                        consecutiveFailureCount++;
                        SecureTransportEventSource.Log.AcceptTcpClientFailed(this.transportId, iteration, consecutiveFailureCount, ex.ToString());
                    }
                    finally
                    {
                        if (mustReleaseSemaphore)
                        {
                            this.acceptConnectionsSemaphore.Release();
                        }
                    }
                }

                this.listener.Stop();
            }
            finally
            {
                SecureTransportEventSource.Log.ListenerStopped(this.transportId);
                this.hasStopped.Set();
            }
        }
Example #11
0
        public static X509Certificate[] GetCertificatesFromThumbPrintOrFileName(string[] paths)
        {
            List <X509Certificate> certificates = new List <X509Certificate>();

            try
            {
                if (paths != null)
                {
                    for (int i = 0; i < paths.Length; i++)
                    {
                        try
                        {
                            string path = paths[i].ToUpper();
                            if (path.StartsWith("FILE:"))
                            {
                                certificates[i] = X509Certificate.CreateFromCertFile(path.Substring("FILE:".Length));
                                continue;
                            }

                            StoreName     name;
                            StoreLocation location;
                            string        thumbprint;

                            string[] pieces = path.Split('/');
                            if (pieces.Length == 1)
                            {
                                // No store name in the cert. Use default
                                name       = StoreName.My;
                                location   = StoreLocation.LocalMachine;
                                thumbprint = path;
                            }
                            else
                            {
                                name       = (StoreName)Enum.Parse(typeof(StoreName), pieces[0], true);
                                location   = (StoreLocation)Enum.Parse(typeof(StoreLocation), pieces[1], true);
                                thumbprint = pieces[2];
                            }

                            using (X509Store store = new X509Store(name, location))
                            {
                                store.Open(OpenFlags.ReadOnly);
                                X509Certificate2 found = null;
                                foreach (X509Certificate2 result in store.Certificates.Find(X509FindType.FindByThumbprint, thumbprint, false))
                                {
                                    if (found == null || found.Equals(result))
                                    {
                                        found = result;
                                    }
                                    else
                                    {
                                        throw SecureTransportException.DuplicateCertificates(thumbprint);
                                    }
                                }

                                if (found == null)
                                {
                                    throw SecureTransportException.MissingCertificate(thumbprint);
                                }

                                certificates.Add(found);
                            }
                        }
                        catch (Exception ex)
                        {
                            SecureTransportEventSource.Log.GetCertificatesFromThumbprintOrFileNameFailed(ex.ToString());
                        }
                    }
                }

                var certificatesToReturn = new X509Certificate[certificates.Count];
                for (int i = 0; i < certificates.Count; i++)
                {
                    certificatesToReturn[i] = certificates[i];
                    certificates[i]         = null;
                }

                return(certificatesToReturn);
            }
            finally
            {
                foreach (var cert in certificates)
                {
                    cert?.Dispose();
                }
            }
        }