internal static void BuildPayload( MemoryStream destination, bool isStaticMethod, object classNameOrJvmObjectReference, string methodName, object[] args) { // Reserve space for total length. long originalPosition = destination.Position; destination.Position += sizeof(int); SerDe.Write(destination, isStaticMethod); SerDe.Write(destination, classNameOrJvmObjectReference.ToString()); SerDe.Write(destination, methodName); SerDe.Write(destination, args.Length); ConvertArgsToBytes(destination, args); // Write the length now that we've written out everything else. long afterPosition = destination.Position; destination.Position = originalPosition; SerDe.Write(destination, (int)afterPosition - sizeof(int)); destination.Position = afterPosition; }
public void Run(Stream inputStream) { var batchDf = new DataFrame(new JvmObjectReference(SerDe.ReadString(inputStream), _jvm)); long batchId = SerDe.ReadInt64(inputStream); _func(batchDf, batchId); }
internal static void WriteIEnumerableObjects <T>(MemoryStream stream, IEnumerable <T> enumlist, Action <MemoryStream, T> writefunc) { var posBeforeEnumerable = stream.Position; stream.Position += sizeof(int); var itemCount = 0; foreach (var obj in enumlist) { itemCount++; writefunc(stream, obj); } var posAfterEnumerable = stream.Position; stream.Position = posBeforeEnumerable; SerDe.Write(stream, itemCount); stream.Position = posAfterEnumerable; }
internal static void ConvertArgsToBytes( MemoryStream destination, object[] args, bool addTypeIdPrefix = true) { long posBeforeEnumerable, posAfterEnumerable; int itemCount; object[] convertArgs = null; foreach (object arg in args) { if (arg == null) { destination.WriteByte((byte)'n'); continue; } Type argType = arg.GetType(); if (addTypeIdPrefix) { SerDe.Write(destination, GetTypeId(argType)); } switch (Type.GetTypeCode(argType)) { case TypeCode.Int32: SerDe.Write(destination, (int)arg); break; case TypeCode.Int64: SerDe.Write(destination, (long)arg); break; case TypeCode.String: SerDe.Write(destination, (string)arg); break; case TypeCode.Boolean: SerDe.Write(destination, (bool)arg); break; case TypeCode.Double: SerDe.Write(destination, (double)arg); break; case TypeCode.Object: switch (arg) { case byte[] argByteArray: SerDe.Write(destination, argByteArray.Length); SerDe.Write(destination, argByteArray); break; case int[] argInt32Array: SerDe.Write(destination, s_int32TypeId); SerDe.Write(destination, argInt32Array.Length); foreach (int i in argInt32Array) { SerDe.Write(destination, i); } break; case long[] argInt64Array: SerDe.Write(destination, s_int64TypeId); SerDe.Write(destination, argInt64Array.Length); foreach (long i in argInt64Array) { SerDe.Write(destination, i); } break; case double[] argDoubleArray: SerDe.Write(destination, s_doubleTypeId); SerDe.Write(destination, argDoubleArray.Length); foreach (double d in argDoubleArray) { SerDe.Write(destination, d); } break; case double[][] argDoubleArrayArray: SerDe.Write(destination, s_doubleArrayArrayTypeId); SerDe.Write(destination, argDoubleArrayArray.Length); foreach (double[] doubleArray in argDoubleArrayArray) { SerDe.Write(destination, doubleArray.Length); foreach (double d in doubleArray) { SerDe.Write(destination, d); } } break; case IEnumerable <byte[]> argByteArrayEnumerable: SerDe.Write(destination, s_byteArrayTypeId); posBeforeEnumerable = destination.Position; destination.Position += sizeof(int); itemCount = 0; foreach (byte[] b in argByteArrayEnumerable) { ++itemCount; SerDe.Write(destination, b.Length); destination.Write(b, 0, b.Length); } posAfterEnumerable = destination.Position; destination.Position = posBeforeEnumerable; SerDe.Write(destination, itemCount); destination.Position = posAfterEnumerable; break; case IEnumerable <string> argStringEnumerable: SerDe.Write(destination, s_stringTypeId); posBeforeEnumerable = destination.Position; destination.Position += sizeof(int); itemCount = 0; foreach (string s in argStringEnumerable) { ++itemCount; SerDe.Write(destination, s); } posAfterEnumerable = destination.Position; destination.Position = posBeforeEnumerable; SerDe.Write(destination, itemCount); destination.Position = posAfterEnumerable; break; case IEnumerable <IJvmObjectReferenceProvider> argJvmEnumerable: SerDe.Write(destination, s_jvmObjectTypeId); posBeforeEnumerable = destination.Position; destination.Position += sizeof(int); itemCount = 0; foreach (IJvmObjectReferenceProvider jvmObject in argJvmEnumerable) { ++itemCount; SerDe.Write(destination, jvmObject.Reference.Id); } posAfterEnumerable = destination.Position; destination.Position = posBeforeEnumerable; SerDe.Write(destination, itemCount); destination.Position = posAfterEnumerable; break; case IEnumerable <GenericRow> argRowEnumerable: posBeforeEnumerable = destination.Position; destination.Position += sizeof(int); itemCount = 0; foreach (GenericRow r in argRowEnumerable) { ++itemCount; SerDe.Write(destination, (int)r.Values.Length); ConvertArgsToBytes(destination, r.Values, true); } posAfterEnumerable = destination.Position; destination.Position = posBeforeEnumerable; SerDe.Write(destination, itemCount); destination.Position = posAfterEnumerable; break; case var _ when IsDictionary(arg.GetType()): // Generic dictionary, but we don't have it strongly typed as // Dictionary<T,U> var dictInterface = (IDictionary)arg; var dict = new Dictionary <object, object>(dictInterface.Count); IDictionaryEnumerator iter = dictInterface.GetEnumerator(); while (iter.MoveNext()) { dict[iter.Key] = iter.Value; } // Below serialization is corresponding to deserialization method // ReadMap() of SerDe.scala. // dictionary's length SerDe.Write(destination, dict.Count); // keys' data type SerDe.Write( destination, GetTypeId(arg.GetType().GetGenericArguments()[0])); // keys' length, same as dictionary's length SerDe.Write(destination, dict.Count); if (convertArgs == null) { convertArgs = new object[1]; } foreach (KeyValuePair <object, object> kv in dict) { convertArgs[0] = kv.Key; // keys, do not need type prefix. ConvertArgsToBytes(destination, convertArgs, false); } // values' length, same as dictionary's length SerDe.Write(destination, dict.Count); foreach (KeyValuePair <object, object> kv in dict) { convertArgs[0] = kv.Value; // values, need type prefix. ConvertArgsToBytes(destination, convertArgs, true); } break; case IJvmObjectReferenceProvider argProvider: SerDe.Write(destination, argProvider.Reference.Id); break; default: throw new NotSupportedException( string.Format($"Type {arg.GetType()} is not supported")); } break; } } }
private object ReadCollection(Stream s) { object returnValue; char listItemTypeAsChar = Convert.ToChar(s.ReadByte()); int numOfItemsInList = SerDe.ReadInt32(s); switch (listItemTypeAsChar) { case 'c': var strArray = new string[numOfItemsInList]; for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex) { strArray[itemIndex] = SerDe.ReadString(s); } returnValue = strArray; break; case 'i': var intArray = new int[numOfItemsInList]; for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex) { intArray[itemIndex] = SerDe.ReadInt32(s); } returnValue = intArray; break; case 'g': var longArray = new long[numOfItemsInList]; for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex) { longArray[itemIndex] = SerDe.ReadInt64(s); } returnValue = longArray; break; case 'd': var doubleArray = new double[numOfItemsInList]; for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex) { doubleArray[itemIndex] = SerDe.ReadDouble(s); } returnValue = doubleArray; break; case 'A': var doubleArrayArray = new double[numOfItemsInList][]; for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex) { doubleArrayArray[itemIndex] = ReadCollection(s) as double[]; } returnValue = doubleArrayArray; break; case 'b': var boolArray = new bool[numOfItemsInList]; for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex) { boolArray[itemIndex] = Convert.ToBoolean(s.ReadByte()); } returnValue = boolArray; break; case 'r': var byteArrayArray = new byte[numOfItemsInList][]; for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex) { int byteArrayLen = SerDe.ReadInt32(s); byteArrayArray[itemIndex] = SerDe.ReadBytes(s, byteArrayLen); } returnValue = byteArrayArray; break; case 'j': var jvmObjectReferenceArray = new JvmObjectReference[numOfItemsInList]; for (int itemIndex = 0; itemIndex < numOfItemsInList; ++itemIndex) { string itemIdentifier = SerDe.ReadString(s); jvmObjectReferenceArray[itemIndex] = new JvmObjectReference(itemIdentifier, this); } returnValue = jvmObjectReferenceArray; break; default: // convert listItemTypeAsChar to UInt32 because the char may be non-printable throw new NotSupportedException( string.Format("Identifier for list item type 0x{0:X} not supported", Convert.ToUInt32(listItemTypeAsChar))); } return(returnValue); }
private object CallJavaMethod( bool isStatic, object classNameOrJvmObjectReference, string methodName, object[] args) { object returnValue = null; ISocketWrapper socket = null; try { MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream(); payloadMemoryStream.Position = 0; PayloadHelper.BuildPayload( payloadMemoryStream, isStatic, classNameOrJvmObjectReference, methodName, args); socket = GetConnection(); Stream outputStream = socket.OutputStream; outputStream.Write( payloadMemoryStream.GetBuffer(), 0, (int)payloadMemoryStream.Position); outputStream.Flush(); Stream inputStream = socket.InputStream; int isMethodCallFailed = SerDe.ReadInt32(inputStream); if (isMethodCallFailed != 0) { string jvmFullStackTrace = SerDe.ReadString(inputStream); string errorMessage = BuildErrorMessage( isStatic, classNameOrJvmObjectReference, methodName, args); _logger.LogError(errorMessage); _logger.LogError(jvmFullStackTrace); throw new Exception(errorMessage, new JvmException(jvmFullStackTrace)); } char typeAsChar = Convert.ToChar(inputStream.ReadByte()); switch (typeAsChar) // TODO: Add support for other types. { case 'n': break; case 'j': returnValue = new JvmObjectReference(SerDe.ReadString(inputStream), this); break; case 'c': returnValue = SerDe.ReadString(inputStream); break; case 'i': returnValue = SerDe.ReadInt32(inputStream); break; case 'g': returnValue = SerDe.ReadInt64(inputStream); break; case 'd': returnValue = SerDe.ReadDouble(inputStream); break; case 'b': returnValue = Convert.ToBoolean(inputStream.ReadByte()); break; case 'l': returnValue = ReadCollection(inputStream); break; default: // Convert typeAsChar to UInt32 because the char may be non-printable. throw new NotSupportedException( string.Format( "Identifier for type 0x{0:X} not supported", Convert.ToUInt32(typeAsChar))); } _sockets.Enqueue(socket); } catch (Exception e) { _logger.LogException(e); socket?.Dispose(); throw; } return(returnValue); }
private object CallJavaMethod( bool isStatic, object classNameOrJvmObjectReference, string methodName, object[] args) { object returnValue = null; ISocketWrapper socket = null; try { // dotnet-interactive does not have a dedicated thread to process // code submissions and each code submission can be processed in different // threads. DotnetHandler uses the CLR thread id to ensure that the same // JVM thread is used to handle the request, which means that code submitted // through dotnet-interactive may be executed in different JVM threads. To // mitigate this, when running in the REPL, submit requests to the DotnetHandler // using the same thread id. This mitigation has some limitations in multithreaded // scenarios. If a JVM method is blocking and needs a JVM method call issued by a // separate thread to unblock it, then this scenario is not supported. // // ie, `StreamingQuery.AwaitTermination()` is a blocking call and requires // `StreamingQuery.Stop()` to be called to unblock it. However, the `Stop` // call will never run because DotnetHandler will assign the method call to // run on the same thread that `AwaitTermination` is running on. Thread thread = _isRunningRepl ? null : Thread.CurrentThread; int threadId = thread == null ? ThreadIdForRepl : thread.ManagedThreadId; MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream(); payloadMemoryStream.Position = 0; PayloadHelper.BuildPayload( payloadMemoryStream, isStatic, _processId, threadId, classNameOrJvmObjectReference, methodName, args); socket = GetConnection(); Stream outputStream = socket.OutputStream; outputStream.Write( payloadMemoryStream.GetBuffer(), 0, (int)payloadMemoryStream.Position); outputStream.Flush(); if (thread != null) { _jvmThreadPoolGC.TryAddThread(thread); } Stream inputStream = socket.InputStream; int isMethodCallFailed = SerDe.ReadInt32(inputStream); if (isMethodCallFailed != 0) { string jvmFullStackTrace = SerDe.ReadString(inputStream); string errorMessage = BuildErrorMessage( isStatic, classNameOrJvmObjectReference, methodName, args); _logger.LogError(errorMessage); _logger.LogError(jvmFullStackTrace); throw new Exception(errorMessage, new JvmException(jvmFullStackTrace)); } char typeAsChar = Convert.ToChar(inputStream.ReadByte()); switch (typeAsChar) // TODO: Add support for other types. { case 'n': break; case 'j': returnValue = new JvmObjectReference(SerDe.ReadString(inputStream), this); break; case 'c': returnValue = SerDe.ReadString(inputStream); break; case 'i': returnValue = SerDe.ReadInt32(inputStream); break; case 'g': returnValue = SerDe.ReadInt64(inputStream); break; case 'd': returnValue = SerDe.ReadDouble(inputStream); break; case 'b': returnValue = Convert.ToBoolean(inputStream.ReadByte()); break; case 'l': returnValue = ReadCollection(inputStream); break; default: // Convert typeAsChar to UInt32 because the char may be non-printable. throw new NotSupportedException( string.Format( "Identifier for type 0x{0:X} not supported", Convert.ToUInt32(typeAsChar))); } _sockets.Enqueue(socket); } catch (Exception e) { _logger.LogException(e); if (e.InnerException is JvmException) { // DotnetBackendHandler caught JVM exception and passed back to dotnet. // We can reuse this connection. _sockets.Enqueue(socket); } else { // In rare cases we may hit the Netty connection thread deadlock. // If max backend threads is 10 and we are currently using 10 active // connections (0 in the _sockets queue). When we hit this exception, // the socket?.Dispose() will not requeue this socket and we will release // the semaphore. Then in the next thread (assuming the other 9 connections // are still busy), a new connection will be made to the backend and this // connection may be scheduled on the blocked Netty thread. socket?.Dispose(); } throw; } finally { _socketSemaphore.Release(); } return(returnValue); }
private object CallJavaMethod( bool isStatic, object classNameOrJvmObjectReference, string methodName, object[] args) { object returnValue = null; ISocketWrapper socket = null; try { Thread thread = Thread.CurrentThread; MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream(); payloadMemoryStream.Position = 0; PayloadHelper.BuildPayload( payloadMemoryStream, isStatic, thread.ManagedThreadId, classNameOrJvmObjectReference, methodName, args); socket = GetConnection(); Stream outputStream = socket.OutputStream; outputStream.Write( payloadMemoryStream.GetBuffer(), 0, (int)payloadMemoryStream.Position); outputStream.Flush(); _jvmThreadPoolGC.TryAddThread(thread); Stream inputStream = socket.InputStream; int isMethodCallFailed = SerDe.ReadInt32(inputStream); if (isMethodCallFailed != 0) { string jvmFullStackTrace = SerDe.ReadString(inputStream); string errorMessage = BuildErrorMessage( isStatic, classNameOrJvmObjectReference, methodName, args); _logger.LogError(errorMessage); _logger.LogError(jvmFullStackTrace); throw new Exception(errorMessage, new JvmException(jvmFullStackTrace)); } char typeAsChar = Convert.ToChar(inputStream.ReadByte()); switch (typeAsChar) // TODO: Add support for other types. { case 'n': break; case 'j': returnValue = new JvmObjectReference(SerDe.ReadString(inputStream), this); break; case 'c': returnValue = SerDe.ReadString(inputStream); break; case 'i': returnValue = SerDe.ReadInt32(inputStream); break; case 'g': returnValue = SerDe.ReadInt64(inputStream); break; case 'd': returnValue = SerDe.ReadDouble(inputStream); break; case 'b': returnValue = Convert.ToBoolean(inputStream.ReadByte()); break; case 'l': returnValue = ReadCollection(inputStream); break; default: // Convert typeAsChar to UInt32 because the char may be non-printable. throw new NotSupportedException( string.Format( "Identifier for type 0x{0:X} not supported", Convert.ToUInt32(typeAsChar))); } _sockets.Enqueue(socket); } catch (Exception e) { _logger.LogException(e); if (e.InnerException is JvmException) { // DotnetBackendHandler caught JVM exception and passed back to dotnet. // We can reuse this connection. _sockets.Enqueue(socket); } else { // In rare cases we may hit the Netty connection thread deadlock. // If max backend threads is 10 and we are currently using 10 active // connections (0 in the _sockets queue). When we hit this exception, // the socket?.Dispose() will not requeue this socket and we will release // the semaphore. Then in the next thread (assuming the other 9 connections // are still busy), a new connection will be made to the backend and this // connection may be scheduled on the blocked Netty thread. socket?.Dispose(); } throw; } finally { _socketSemaphore.Release(); } return(returnValue); }
/// <summary> /// Process the input and output streams. /// </summary> /// <param name="inputStream">The input stream.</param> /// <param name="outputStream">The output stream.</param> /// <param name="readComplete">True if stream is read completely, false otherwise.</param> /// <returns>The connection status.</returns> private ConnectionStatus ProcessStream( Stream inputStream, Stream outputStream, out bool readComplete) { readComplete = false; try { byte[] requestFlagBytes = SerDe.ReadBytes(inputStream, sizeof(int)); // For socket stream, read on the stream returns 0, which // SerDe.ReadBytes() returns as null to denote the stream is closed. if (requestFlagBytes == null) { return(ConnectionStatus.SOCKET_CLOSED); } // Check value of the initial request. Expected values are: // - CallbackFlags.CLOSE // - CallbackFlags.CALLBACK int requestFlag = BinaryPrimitives.ReadInt32BigEndian(requestFlagBytes); if (requestFlag == (int)CallbackFlags.CLOSE) { return(ConnectionStatus.REQUEST_CLOSE); } else if (requestFlag != (int)CallbackFlags.CALLBACK) { throw new Exception( string.Format( "Unexpected callback flag received. Expected: {0}, Received: {1}.", CallbackFlags.CALLBACK, requestFlag)); } // Use callback id to get the registered handler. int callbackId = SerDe.ReadInt32(inputStream); if (!_callbackHandlers.TryGetValue( callbackId, out ICallbackHandler callbackHandler)) { throw new Exception($"Unregistered callback id: {callbackId}"); } s_logger.LogInfo( string.Format( "[{0}] Received request for callback id: {1}, callback handler: {2}", ConnectionId, callbackId, callbackHandler)); // Save contents of callback handler data to be used later. using var callbackDataStream = new MemoryStream(SerDe.ReadBytes(inputStream, SerDe.ReadInt32(inputStream))); // Check the end of stream. int endOfStream = SerDe.ReadInt32(inputStream); if (endOfStream == (int)CallbackFlags.END_OF_STREAM) { s_logger.LogDebug($"[{ConnectionId}] Received END_OF_STREAM signal."); // Run callback handler. callbackHandler.Run(callbackDataStream); SerDe.Write(outputStream, (int)CallbackFlags.END_OF_STREAM); readComplete = true; } else { // This may happen when the input data is not read completely. s_logger.LogWarn( $"[{ConnectionId}] Unexpected end of stream: {endOfStream}."); // Write flag to indicate the connection should be closed. SerDe.Write(outputStream, (int)CallbackFlags.CLOSE); } return(ConnectionStatus.OK); } catch (Exception e) { s_logger.LogError($"[{ConnectionId}] ProcessStream() failed with exception: {e}"); try { SerDe.Write(outputStream, (int)CallbackFlags.DOTNET_EXCEPTION_THROWN); SerDe.Write(outputStream, e.ToString()); } catch (IOException) { // JVM closed the socket. } catch (Exception ex) { s_logger.LogError( $"[{ConnectionId}] Writing exception to stream failed with exception: {ex}"); } throw; } }