예제 #1
0
        internal static void BuildPayload(
            MemoryStream destination,
            bool isStaticMethod,
            object classNameOrJvmObjectReference,
            string methodName,
            object[] args)
        {
            // Reserve space for total length.
            var originalPosition = destination.Position;

            destination.Position += sizeof(int);

            SerDe.Write(destination, isStaticMethod);
            SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
            SerDe.Write(destination, methodName);
            SerDe.Write(destination, args.Length);
            ConvertArgsToBytes(destination, args);

            // Write the length now that we've written out everything else.
            var afterPosition = destination.Position;

            destination.Position = originalPosition;
            SerDe.Write(destination, (int)afterPosition - sizeof(int));
            destination.Position = afterPosition;
        }
예제 #2
0
        internal static void WriteIEnumerableObjects <T>(MemoryStream stream, IEnumerable <T> enumlist, Action <MemoryStream, T> writefunc)
        {
            var posBeforeEnumerable = stream.Position;

            stream.Position += sizeof(int);
            var itemCount = 0;

            foreach (var obj in enumlist)
            {
                itemCount++;
                writefunc(stream, obj);
            }
            var posAfterEnumerable = stream.Position;

            stream.Position = posBeforeEnumerable;
            SerDe.Write(stream, itemCount);
            stream.Position = posAfterEnumerable;
        }
예제 #3
0
        internal static void ConvertArgsToBytes(
            MemoryStream destination,
            object[] args,
            bool addTypeIdPrefix = true)
        {
            long posBeforeEnumerable, posAfterEnumerable;
            int  itemCount;

            object[] convertArgs = null;

            foreach (object arg in args)
            {
                if (arg == null)
                {
                    destination.WriteByte((byte)'n');
                    continue;
                }

                Type argType = arg.GetType();

                if (addTypeIdPrefix)
                {
                    SerDe.Write(destination, GetTypeId(argType));
                }

                switch (Type.GetTypeCode(argType))
                {
                case TypeCode.Int32:
                    SerDe.Write(destination, (int)arg);
                    break;

                case TypeCode.Int64:
                    SerDe.Write(destination, (long)arg);
                    break;

                case TypeCode.String:
                    SerDe.Write(destination, (string)arg);
                    break;

                case TypeCode.Boolean:
                    SerDe.Write(destination, (bool)arg);
                    break;

                case TypeCode.Double:
                    SerDe.Write(destination, (double)arg);
                    break;

                case TypeCode.Object:
                    switch (arg)
                    {
                    case byte[] argByteArray:
                        SerDe.Write(destination, argByteArray.Length);
                        SerDe.Write(destination, argByteArray);
                        break;

                    case int[] argInt32Array:
                        SerDe.Write(destination, s_int32TypeId);
                        SerDe.Write(destination, argInt32Array.Length);
                        foreach (int i in argInt32Array)
                        {
                            SerDe.Write(destination, i);
                        }
                        break;

                    case long[] argInt64Array:
                        SerDe.Write(destination, s_int64TypeId);
                        SerDe.Write(destination, argInt64Array.Length);
                        foreach (long i in argInt64Array)
                        {
                            SerDe.Write(destination, i);
                        }
                        break;

                    case double[] argDoubleArray:
                        SerDe.Write(destination, s_doubleTypeId);
                        SerDe.Write(destination, argDoubleArray.Length);
                        foreach (double d in argDoubleArray)
                        {
                            SerDe.Write(destination, d);
                        }
                        break;

                    case double[][] argDoubleArrayArray:
                        SerDe.Write(destination, s_doubleArrayArrayTypeId);
                        SerDe.Write(destination, argDoubleArrayArray.Length);
                        foreach (double[] doubleArray in argDoubleArrayArray)
                        {
                            SerDe.Write(destination, doubleArray.Length);
                            foreach (double d in doubleArray)
                            {
                                SerDe.Write(destination, d);
                            }
                        }
                        break;

                    case IEnumerable <byte[]> argByteArrayEnumerable:
                        SerDe.Write(destination, s_byteArrayTypeId);
                        posBeforeEnumerable   = destination.Position;
                        destination.Position += sizeof(int);
                        itemCount             = 0;
                        foreach (byte[] b in argByteArrayEnumerable)
                        {
                            ++itemCount;
                            SerDe.Write(destination, b.Length);
                            destination.Write(b, 0, b.Length);
                        }
                        posAfterEnumerable   = destination.Position;
                        destination.Position = posBeforeEnumerable;
                        SerDe.Write(destination, itemCount);
                        destination.Position = posAfterEnumerable;
                        break;

                    case IEnumerable <string> argStringEnumerable:
                        SerDe.Write(destination, s_stringTypeId);
                        posBeforeEnumerable   = destination.Position;
                        destination.Position += sizeof(int);
                        itemCount             = 0;
                        foreach (string s in argStringEnumerable)
                        {
                            ++itemCount;
                            SerDe.Write(destination, s);
                        }
                        posAfterEnumerable   = destination.Position;
                        destination.Position = posBeforeEnumerable;
                        SerDe.Write(destination, itemCount);
                        destination.Position = posAfterEnumerable;
                        break;

                    case IEnumerable <IJvmObjectReferenceProvider> argJvmEnumerable:
                        SerDe.Write(destination, s_jvmObjectTypeId);
                        posBeforeEnumerable   = destination.Position;
                        destination.Position += sizeof(int);
                        itemCount             = 0;
                        foreach (IJvmObjectReferenceProvider jvmObject in argJvmEnumerable)
                        {
                            ++itemCount;
                            SerDe.Write(destination, jvmObject.Reference.Id);
                        }
                        posAfterEnumerable   = destination.Position;
                        destination.Position = posBeforeEnumerable;
                        SerDe.Write(destination, itemCount);
                        destination.Position = posAfterEnumerable;
                        break;

                    case IEnumerable <GenericRow> argRowEnumerable:
                        posBeforeEnumerable   = destination.Position;
                        destination.Position += sizeof(int);
                        itemCount             = 0;
                        foreach (GenericRow r in argRowEnumerable)
                        {
                            ++itemCount;
                            SerDe.Write(destination, (int)r.Values.Length);
                            ConvertArgsToBytes(destination, r.Values, true);
                        }
                        posAfterEnumerable   = destination.Position;
                        destination.Position = posBeforeEnumerable;
                        SerDe.Write(destination, itemCount);
                        destination.Position = posAfterEnumerable;
                        break;

                    case var _ when IsDictionary(arg.GetType()):
                        // Generic dictionary, but we don't have it strongly typed as
                        // Dictionary<T,U>
                        var dictInterface = (IDictionary)arg;

                        var dict = new Dictionary <object, object>(dictInterface.Count);
                        IDictionaryEnumerator iter = dictInterface.GetEnumerator();
                        while (iter.MoveNext())
                        {
                            dict[iter.Key] = iter.Value;
                        }

                        // Below serialization is corresponding to deserialization method
                        // ReadMap() of SerDe.scala.

                        // dictionary's length
                        SerDe.Write(destination, dict.Count);

                        // keys' data type
                        SerDe.Write(
                            destination,
                            GetTypeId(arg.GetType().GetGenericArguments()[0]));
                        // keys' length, same as dictionary's length
                        SerDe.Write(destination, dict.Count);
                        if (convertArgs == null)
                        {
                            convertArgs = new object[1];
                        }
                        foreach (KeyValuePair <object, object> kv in dict)
                        {
                            convertArgs[0] = kv.Key;
                            // keys, do not need type prefix.
                            ConvertArgsToBytes(destination, convertArgs, false);
                        }

                        // values' length, same as dictionary's length
                        SerDe.Write(destination, dict.Count);
                        foreach (KeyValuePair <object, object> kv in dict)
                        {
                            convertArgs[0] = kv.Value;
                            // values, need type prefix.
                            ConvertArgsToBytes(destination, convertArgs, true);
                        }
                        break;

                    case IJvmObjectReferenceProvider argProvider:
                        SerDe.Write(destination, argProvider.Reference.Id);
                        break;

                    default:
                        throw new NotSupportedException(
                                  string.Format($"Type {arg.GetType()} is not supported"));
                    }
                    break;
                }
            }
        }
예제 #4
0
        /// <summary>
        /// Process the input and output streams.
        /// </summary>
        /// <param name="inputStream">The input stream.</param>
        /// <param name="outputStream">The output stream.</param>
        /// <param name="readComplete">True if stream is read completely, false otherwise.</param>
        /// <returns>The connection status.</returns>
        private ConnectionStatus ProcessStream(
            Stream inputStream,
            Stream outputStream,
            out bool readComplete)
        {
            readComplete = false;

            try
            {
                byte[] requestFlagBytes = SerDe.ReadBytes(inputStream, sizeof(int));
                // For socket stream, read on the stream returns 0, which
                // SerDe.ReadBytes() returns as null to denote the stream is closed.
                if (requestFlagBytes == null)
                {
                    return(ConnectionStatus.SOCKET_CLOSED);
                }

                // Check value of the initial request. Expected values are:
                // - CallbackFlags.CLOSE
                // - CallbackFlags.CALLBACK
                int requestFlag = BinaryPrimitives.ReadInt32BigEndian(requestFlagBytes);
                if (requestFlag == (int)CallbackFlags.CLOSE)
                {
                    return(ConnectionStatus.REQUEST_CLOSE);
                }
                else if (requestFlag != (int)CallbackFlags.CALLBACK)
                {
                    throw new Exception(
                              string.Format(
                                  "Unexpected callback flag received. Expected: {0}, Received: {1}.",
                                  CallbackFlags.CALLBACK,
                                  requestFlag));
                }

                // Use callback id to get the registered handler.
                int callbackId = SerDe.ReadInt32(inputStream);
                if (!_callbackHandlers.TryGetValue(
                        callbackId,
                        out ICallbackHandler callbackHandler))
                {
                    throw new Exception($"Unregistered callback id: {callbackId}");
                }

                s_logger.LogInfo(
                    string.Format(
                        "[{0}] Received request for callback id: {1}, callback handler: {2}",
                        ConnectionId,
                        callbackId,
                        callbackHandler));

                // Save contents of callback handler data to be used later.
                using var callbackDataStream =
                          new MemoryStream(SerDe.ReadBytes(inputStream, SerDe.ReadInt32(inputStream)));

                // Check the end of stream.
                int endOfStream = SerDe.ReadInt32(inputStream);
                if (endOfStream == (int)CallbackFlags.END_OF_STREAM)
                {
                    s_logger.LogDebug($"[{ConnectionId}] Received END_OF_STREAM signal.");

                    // Run callback handler.
                    callbackHandler.Run(callbackDataStream);

                    SerDe.Write(outputStream, (int)CallbackFlags.END_OF_STREAM);
                    readComplete = true;
                }
                else
                {
                    // This may happen when the input data is not read completely.
                    s_logger.LogWarn(
                        $"[{ConnectionId}] Unexpected end of stream: {endOfStream}.");

                    // Write flag to indicate the connection should be closed.
                    SerDe.Write(outputStream, (int)CallbackFlags.CLOSE);
                }

                return(ConnectionStatus.OK);
            }
            catch (Exception e)
            {
                s_logger.LogError($"[{ConnectionId}] ProcessStream() failed with exception: {e}");

                try
                {
                    SerDe.Write(outputStream, (int)CallbackFlags.DOTNET_EXCEPTION_THROWN);
                    SerDe.Write(outputStream, e.ToString());
                }
                catch (IOException)
                {
                    // JVM closed the socket.
                }
                catch (Exception ex)
                {
                    s_logger.LogError(
                        $"[{ConnectionId}] Writing exception to stream failed with exception: {ex}");
                }

                throw;
            }
        }