/// <summary>
        /// Reads the given stream to construct a BroadcastVariables object.
        /// </summary>
        /// <param name="stream">The stream to read from</param>
        /// <returns>BroadcastVariables object</returns>
        internal BroadcastVariables Process(Stream stream)
        {
            var            broadcastVars = new BroadcastVariables();
            ISocketWrapper socket        = null;

            broadcastVars.DecryptionServerNeeded = SerDe.ReadBool(stream);
            broadcastVars.Count = Math.Max(SerDe.ReadInt32(stream), 0);

            if (broadcastVars.DecryptionServerNeeded)
            {
                broadcastVars.DecryptionServerPort = SerDe.ReadInt32(stream);
                broadcastVars.Secret = SerDe.ReadString(stream);
                if (broadcastVars.Count > 0)
                {
                    socket = SocketFactory.CreateSocket();
                    socket.Connect(
                        IPAddress.Loopback,
                        broadcastVars.DecryptionServerPort,
                        broadcastVars.Secret);
                }
            }

            var formatter = new BinaryFormatter();

            for (int i = 0; i < broadcastVars.Count; ++i)
            {
                long bid = SerDe.ReadInt64(stream);
                if (bid >= 0)
                {
                    if (broadcastVars.DecryptionServerNeeded)
                    {
                        long readBid = SerDe.ReadInt64(socket.InputStream);
                        if (bid != readBid)
                        {
                            throw new Exception("The Broadcast Id received from the encryption " +
                                                $"server {readBid} is different from the Broadcast Id received " +
                                                $"from the payload {bid}.");
                        }
                        object value = formatter.Deserialize(socket.InputStream);
                        BroadcastRegistry.Add(bid, value);
                    }
                    else
                    {
                        string path = SerDe.ReadString(stream);
                        using FileStream fStream =
                                  File.Open(path, FileMode.Open, FileAccess.Read, FileShare.Read);
                        object value = formatter.Deserialize(fStream);
                        BroadcastRegistry.Add(bid, value);
                    }
                }
                else
                {
                    bid = -bid - 1;
                    BroadcastRegistry.Remove(bid);
                }
            }
            socket?.Dispose();
            return(broadcastVars);
        }
Example #2
0
        public void Dispose()
        {
            state.SetDisposed();

            if (socket != null) // i. e., never connected
            {
                socket.Dispose();
            }
        }
Example #3
0
        /// <summary>
        /// Shuts down the <see cref="CallbackServer"/> by canceling any running threads
        /// and disposing of resources.
        /// </summary>
        private void Shutdown()
        {
            s_logger.LogInfo("Shutting down CallbackServer");

            _tokenSource.Cancel();
            _waitingConnections.Dispose();
            _connections.Clear();
            _callbackHandlers.Clear();
            _listener?.Dispose();
            _isRunning = false;

            _jvm.CallStaticJavaMethod("DotnetHandler", "closeCallback");
        }
Example #4
0
        private object CallJavaMethod(
            bool isStatic,
            object classNameOrJvmObjectReference,
            string methodName,
            object[] args)
        {
            object         returnValue = null;
            ISocketWrapper socket      = null;

            try
            {
                MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
                payloadMemoryStream.Position = 0;
                PayloadHelper.BuildPayload(
                    payloadMemoryStream,
                    isStatic,
                    classNameOrJvmObjectReference,
                    methodName,
                    args);

                socket = GetConnection();

                Stream outputStream = socket.OutputStream;
                outputStream.Write(
                    payloadMemoryStream.GetBuffer(),
                    0,
                    (int)payloadMemoryStream.Position);
                outputStream.Flush();

                Stream inputStream        = socket.InputStream;
                int    isMethodCallFailed = SerDe.ReadInt32(inputStream);
                if (isMethodCallFailed != 0)
                {
                    string jvmFullStackTrace = SerDe.ReadString(inputStream);
                    string errorMessage      = BuildErrorMessage(
                        isStatic,
                        classNameOrJvmObjectReference,
                        methodName,
                        args);
                    _logger.LogError(errorMessage);
                    _logger.LogError(jvmFullStackTrace);
                    throw new Exception(errorMessage, new JvmException(jvmFullStackTrace));
                }

                char typeAsChar = Convert.ToChar(inputStream.ReadByte());
                switch (typeAsChar) // TODO: Add support for other types.
                {
                case 'n':
                    break;

                case 'j':
                    returnValue = new JvmObjectReference(SerDe.ReadString(inputStream), this);
                    break;

                case 'c':
                    returnValue = SerDe.ReadString(inputStream);
                    break;

                case 'i':
                    returnValue = SerDe.ReadInt32(inputStream);
                    break;

                case 'g':
                    returnValue = SerDe.ReadInt64(inputStream);
                    break;

                case 'd':
                    returnValue = SerDe.ReadDouble(inputStream);
                    break;

                case 'b':
                    returnValue = Convert.ToBoolean(inputStream.ReadByte());
                    break;

                case 'l':
                    returnValue = ReadCollection(inputStream);
                    break;

                default:
                    // Convert typeAsChar to UInt32 because the char may be non-printable.
                    throw new NotSupportedException(
                              string.Format(
                                  "Identifier for type 0x{0:X} not supported",
                                  Convert.ToUInt32(typeAsChar)));
                }
                _sockets.Enqueue(socket);
            }
            catch (Exception e)
            {
                _logger.LogException(e);
                socket?.Dispose();
                throw;
            }

            return(returnValue);
        }
Example #5
0
        public void Run()
        {
            s_logger.LogInfo($"[{TaskId}] Starting with ReuseSocket[{_reuseSocket}].");

            if (EnvironmentUtils.GetEnvironmentVariableAsBool("DOTNET_WORKER_DEBUG"))
            {
                Debugger.Launch();
            }

            _isRunning = true;
            Stream inputStream  = _socket.InputStream;
            Stream outputStream = _socket.OutputStream;

            try
            {
                while (_isRunning)
                {
                    Payload payload = ProcessStream(
                        inputStream,
                        outputStream,
                        _version,
                        out bool readComplete);

                    if (payload != null)
                    {
                        outputStream.Flush();

                        ++_numTasksRun;

                        // If the socket is not read through completely, then it cannot be reused.
                        if (!readComplete)
                        {
                            _isRunning = false;

                            // Wait for server to complete to avoid 'connection reset' exception.
                            s_logger.LogInfo($"[{TaskId}] Sleep 500 millisecond to close socket.");
                            Thread.Sleep(500);
                        }
                        else if (!_reuseSocket)
                        {
                            _isRunning = false;

                            // Use SerDe.ReadBytes() to detect Java side has closed socket
                            // properly. SerDe.ReadBytes() will block until the socket is closed.
                            s_logger.LogInfo($"[{TaskId}] Waiting for JVM side to close socket.");
                            SerDe.ReadBytes(inputStream);
                            s_logger.LogInfo($"[{TaskId}] JVM side has closed socket.");
                        }
                    }
                    else
                    {
                        _isRunning = false;
                        s_logger.LogWarn(
                            $"[{TaskId}] Read null payload. Socket is closed by JVM.");
                    }
                }
            }
            catch (Exception e)
            {
                _isRunning = false;
                s_logger.LogError($"[{TaskId}] Exiting with exception: {e}");
            }
            finally
            {
                try
                {
                    _socket.Dispose();
                }
                catch (Exception ex)
                {
                    s_logger.LogWarn($"[{TaskId}] Exception while closing socket: {ex}");
                }

                s_logger.LogInfo($"[{TaskId}] Finished running {_numTasksRun} task(s).");
            }
        }
Example #6
0
        /// <summary>
        /// Starts listening to any connection from JVM.
        /// </summary>
        private void StartServer(ISocketWrapper listener)
        {
            try
            {
                bool reuseWorker =
                    "1".Equals(Environment.GetEnvironmentVariable("SPARK_REUSE_WORKER"));

                string secret = Utils.SettingUtils.GetWorkerFactorySecret();

                int taskRunnerId     = 1;
                int numWorkerThreads = 0;

                while (true)
                {
                    ISocketWrapper socket = listener.Accept();
                    s_logger.LogInfo($"New connection accepted for TaskRunner [{taskRunnerId}]");

                    bool authStatus = true;
                    if (!string.IsNullOrWhiteSpace(secret))
                    {
                        // The Spark side expects the PID from a forked process.
                        // In .NET implementation, a task runner id is used instead.
                        SerDe.Write(socket.OutputStream, taskRunnerId);
                        socket.OutputStream.Flush();

                        if (ConfigurationService.IsDatabricks)
                        {
                            SerDe.ReadString(socket.InputStream);
                        }

                        authStatus = Authenticator.AuthenticateAsServer(socket, secret);
                    }

                    if (authStatus)
                    {
                        var taskRunner = new TaskRunner(
                            taskRunnerId,
                            socket,
                            reuseWorker,
                            _version);

                        _waitingTaskRunners.Add(taskRunner);
                        _taskRunners[taskRunnerId] = taskRunner;

                        ++taskRunnerId;

                        // When reuseWorker is set to true, numTaskRunners will be always one
                        // greater than numWorkerThreads since TaskRunner.Run() does not return
                        // so that the task runner object is not removed from _taskRunners.
                        int numTaskRunners = CurrentNumTaskRunners;

                        while (numWorkerThreads < numTaskRunners)
                        {
                            // Note that in the current implementation of RunWorkerThread() does
                            // not return. If more graceful exit is required, RunWorkerThread() can
                            // be updated to return upon receiving a signal from this main thread.
                            new Thread(RunWorkerThread).Start();
                            ++numWorkerThreads;
                        }

                        s_logger.LogInfo(
                            $"Pool snapshot: [NumThreads:{numWorkerThreads}], [NumTaskRunners:{numTaskRunners}]");
                    }
                    else
                    {
                        // Use SerDe.ReadBytes() to detect Java side has closed socket
                        // properly. ReadBytes() will block until the socket is closed.
                        s_logger.LogError(
                            "Authentication failed. Waiting for JVM side to close socket.");
                        SerDe.ReadBytes(socket.InputStream);

                        socket.Dispose();
                    }
                }
            }
            catch (Exception e)
            {
                s_logger.LogError($"StartServer() exits with exception: {e}");
                Environment.Exit(-1);
            }
        }
Example #7
0
        private object CallJavaMethod(
            bool isStatic,
            object classNameOrJvmObjectReference,
            string methodName,
            object[] args)
        {
            object         returnValue = null;
            ISocketWrapper socket      = null;

            try
            {
                // dotnet-interactive does not have a dedicated thread to process
                // code submissions and each code submission can be processed in different
                // threads. DotnetHandler uses the CLR thread id to ensure that the same
                // JVM thread is used to handle the request, which means that code submitted
                // through dotnet-interactive may be executed in different JVM threads. To
                // mitigate this, when running in the REPL, submit requests to the DotnetHandler
                // using the same thread id. This mitigation has some limitations in multithreaded
                // scenarios. If a JVM method is blocking and needs a JVM method call issued by a
                // separate thread to unblock it, then this scenario is not supported.
                //
                // ie, `StreamingQuery.AwaitTermination()` is a blocking call and requires
                // `StreamingQuery.Stop()` to be called to unblock it. However, the `Stop`
                // call will never run because DotnetHandler will assign the method call to
                // run on the same thread that `AwaitTermination` is running on.
                Thread       thread              = _isRunningRepl ? null : Thread.CurrentThread;
                int          threadId            = thread == null ? ThreadIdForRepl : thread.ManagedThreadId;
                MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
                payloadMemoryStream.Position = 0;

                PayloadHelper.BuildPayload(
                    payloadMemoryStream,
                    isStatic,
                    _processId,
                    threadId,
                    classNameOrJvmObjectReference,
                    methodName,
                    args);

                socket = GetConnection();

                Stream outputStream = socket.OutputStream;
                outputStream.Write(
                    payloadMemoryStream.GetBuffer(),
                    0,
                    (int)payloadMemoryStream.Position);
                outputStream.Flush();

                if (thread != null)
                {
                    _jvmThreadPoolGC.TryAddThread(thread);
                }

                Stream inputStream        = socket.InputStream;
                int    isMethodCallFailed = SerDe.ReadInt32(inputStream);
                if (isMethodCallFailed != 0)
                {
                    string jvmFullStackTrace = SerDe.ReadString(inputStream);
                    string errorMessage      = BuildErrorMessage(
                        isStatic,
                        classNameOrJvmObjectReference,
                        methodName,
                        args);
                    _logger.LogError(errorMessage);
                    _logger.LogError(jvmFullStackTrace);
                    throw new Exception(errorMessage, new JvmException(jvmFullStackTrace));
                }

                char typeAsChar = Convert.ToChar(inputStream.ReadByte());
                switch (typeAsChar) // TODO: Add support for other types.
                {
                case 'n':
                    break;

                case 'j':
                    returnValue = new JvmObjectReference(SerDe.ReadString(inputStream), this);
                    break;

                case 'c':
                    returnValue = SerDe.ReadString(inputStream);
                    break;

                case 'i':
                    returnValue = SerDe.ReadInt32(inputStream);
                    break;

                case 'g':
                    returnValue = SerDe.ReadInt64(inputStream);
                    break;

                case 'd':
                    returnValue = SerDe.ReadDouble(inputStream);
                    break;

                case 'b':
                    returnValue = Convert.ToBoolean(inputStream.ReadByte());
                    break;

                case 'l':
                    returnValue = ReadCollection(inputStream);
                    break;

                default:
                    // Convert typeAsChar to UInt32 because the char may be non-printable.
                    throw new NotSupportedException(
                              string.Format(
                                  "Identifier for type 0x{0:X} not supported",
                                  Convert.ToUInt32(typeAsChar)));
                }
                _sockets.Enqueue(socket);
            }
            catch (Exception e)
            {
                _logger.LogException(e);

                if (e.InnerException is JvmException)
                {
                    // DotnetBackendHandler caught JVM exception and passed back to dotnet.
                    // We can reuse this connection.
                    _sockets.Enqueue(socket);
                }
                else
                {
                    // In rare cases we may hit the Netty connection thread deadlock.
                    // If max backend threads is 10 and we are currently using 10 active
                    // connections (0 in the _sockets queue). When we hit this exception,
                    // the socket?.Dispose() will not requeue this socket and we will release
                    // the semaphore. Then in the next thread (assuming the other 9 connections
                    // are still busy), a new connection will be made to the backend and this
                    // connection may be scheduled on the blocked Netty thread.
                    socket?.Dispose();
                }

                throw;
            }
            finally
            {
                _socketSemaphore.Release();
            }

            return(returnValue);
        }
Example #8
0
        private object CallJavaMethod(
            bool isStatic,
            object classNameOrJvmObjectReference,
            string methodName,
            object[] args)
        {
            object         returnValue = null;
            ISocketWrapper socket      = null;

            try
            {
                Thread       thread = Thread.CurrentThread;
                MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
                payloadMemoryStream.Position = 0;
                PayloadHelper.BuildPayload(
                    payloadMemoryStream,
                    isStatic,
                    thread.ManagedThreadId,
                    classNameOrJvmObjectReference,
                    methodName,
                    args);

                socket = GetConnection();

                Stream outputStream = socket.OutputStream;
                outputStream.Write(
                    payloadMemoryStream.GetBuffer(),
                    0,
                    (int)payloadMemoryStream.Position);
                outputStream.Flush();

                _jvmThreadPoolGC.TryAddThread(thread);

                Stream inputStream        = socket.InputStream;
                int    isMethodCallFailed = SerDe.ReadInt32(inputStream);
                if (isMethodCallFailed != 0)
                {
                    string jvmFullStackTrace = SerDe.ReadString(inputStream);
                    string errorMessage      = BuildErrorMessage(
                        isStatic,
                        classNameOrJvmObjectReference,
                        methodName,
                        args);
                    _logger.LogError(errorMessage);
                    _logger.LogError(jvmFullStackTrace);
                    throw new Exception(errorMessage, new JvmException(jvmFullStackTrace));
                }

                char typeAsChar = Convert.ToChar(inputStream.ReadByte());
                switch (typeAsChar) // TODO: Add support for other types.
                {
                case 'n':
                    break;

                case 'j':
                    returnValue = new JvmObjectReference(SerDe.ReadString(inputStream), this);
                    break;

                case 'c':
                    returnValue = SerDe.ReadString(inputStream);
                    break;

                case 'i':
                    returnValue = SerDe.ReadInt32(inputStream);
                    break;

                case 'g':
                    returnValue = SerDe.ReadInt64(inputStream);
                    break;

                case 'd':
                    returnValue = SerDe.ReadDouble(inputStream);
                    break;

                case 'b':
                    returnValue = Convert.ToBoolean(inputStream.ReadByte());
                    break;

                case 'l':
                    returnValue = ReadCollection(inputStream);
                    break;

                default:
                    // Convert typeAsChar to UInt32 because the char may be non-printable.
                    throw new NotSupportedException(
                              string.Format(
                                  "Identifier for type 0x{0:X} not supported",
                                  Convert.ToUInt32(typeAsChar)));
                }
                _sockets.Enqueue(socket);
            }
            catch (Exception e)
            {
                _logger.LogException(e);

                if (e.InnerException is JvmException)
                {
                    // DotnetBackendHandler caught JVM exception and passed back to dotnet.
                    // We can reuse this connection.
                    _sockets.Enqueue(socket);
                }
                else
                {
                    // In rare cases we may hit the Netty connection thread deadlock.
                    // If max backend threads is 10 and we are currently using 10 active
                    // connections (0 in the _sockets queue). When we hit this exception,
                    // the socket?.Dispose() will not requeue this socket and we will release
                    // the semaphore. Then in the next thread (assuming the other 9 connections
                    // are still busy), a new connection will be made to the backend and this
                    // connection may be scheduled on the blocked Netty thread.
                    socket?.Dispose();
                }

                throw;
            }
            finally
            {
                _socketSemaphore.Release();
            }

            return(returnValue);
        }
Example #9
0
        /// <summary>
        /// Run and start processing the callback connection.
        /// </summary>
        /// <param name="token">Cancellation token used to stop the connection.</param>
        internal void Run(CancellationToken token)
        {
            _isRunning = true;
            Stream inputStream  = _socket.InputStream;
            Stream outputStream = _socket.OutputStream;

            token.Register(() => Stop());

            try
            {
                while (_isRunning)
                {
                    ConnectionStatus connectionStatus =
                        ProcessStream(inputStream, outputStream, out bool readComplete);

                    if (connectionStatus == ConnectionStatus.OK)
                    {
                        outputStream.Flush();

                        ++_numCallbacksRun;

                        // If the socket is not read through completely, then it cannot be reused.
                        if (!readComplete)
                        {
                            _isRunning = false;

                            // Wait for server to complete to avoid 'connection reset' exception.
                            s_logger.LogInfo(
                                $"[{ConnectionId}] Sleep 500 millisecond to close socket.");
                            Thread.Sleep(500);
                        }
                    }
                    else if (connectionStatus == ConnectionStatus.REQUEST_CLOSE)
                    {
                        _isRunning = false;
                        s_logger.LogInfo(
                            $"[{ConnectionId}] Request to close connection received.");
                    }
                    else
                    {
                        _isRunning = false;
                        s_logger.LogWarn($"[{ConnectionId}] Socket is closed by JVM.");
                    }
                }
            }
            catch (Exception e)
            {
                _isRunning = false;
                s_logger.LogError($"[{ConnectionId}] Exiting with exception: {e}");
            }
            finally
            {
                try
                {
                    _socket.Dispose();
                }
                catch (Exception e)
                {
                    s_logger.LogWarn($"[{ConnectionId}] Exception while closing socket {e}");
                }

                s_logger.LogInfo(
                    $"[{ConnectionId}] Finished running {_numCallbacksRun} callback(s).");
            }
        }