internal static void BuildPayload( MemoryStream destination, bool isStaticMethod, object classNameOrJvmObjectReference, string methodName, object[] args) { // Reserve space for total length. var originalPosition = destination.Position; destination.Position += sizeof(int); SerDe.Write(destination, isStaticMethod); SerDe.Write(destination, classNameOrJvmObjectReference.ToString()); SerDe.Write(destination, methodName); SerDe.Write(destination, args.Length); ConvertArgsToBytes(destination, args); // Write the length now that we've written out everything else. var afterPosition = destination.Position; destination.Position = originalPosition; SerDe.Write(destination, (int)afterPosition - sizeof(int)); destination.Position = afterPosition; }
internal static void WriteIEnumerableObjects <T>(MemoryStream stream, IEnumerable <T> enumlist, Action <MemoryStream, T> writefunc) { var posBeforeEnumerable = stream.Position; stream.Position += sizeof(int); var itemCount = 0; foreach (var obj in enumlist) { itemCount++; writefunc(stream, obj); } var posAfterEnumerable = stream.Position; stream.Position = posBeforeEnumerable; SerDe.Write(stream, itemCount); stream.Position = posAfterEnumerable; }
internal static void ConvertArgsToBytes( MemoryStream destination, object[] args, bool addTypeIdPrefix = true) { long posBeforeEnumerable, posAfterEnumerable; int itemCount; object[] convertArgs = null; foreach (object arg in args) { if (arg == null) { destination.WriteByte((byte)'n'); continue; } Type argType = arg.GetType(); if (addTypeIdPrefix) { SerDe.Write(destination, GetTypeId(argType)); } switch (Type.GetTypeCode(argType)) { case TypeCode.Int32: SerDe.Write(destination, (int)arg); break; case TypeCode.Int64: SerDe.Write(destination, (long)arg); break; case TypeCode.String: SerDe.Write(destination, (string)arg); break; case TypeCode.Boolean: SerDe.Write(destination, (bool)arg); break; case TypeCode.Double: SerDe.Write(destination, (double)arg); break; case TypeCode.Object: switch (arg) { case byte[] argByteArray: SerDe.Write(destination, argByteArray.Length); SerDe.Write(destination, argByteArray); break; case int[] argInt32Array: SerDe.Write(destination, s_int32TypeId); SerDe.Write(destination, argInt32Array.Length); foreach (int i in argInt32Array) { SerDe.Write(destination, i); } break; case long[] argInt64Array: SerDe.Write(destination, s_int64TypeId); SerDe.Write(destination, argInt64Array.Length); foreach (long i in argInt64Array) { SerDe.Write(destination, i); } break; case double[] argDoubleArray: SerDe.Write(destination, s_doubleTypeId); SerDe.Write(destination, argDoubleArray.Length); foreach (double d in argDoubleArray) { SerDe.Write(destination, d); } break; case double[][] argDoubleArrayArray: SerDe.Write(destination, s_doubleArrayArrayTypeId); SerDe.Write(destination, argDoubleArrayArray.Length); foreach (double[] doubleArray in argDoubleArrayArray) { SerDe.Write(destination, doubleArray.Length); foreach (double d in doubleArray) { SerDe.Write(destination, d); } } break; case IEnumerable <byte[]> argByteArrayEnumerable: SerDe.Write(destination, s_byteArrayTypeId); posBeforeEnumerable = destination.Position; destination.Position += sizeof(int); itemCount = 0; foreach (byte[] b in argByteArrayEnumerable) { ++itemCount; SerDe.Write(destination, b.Length); destination.Write(b, 0, b.Length); } posAfterEnumerable = destination.Position; destination.Position = posBeforeEnumerable; SerDe.Write(destination, itemCount); destination.Position = posAfterEnumerable; break; case IEnumerable <string> argStringEnumerable: SerDe.Write(destination, s_stringTypeId); posBeforeEnumerable = destination.Position; destination.Position += sizeof(int); itemCount = 0; foreach (string s in argStringEnumerable) { ++itemCount; SerDe.Write(destination, s); } posAfterEnumerable = destination.Position; destination.Position = posBeforeEnumerable; SerDe.Write(destination, itemCount); destination.Position = posAfterEnumerable; break; case IEnumerable <IJvmObjectReferenceProvider> argJvmEnumerable: SerDe.Write(destination, s_jvmObjectTypeId); posBeforeEnumerable = destination.Position; destination.Position += sizeof(int); itemCount = 0; foreach (IJvmObjectReferenceProvider jvmObject in argJvmEnumerable) { ++itemCount; SerDe.Write(destination, jvmObject.Reference.Id); } posAfterEnumerable = destination.Position; destination.Position = posBeforeEnumerable; SerDe.Write(destination, itemCount); destination.Position = posAfterEnumerable; break; case IEnumerable <GenericRow> argRowEnumerable: posBeforeEnumerable = destination.Position; destination.Position += sizeof(int); itemCount = 0; foreach (GenericRow r in argRowEnumerable) { ++itemCount; SerDe.Write(destination, (int)r.Values.Length); ConvertArgsToBytes(destination, r.Values, true); } posAfterEnumerable = destination.Position; destination.Position = posBeforeEnumerable; SerDe.Write(destination, itemCount); destination.Position = posAfterEnumerable; break; case var _ when IsDictionary(arg.GetType()): // Generic dictionary, but we don't have it strongly typed as // Dictionary<T,U> var dictInterface = (IDictionary)arg; var dict = new Dictionary <object, object>(dictInterface.Count); IDictionaryEnumerator iter = dictInterface.GetEnumerator(); while (iter.MoveNext()) { dict[iter.Key] = iter.Value; } // Below serialization is corresponding to deserialization method // ReadMap() of SerDe.scala. // dictionary's length SerDe.Write(destination, dict.Count); // keys' data type SerDe.Write( destination, GetTypeId(arg.GetType().GetGenericArguments()[0])); // keys' length, same as dictionary's length SerDe.Write(destination, dict.Count); if (convertArgs == null) { convertArgs = new object[1]; } foreach (KeyValuePair <object, object> kv in dict) { convertArgs[0] = kv.Key; // keys, do not need type prefix. ConvertArgsToBytes(destination, convertArgs, false); } // values' length, same as dictionary's length SerDe.Write(destination, dict.Count); foreach (KeyValuePair <object, object> kv in dict) { convertArgs[0] = kv.Value; // values, need type prefix. ConvertArgsToBytes(destination, convertArgs, true); } break; case IJvmObjectReferenceProvider argProvider: SerDe.Write(destination, argProvider.Reference.Id); break; default: throw new NotSupportedException( string.Format($"Type {arg.GetType()} is not supported")); } break; } } }
/// <summary> /// Process the input and output streams. /// </summary> /// <param name="inputStream">The input stream.</param> /// <param name="outputStream">The output stream.</param> /// <param name="readComplete">True if stream is read completely, false otherwise.</param> /// <returns>The connection status.</returns> private ConnectionStatus ProcessStream( Stream inputStream, Stream outputStream, out bool readComplete) { readComplete = false; try { byte[] requestFlagBytes = SerDe.ReadBytes(inputStream, sizeof(int)); // For socket stream, read on the stream returns 0, which // SerDe.ReadBytes() returns as null to denote the stream is closed. if (requestFlagBytes == null) { return(ConnectionStatus.SOCKET_CLOSED); } // Check value of the initial request. Expected values are: // - CallbackFlags.CLOSE // - CallbackFlags.CALLBACK int requestFlag = BinaryPrimitives.ReadInt32BigEndian(requestFlagBytes); if (requestFlag == (int)CallbackFlags.CLOSE) { return(ConnectionStatus.REQUEST_CLOSE); } else if (requestFlag != (int)CallbackFlags.CALLBACK) { throw new Exception( string.Format( "Unexpected callback flag received. Expected: {0}, Received: {1}.", CallbackFlags.CALLBACK, requestFlag)); } // Use callback id to get the registered handler. int callbackId = SerDe.ReadInt32(inputStream); if (!_callbackHandlers.TryGetValue( callbackId, out ICallbackHandler callbackHandler)) { throw new Exception($"Unregistered callback id: {callbackId}"); } s_logger.LogInfo( string.Format( "[{0}] Received request for callback id: {1}, callback handler: {2}", ConnectionId, callbackId, callbackHandler)); // Save contents of callback handler data to be used later. using var callbackDataStream = new MemoryStream(SerDe.ReadBytes(inputStream, SerDe.ReadInt32(inputStream))); // Check the end of stream. int endOfStream = SerDe.ReadInt32(inputStream); if (endOfStream == (int)CallbackFlags.END_OF_STREAM) { s_logger.LogDebug($"[{ConnectionId}] Received END_OF_STREAM signal."); // Run callback handler. callbackHandler.Run(callbackDataStream); SerDe.Write(outputStream, (int)CallbackFlags.END_OF_STREAM); readComplete = true; } else { // This may happen when the input data is not read completely. s_logger.LogWarn( $"[{ConnectionId}] Unexpected end of stream: {endOfStream}."); // Write flag to indicate the connection should be closed. SerDe.Write(outputStream, (int)CallbackFlags.CLOSE); } return(ConnectionStatus.OK); } catch (Exception e) { s_logger.LogError($"[{ConnectionId}] ProcessStream() failed with exception: {e}"); try { SerDe.Write(outputStream, (int)CallbackFlags.DOTNET_EXCEPTION_THROWN); SerDe.Write(outputStream, e.ToString()); } catch (IOException) { // JVM closed the socket. } catch (Exception ex) { s_logger.LogError( $"[{ConnectionId}] Writing exception to stream failed with exception: {ex}"); } throw; } }