internal static void TestTaskRunnerReadWrite(
            ISocketWrapper serverListener,
            PayloadWriter payloadWriter)
        {
            using (ISocketWrapper serverSocket = serverListener.Accept())
            {
                System.IO.Stream inputStream  = serverSocket.InputStream;
                System.IO.Stream outputStream = serverSocket.OutputStream;

                payloadWriter.WriteTestData(outputStream);
                // Now process the bytes flowing in from the client.
                List <object[]> rowsReceived = PayloadReader.Read(inputStream);

                // Validate rows received.
                Assert.Equal(10, rowsReceived.Count);
                for (int i = 0; i < 10; ++i)
                {
                    // Two UDFs registered, thus expecting two columns.
                    // Refer to TestData.GetDefaultCommandPayload().
                    object[] row = rowsReceived[i];
                    Assert.Equal(2, rowsReceived[i].Length);
                    Assert.Equal($"udf2 udf1 {i}", row[0]);
                    Assert.Equal(i + i, row[1]);
                }
            }
        }
Example #2
0
        public void TestsSimpleWorkerTaskRunners(string version)
        {
            using ISocketWrapper serverListener = SocketFactory.CreateSocket();
            var ipEndpoint = (IPEndPoint)serverListener.LocalEndPoint;

            serverListener.Listen();

            var typedVersion = new Version(version);
            var simpleWorker = new SimpleWorker(typedVersion);

            Task clientTask = Task.Run(() => simpleWorker.Run(ipEndpoint.Port));

            PayloadWriter payloadWriter = new PayloadWriterFactory().Create(typedVersion);

            using (ISocketWrapper serverSocket = serverListener.Accept())
            {
                if ((typedVersion.Major == 3 && typedVersion.Minor >= 2) || typedVersion.Major > 3)
                {
                    int pid = SerDe.ReadInt32(serverSocket.InputStream);
                }

                TaskRunnerTests.TestTaskRunnerReadWrite(serverSocket, payloadWriter);
            }

            Assert.True(clientTask.Wait(5000));
        }
Example #3
0
        /// <summary>
        /// Listen to the server socket and accept new TCP connection from JVM side. Then create new TaskRunner instance and
        /// add it to waitingTaskRunners queue.
        /// </summary>
        private void StartDaemonServer(ISocketWrapper listener)
        {
            logger.LogInfo("StartDaemonServer ...");

            bool   sparkReuseWorker = false;
            string envVar           = Environment.GetEnvironmentVariable("SPARK_REUSE_WORKER"); // this envVar is set in JVM side
            var    secret           = Environment.GetEnvironmentVariable("PYTHON_WORKER_FACTORY_SECRET");

            if ((envVar != null) && envVar.Equals("1"))
            {
                sparkReuseWorker = true;
            }

            try
            {
                int trId          = 1;
                int workThreadNum = 0;

                while (true)
                {
                    var socket = listener.Accept();
                    logger.LogInfo("Connection accepted for taskRunnerId: {0}", trId);
                    using (var s = socket.GetStream())
                    {
                        SerDe.Write(s, trId); // write taskRunnerId to JVM side
                        s.Flush();
                    }
                    TaskRunner taskRunner = new TaskRunner(trId, socket, sparkReuseWorker, secret);
                    waitingTaskRunners.Add(taskRunner);
                    taskRunnerRegistry[trId] = taskRunner;
                    trId++;

                    int taskRunnerNum = taskRunnerRegistry.Count();
                    while (workThreadNum < taskRunnerNum)  // launch new work thread as appropriate
                    {
                        // start threads that do the actual work of running tasks, there are several options here:
                        // Option 1. TPL - Task Parallel Library
                        // Option 2. ThreadPool
                        // Option 3. Self managed threads group
                        // Option 3 is selected after testing in real cluster because it can get the best performance.
                        // When using option 1 or 2, it is observered that the boot time may be as large as 50 ~ 60s.
                        // But it is always less than 1s for option 3. Perhaps this is because TPL and ThreadPool are not
                        // suitable for long running threads.
                        new Thread(FetchAndRun).Start();
                        workThreadNum++;
                    }
                }
            }
            catch (Exception e)
            {
                logger.LogError("StartDaemonServer exception, will exit");
                logger.LogException(e);
                Environment.Exit(-1);
            }
        }
Example #4
0
        /// <summary>
        /// Starts listening to any connection from JVM.
        /// </summary>
        /// <param name="listener"></param>
        private void StartServer(ISocketWrapper listener)
        {
            try
            {
                long connectionId     = 1;
                int  numWorkerThreads = 0;

                while (_isRunning)
                {
                    ISocketWrapper socket     = listener.Accept();
                    var            connection =
                        new CallbackConnection(connectionId, socket, _callbackHandlers);

                    _waitingConnections.Add(connection);
                    _connections[connectionId] = connection;
                    ++connectionId;

                    int numConnections = CurrentNumConnections;

                    // Start worker thread until there are at least as many worker threads
                    // as there are CallbackConnections. CallbackConnections are expected
                    // to stay open and reuse the socket to service repeated callback
                    // requests. However, if there is an issue with a connection, then
                    // CallbackConnection.Run will return, freeing up extra worker threads
                    // to service any _waitingConnections.
                    //
                    // For example,
                    // Assume there were 5 worker threads, each servicing a CallbackConnection
                    // (5 total healthy connections). If 2 CallbackConnection sockets closed
                    // unexpectedly, then there would be 5 worker threads and 3 healthy
                    // connections. If a new connection request arrived, then the
                    // CallbackConnection would be added to the _waitingConnections collection
                    // and no new worker threads would be started (2 worker threads are already
                    // waiting to take CallbackConnections from _waitingConnections).
                    while (numWorkerThreads < numConnections)
                    {
                        new Thread(RunWorkerThread)
                        {
                            IsBackground = true
                        }.Start();
                        ++numWorkerThreads;
                    }

                    s_logger.LogInfo(
                        $"Pool snapshot: [NumThreads:{numWorkerThreads}], " +
                        $"[NumConnections:{numConnections}]");
                }
            }
            catch (Exception e)
            {
                s_logger.LogError($"StartServer() exits with exception: {e}");
                Shutdown();
            }
        }
Example #5
0
        internal int StartUpdateServer()
        {
            innerSocket.Listen();
            Task.Run(() =>
            {
                try
                {
                    IFormatter formatter = new BinaryFormatter();
                    using (var s = innerSocket.Accept())
                        using (var ns = s.GetStream())
                        {
                            while (!serverShutdown)
                            {
                                int numUpdates = SerDe.ReadInt(ns);
                                for (int i = 0; i < numUpdates; i++)
                                {
                                    var ms     = new MemoryStream(SerDe.ReadBytes(ns));
                                    var update = (Tuple <int, dynamic>)formatter.Deserialize(ms);

                                    if (Accumulator.accumulatorRegistry.ContainsKey(update.Item1))
                                    {
                                        Accumulator accumulator = Accumulator.accumulatorRegistry[update.Item1];
                                        accumulator.GetType().GetMethod("Add").Invoke(accumulator, new object[] { update.Item2 });
                                    }
                                    else
                                    {
                                        Console.Error.WriteLine("WARN: cann't find update.Key: {0} for accumulator, will create a new one", update.Item1);
                                        var genericAccumulatorType  = typeof(Accumulator <>);
                                        var specificAccumulatorType = genericAccumulatorType.MakeGenericType(update.Item2.GetType());
                                        Activator.CreateInstance(specificAccumulatorType, new object[] { update.Item1, update.Item2 });
                                    }
                                }
                                ns.WriteByte((byte)1); // acknowledge byte other than -1
                                ns.Flush();
                            }
                        }
                }
                catch (SocketException e)
                {
                    if (e.ErrorCode != 10004)   // A blocking operation was interrupted by a call to WSACancelBlockingCall - ISocketWrapper.Close canceled Accep() as expected
                    {
                        throw e;
                    }
                }
                catch (Exception e)
                {
                    logger.LogError(e.ToString());
                    throw;
                }
            });

            return((innerSocket.LocalEndPoint as IPEndPoint).Port);
        }
        private void TestCallbackConnection(
            ConcurrentDictionary <int, ICallbackHandler> callbackHandlersDict,
            ITestCallbackHandler callbackHandler,
            int inputToHandler,
            CancellationToken token)
        {
            using ISocketWrapper serverListener = SocketFactory.CreateSocket();
            serverListener.Listen();

            var ipEndpoint = (IPEndPoint)serverListener.LocalEndPoint;

            using ISocketWrapper clientSocket = SocketFactory.CreateSocket();
            clientSocket.Connect(ipEndpoint.Address, ipEndpoint.Port);

            // Don't use "using" here. The CallbackConnection will dispose the socket.
            ISocketWrapper serverSocket       = serverListener.Accept();
            var            callbackConnection = new CallbackConnection(0, serverSocket, callbackHandlersDict);
            Task           task = Task.Run(() => callbackConnection.Run(token));

            if (token.IsCancellationRequested)
            {
                task.Wait();
                Assert.False(callbackConnection.IsRunning);
            }
            else
            {
                WriteAndReadTestData(clientSocket, callbackHandler, inputToHandler);

                if (callbackHandler.Throws)
                {
                    task.Wait();
                    Assert.False(callbackConnection.IsRunning);
                }
                else
                {
                    Assert.True(callbackConnection.IsRunning);

                    // Clean up CallbackConnection
                    Stream outputStream = clientSocket.OutputStream;
                    SerDe.Write(outputStream, (int)CallbackConnection.ConnectionStatus.REQUEST_CLOSE);
                    outputStream.Flush();
                    task.Wait();
                    Assert.False(callbackConnection.IsRunning);
                }
            }
        }
Example #7
0
        private void TestCallbackConnection(
            ConcurrentDictionary <int, ICallbackHandler> callbackHandlersDict,
            ITestCallbackHandler callbackHandler,
            int inputToHandler,
            CancellationToken token)
        {
            using ISocketWrapper serverListener = SocketFactory.CreateSocket();
            serverListener.Listen();

            var            ipEndpoint   = (IPEndPoint)serverListener.LocalEndPoint;
            ISocketWrapper clientSocket = SocketFactory.CreateSocket();

            clientSocket.Connect(ipEndpoint.Address, ipEndpoint.Port);

            var callbackConnection = new CallbackConnection(0, clientSocket, callbackHandlersDict);

            Task.Run(() => callbackConnection.Run(token));

            using ISocketWrapper serverSocket = serverListener.Accept();
            WriteAndReadTestData(serverSocket, callbackHandler, inputToHandler, token);
        }
Example #8
0
        /// <summary>
        /// Starts listening to any connection from JVM.
        /// </summary>
        private void StartServer(ISocketWrapper listener)
        {
            try
            {
                bool reuseWorker =
                    "1".Equals(Environment.GetEnvironmentVariable("SPARK_REUSE_WORKER"));

                string secret = Utils.SettingUtils.GetWorkerFactorySecret();

                int taskRunnerId     = 1;
                int numWorkerThreads = 0;

                while (true)
                {
                    ISocketWrapper socket = listener.Accept();
                    s_logger.LogInfo($"New connection accepted for TaskRunner [{taskRunnerId}]");

                    bool authStatus = true;
                    if (!string.IsNullOrWhiteSpace(secret))
                    {
                        // The Spark side expects the PID from a forked process.
                        // In .NET implementation, a task runner id is used instead.
                        SerDe.Write(socket.OutputStream, taskRunnerId);
                        socket.OutputStream.Flush();

                        if (ConfigurationService.IsDatabricks)
                        {
                            SerDe.ReadString(socket.InputStream);
                        }

                        authStatus = Authenticator.AuthenticateAsServer(socket, secret);
                    }

                    if (authStatus)
                    {
                        var taskRunner = new TaskRunner(
                            taskRunnerId,
                            socket,
                            reuseWorker,
                            _version);

                        _waitingTaskRunners.Add(taskRunner);
                        _taskRunners[taskRunnerId] = taskRunner;

                        ++taskRunnerId;

                        // When reuseWorker is set to true, numTaskRunners will be always one
                        // greater than numWorkerThreads since TaskRunner.Run() does not return
                        // so that the task runner object is not removed from _taskRunners.
                        int numTaskRunners = CurrentNumTaskRunners;

                        while (numWorkerThreads < numTaskRunners)
                        {
                            // Note that in the current implementation of RunWorkerThread() does
                            // not return. If more graceful exit is required, RunWorkerThread() can
                            // be updated to return upon receiving a signal from this main thread.
                            new Thread(RunWorkerThread).Start();
                            ++numWorkerThreads;
                        }

                        s_logger.LogInfo(
                            $"Pool snapshot: [NumThreads:{numWorkerThreads}], [NumTaskRunners:{numTaskRunners}]");
                    }
                    else
                    {
                        // Use SerDe.ReadBytes() to detect Java side has closed socket
                        // properly. ReadBytes() will block until the socket is closed.
                        s_logger.LogError(
                            "Authentication failed. Waiting for JVM side to close socket.");
                        SerDe.ReadBytes(socket.InputStream);

                        socket.Dispose();
                    }
                }
            }
            catch (Exception e)
            {
                s_logger.LogError($"StartServer() exits with exception: {e}");
                Environment.Exit(-1);
            }
        }
Example #9
0
        /// <summary>
        /// Listen to the server socket and accept new TCP connection from JVM side. Then create new TaskRunner instance and
        /// add it to waitingTaskRunners queue.
        /// </summary>
        private void StartDaemonServer(ISocketWrapper listener)
        {
            logger.LogInfo("StartDaemonServer ...");

            bool sparkReuseWorker = false;
            string envVar = Environment.GetEnvironmentVariable("SPARK_REUSE_WORKER"); // this envVar is set in JVM side
            if ((envVar != null) && envVar.Equals("1"))
            {
                sparkReuseWorker = true;
            }

            try
            {
                int trId = 1;
                int workThreadNum = 0;

                while (true)
                {
                    var socket = listener.Accept();
                    logger.LogInfo("Connection accepted for taskRunnerId: {0}", trId);
                    using (var s = socket.GetStream())
                    {
                        SerDe.Write(s, trId); // write taskRunnerId to JVM side
                        s.Flush();
                    }
                    TaskRunner taskRunner = new TaskRunner(trId, socket, sparkReuseWorker);
                    waitingTaskRunners.Add(taskRunner);
                    taskRunnerRegistry[trId] = taskRunner;
                    trId++;

                    int taskRunnerNum = taskRunnerRegistry.Count();
                    while (workThreadNum < taskRunnerNum)  // launch new work thread as appropriate
                    {
                        // start threads that do the actual work of running tasks, there are several options here:
                        // Option 1. TPL - Task Parallel Library
                        // Option 2. ThreadPool
                        // Option 3. Self managed threads group
                        // Option 3 is selected after testing in real cluster because it can get the best performance.
                        // When using option 1 or 2, it is observered that the boot time may be as large as 50 ~ 60s.
                        // But it is always less than 1s for option 3. Perhaps this is because TPL and ThreadPool are not
                        // suitable for long running threads.
                        new Thread(FetchAndRun).Start();
                        workThreadNum++;
                    }
                }
            }
            catch (Exception e)
            {
                logger.LogError("StartDaemonServer exception, will exit");
                logger.LogException(e);
                Environment.Exit(-1);
            }
        }
Example #10
0
        private void SocketTest(ISocketWrapper serverSocket)
        {
            serverSocket.Listen();
            if (serverSocket is RioSocketWrapper)
            {
                // Do nothing for second listen operation.
                Assert.DoesNotThrow(() => serverSocket.Listen(int.MaxValue));
            }

            var port = ((IPEndPoint)serverSocket.LocalEndPoint).Port;

            var clientMsg = "Hello Message from client";
            var clientMsgBytes = Encoding.UTF8.GetBytes(clientMsg);

            Task.Run(() =>
            {
                var bytes = new byte[1024];
                using (var socket = serverSocket.Accept())
                {
                    using (var s = socket.GetStream())
                    {
                        // Receive data
                        var bytesRec = s.Read(bytes, 0, bytes.Length);
                        // send echo message.
                        s.Write(bytes, 0, bytesRec);
                        s.Flush();

                        // Receive one byte
                        var oneByte = s.ReadByte();

                        // Send echo one byte
                        byte[] oneBytes = { (byte)oneByte };
                        s.Write(oneBytes, 0, oneBytes.Length);

                        Thread.SpinWait(0);

                        // Keep sending to ensure no memory leak
                        var longBytes = Encoding.UTF8.GetBytes(new string('x', 8192));
                        for (int i = 0; i < 1000; i++)
                        {
                            s.Write(longBytes, 0, longBytes.Length);
                        }
                        byte[] msg = Encoding.ASCII.GetBytes("This is a test<EOF>");
                        s.Write(msg, 0, msg.Length);

                        // Receive echo byte.
                        s.ReadByte();
                    }
                }
            });

            var clientSock = SocketFactory.CreateSocket();

            // Valid invalid operation
            Assert.Throws<InvalidOperationException>(() => clientSock.GetStream());
            Assert.Throws<InvalidOperationException>(() => clientSock.Receive());
            Assert.Throws<InvalidOperationException>(() => clientSock.Send(null));
            Assert.Throws<SocketException>(() => clientSock.Connect(IPAddress.Any, 1024));

            clientSock.Connect(IPAddress.Loopback, port);

            // Valid invalid operation
            var byteBuf = ByteBufPool.Default.Allocate();
            Assert.Throws<ArgumentException>(() => clientSock.Send(byteBuf));
            byteBuf.Release();

            Assert.Throws<SocketException>(() => clientSock.Listen());
            if (clientSock is RioSocketWrapper)
            {
                Assert.Throws<InvalidOperationException>(() => clientSock.Accept());
            }

            using (var s = clientSock.GetStream())
            {
                // Send message
                s.Write(clientMsgBytes, 0, clientMsgBytes.Length);
                // Receive echo message
                var bytes = new byte[1024];
                var bytesRec = s.Read(bytes, 0, bytes.Length);
                Assert.AreEqual(clientMsgBytes.Length, bytesRec);
                var recvStr = Encoding.UTF8.GetString(bytes, 0, bytesRec);
                Assert.AreEqual(clientMsg, recvStr);

                // Send one byte
                byte[] oneBytes = { 1 };
                s.Write(oneBytes, 0, oneBytes.Length);

                // Receive echo message
                var oneByte = s.ReadByte();
                Assert.AreEqual((byte)1, oneByte);

                // Keep receiving to ensure no memory leak.
                while (true)
                {
                    bytesRec = s.Read(bytes, 0, bytes.Length);
                    recvStr = Encoding.UTF8.GetString(bytes, 0, bytesRec);
                    if (recvStr.IndexOf("<EOF>", StringComparison.OrdinalIgnoreCase) > -1)
                    {
                        break;
                    }
                }
                // send echo bytes
                s.Write(oneBytes, 0, oneBytes.Length);
            }

            clientSock.Close();
            // Verify invalid operation
            Assert.Throws<ObjectDisposedException>(() => clientSock.Receive());

            serverSocket.Close();
        }
Example #11
0
        private void SocketTest(ISocketWrapper serverSocket)
        {
            serverSocket.Listen();
            if (serverSocket is RioSocketWrapper)
            {
                // Do nothing for second listen operation.
                Assert.DoesNotThrow(() => serverSocket.Listen(int.MaxValue));
            }

            var port = ((IPEndPoint)serverSocket.LocalEndPoint).Port;

            var clientMsg      = "Hello Message from client";
            var clientMsgBytes = Encoding.UTF8.GetBytes(clientMsg);

            Task.Run(() =>
            {
                var bytes = new byte[1024];
                using (var socket = serverSocket.Accept())
                {
                    using (var s = socket.GetStream())
                    {
                        // Receive data
                        var bytesRec = s.Read(bytes, 0, bytes.Length);
                        // send echo message.
                        s.Write(bytes, 0, bytesRec);
                        s.Flush();

                        // Receive one byte
                        var oneByte = s.ReadByte();

                        // Send echo one byte
                        byte[] oneBytes = { (byte)oneByte };
                        s.Write(oneBytes, 0, oneBytes.Length);

                        Thread.SpinWait(0);

                        // Keep sending to ensure no memory leak
                        var longBytes = Encoding.UTF8.GetBytes(new string('x', 8192));
                        for (int i = 0; i < 1000; i++)
                        {
                            s.Write(longBytes, 0, longBytes.Length);
                        }
                        byte[] msg = Encoding.ASCII.GetBytes("This is a test<EOF>");
                        s.Write(msg, 0, msg.Length);

                        // Receive echo byte.
                        s.ReadByte();
                    }
                }
            });


            var clientSock = SocketFactory.CreateSocket();

            // Valid invalid operation
            Assert.Throws <InvalidOperationException>(() => clientSock.GetStream());
            Assert.Throws <InvalidOperationException>(() => clientSock.Receive());
            Assert.Throws <InvalidOperationException>(() => clientSock.Send(null));
            Assert.Throws <SocketException>(() => clientSock.Connect(IPAddress.Any, 1024));

            clientSock.Connect(IPAddress.Loopback, port);

            // Valid invalid operation
            var byteBuf = ByteBufPool.Default.Allocate();

            Assert.Throws <ArgumentException>(() => clientSock.Send(byteBuf));
            byteBuf.Release();

            Assert.Throws <SocketException>(() => clientSock.Listen());
            if (clientSock is RioSocketWrapper)
            {
                Assert.Throws <InvalidOperationException>(() => clientSock.Accept());
            }

            using (var s = clientSock.GetStream())
            {
                // Send message
                s.Write(clientMsgBytes, 0, clientMsgBytes.Length);
                // Receive echo message
                var bytes    = new byte[1024];
                var bytesRec = s.Read(bytes, 0, bytes.Length);
                Assert.AreEqual(clientMsgBytes.Length, bytesRec);
                var recvStr = Encoding.UTF8.GetString(bytes, 0, bytesRec);
                Assert.AreEqual(clientMsg, recvStr);

                // Send one byte
                byte[] oneBytes = { 1 };
                s.Write(oneBytes, 0, oneBytes.Length);

                // Receive echo message
                var oneByte = s.ReadByte();
                Assert.AreEqual((byte)1, oneByte);

                // Keep receiving to ensure no memory leak.
                while (true)
                {
                    bytesRec = s.Read(bytes, 0, bytes.Length);
                    recvStr  = Encoding.UTF8.GetString(bytes, 0, bytesRec);
                    if (recvStr.IndexOf("<EOF>", StringComparison.OrdinalIgnoreCase) > -1)
                    {
                        break;
                    }
                }
                // send echo bytes
                s.Write(oneBytes, 0, oneBytes.Length);
            }

            clientSock.Close();
            // Verify invalid operation
            Assert.Throws <ObjectDisposedException>(() => clientSock.Receive());

            serverSocket.Close();
        }