示例#1
0
        protected override CommandExecutorStat ExecuteCore(
            Stream inputStream,
            Stream outputStream,
            SqlCommand[] commands)
        {
            var            stat          = new CommandExecutorStat();
            ICommandRunner commandRunner = CreateCommandRunner(commands);

            // On the Spark side, each object in the following List<> is considered as a row.
            // See the ICommandRunner comments above for the types for a row.
            var outputRows = new List <object>();

            // If the input is empty (no rows) or all rows have been read, then
            // SpecialLengths.END_OF_DATA_SECTION is sent as the messageLength.
            // For example, no rows:
            //   +---+----+
            //   |age|name|
            //   +---+----+
            //   +---+----+
            int messageLength = 0;

            while ((messageLength = SerDe.ReadInt32(inputStream)) !=
                   (int)SpecialLengths.END_OF_DATA_SECTION)
            {
                if ((messageLength > 0) || (messageLength == (int)SpecialLengths.NULL))
                {
                    if (messageLength <= 0)
                    {
                        throw new InvalidDataException(
                                  $"Invalid message length: {messageLength}");
                    }

                    // Each row in inputRows is of type object[]. If a null is present in a row
                    // then the corresponding index column of the row object[] will be set to null.
                    // For example, (inputRows.Length == 2) and (inputRows[0][0] == null)
                    //   +----+
                    //   | age|
                    //   +----+
                    //   |null|
                    //   |  11|
                    //   +----+
                    object[] inputRows = PythonSerDe.GetUnpickledObjects(inputStream, messageLength);

                    for (int i = 0; i < inputRows.Length; ++i)
                    {
                        // Split id is not used for SQL UDFs, so 0 is passed.
                        outputRows.Add(commandRunner.Run(0, inputRows[i]));
                    }

                    // The initial (estimated) buffer size for pickling rows is set to the size of input pickled rows
                    // because the number of rows are the same for both input and output.
                    WriteOutput(outputStream, outputRows, messageLength);
                    stat.NumEntriesProcessed += inputRows.Length;
                    outputRows.Clear();
                }
            }

            return(stat);
        }
示例#2
0
 public object Deserialize(Stream stream, int length)
 {
     // Refer to the AutoBatchedPickler class in spark/core/src/main/scala/org/apache/
     // spark/api/python/SerDeUtil.scala regarding how the Rows may be batched.
     return(PythonSerDe.GetUnpickledObjects(stream, length)
            .Cast <RowConstructor>()
            .Select(rc => rc.GetRow())
            .ToArray());
 }
示例#3
0
        /// <summary>
        /// Collects pickled row objects from the given socket.
        /// </summary>
        /// <param name="socket">Socket the get the stream from</param>
        /// <returns>Collection of row objects</returns>
        public IEnumerable <Row> Collect(ISocketWrapper socket)
        {
            Stream inputStream = socket.InputStream;

            int?length;

            while (((length = SerDe.ReadBytesLength(inputStream)) != null) && (length.GetValueOrDefault() > 0))
            {
                object[] unpickledObjects = PythonSerDe.GetUnpickledObjects(inputStream, length.GetValueOrDefault());

                foreach (object unpickled in unpickledObjects)
                {
                    yield return((unpickled as RowConstructor).GetRow());
                }
            }
        }
示例#4
0
        public void RowConstructorTest()
        {
            Pickler pickler = CreatePickler();

            var schema       = (StructType)DataType.ParseDataType(_testJsonSchema);
            var row1         = new Row(new object[] { 10, "name1" }, schema);
            var row2         = new Row(new object[] { 15, "name2" }, schema);
            var pickledBytes = pickler.dumps(new[] { row1, row2 });

            // Note that the following will invoke RowConstructor.construct().
            var unpickledData = PythonSerDe.GetUnpickledObjects(new MemoryStream(pickledBytes));

            Assert.Equal(2, unpickledData.Length);
            Assert.Equal(row1, (unpickledData[0] as RowConstructor).GetRow());
            Assert.Equal(row2, (unpickledData[1] as RowConstructor).GetRow());
        }
示例#5
0
        public IEnumerable <dynamic> Collect(SocketInfo info, SerializedMode serializedMode, Type type)
        {
            IFormatter formatter = new BinaryFormatter();
            var        sock      = SocketFactory.CreateSocket();

            sock.Connect(IPAddress.Loopback, info.Port, null);

            using (var s = sock.GetStream())
            {
                if (info.Secret != null)
                {
                    SerDe.Write(s, info.Secret);
                    var reply = SerDe.ReadString(s);
                    Logger.LogDebug("Connect back to JVM: " + reply);
                }
                byte[] buffer;
                while ((buffer = SerDe.ReadBytes(s)) != null && buffer.Length > 0)
                {
                    if (serializedMode == SerializedMode.Byte)
                    {
                        MemoryStream ms = new MemoryStream(buffer);
                        yield return(formatter.Deserialize(ms));
                    }
                    else if (serializedMode == SerializedMode.String)
                    {
                        yield return(Encoding.UTF8.GetString(buffer));
                    }
                    else if (serializedMode == SerializedMode.Pair)
                    {
                        MemoryStream ms  = new MemoryStream(buffer);
                        MemoryStream ms2 = new MemoryStream(SerDe.ReadBytes(s));

                        ConstructorInfo ci = type.GetConstructors()[0];
                        yield return(ci.Invoke(new object[] { formatter.Deserialize(ms), formatter.Deserialize(ms2) }));
                    }
                    else if (serializedMode == SerializedMode.Row)
                    {
                        var unpickledObjects = PythonSerDe.GetUnpickledObjects(buffer);
                        foreach (var item in unpickledObjects)
                        {
                            yield return((item as RowConstructor).GetRow());
                        }
                    }
                }
            }
        }
示例#6
0
        public void TestSerDeWithPythonSerDe()
        {
            const int expectedCount = 5;

            using (var ms = new MemoryStream())
            {
                new StructTypePickler().Register();
                new RowPickler().Register();
                var pickler = new Pickler();
                for (int i = 0; i < expectedCount; i++)
                {
                    var pickleBytes = pickler.dumps(new[] { RowHelper.BuildRowForBasicSchema(i) });
                    SerDe.Write(ms, pickleBytes.Length);
                    SerDe.Write(ms, pickleBytes);
                }

                SerDe.Write(ms, (int)SpecialLengths.END_OF_STREAM);
                ms.Flush();

                ms.Position = 0;
                int count = 0;
                while (true)
                {
                    byte[] outBuffer = null;
                    int    length    = SerDe.ReadInt(ms);
                    if (length > 0)
                    {
                        outBuffer = SerDe.ReadBytes(ms, length);
                    }
                    else if (length == (int)SpecialLengths.END_OF_STREAM)
                    {
                        break;
                    }

                    var unpickledObjs = PythonSerDe.GetUnpickledObjects(outBuffer);
                    var rows          = unpickledObjs.Select(item => (item as RowConstructor).GetRow()).ToList();
                    Assert.AreEqual(1, rows.Count);
                    Assert.AreEqual(count++, rows[0].Get("age"));
                }
                Assert.AreEqual(expectedCount, count);
            }
        }
示例#7
0
        public IEnumerable <dynamic> Collect(int port, SerializedMode serializedMode, Type type)
        {
            IFormatter formatter = new BinaryFormatter();
            Socket     sock      = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

            sock.Connect(IPAddress.Loopback, port);

            using (NetworkStream s = new NetworkStream(sock))
            {
                byte[] buffer;
                while ((buffer = SerDe.ReadBytes(s)) != null && buffer.Length > 0)
                {
                    if (serializedMode == SerializedMode.Byte)
                    {
                        MemoryStream ms = new MemoryStream(buffer);
                        yield return(formatter.Deserialize(ms));
                    }
                    else if (serializedMode == SerializedMode.String)
                    {
                        yield return(Encoding.UTF8.GetString(buffer));
                    }
                    else if (serializedMode == SerializedMode.Pair)
                    {
                        MemoryStream ms  = new MemoryStream(buffer);
                        MemoryStream ms2 = new MemoryStream(SerDe.ReadBytes(s));

                        ConstructorInfo ci = type.GetConstructors()[0];
                        yield return(ci.Invoke(new object[] { formatter.Deserialize(ms), formatter.Deserialize(ms2) }));
                    }
                    else if (serializedMode == SerializedMode.Row)
                    {
                        var unpickledObjects = PythonSerDe.GetUnpickledObjects(buffer);
                        foreach (var item in unpickledObjects)
                        {
                            yield return((item as RowConstructor).GetRow());
                        }
                    }
                }
            }
        }
示例#8
0
        /// <summary>
        /// Collects pickled row objects from the given socket.
        /// </summary>
        /// <param name="socket">Socket the get the stream from</param>
        /// <returns>Collection of row objects</returns>
        public IEnumerable <Row> Collect(ISocketWrapper socket)
        {
            Stream inputStream = socket.InputStream;

            int?length;

            while (((length = SerDe.ReadBytesLength(inputStream)) != null) &&
                   (length.GetValueOrDefault() > 0))
            {
                object[] unpickledObjects =
                    PythonSerDe.GetUnpickledObjects(inputStream, length.GetValueOrDefault());

                foreach (object unpickled in unpickledObjects)
                {
                    // Unpickled object can be either a RowConstructor object (not materialized),
                    // or a Row object (materialized). Refer to RowConstruct.construct() to see how
                    // Row objects are unpickled.
                    switch (unpickled)
                    {
                    case RowConstructor rc:
                        yield return(rc.GetRow());

                        break;

                    case object[] objs when objs.Length == 1 && (objs[0] is Row row):
                        yield return(row);

                        break;

                    default:
                        throw new NotSupportedException(
                                  string.Format("Unpickle type {0} is not supported",
                                                unpickled.GetType()));
                    }
                }
            }
        }
示例#9
0
        private object[] GetNext(int messageLength)
        {
            object[] result = null;
            switch ((SerializedMode)Enum.Parse(typeof(SerializedMode), deserializedMode))
            {
            case SerializedMode.String:
            {
                result = new object[1];
                if (messageLength > 0)
                {
                    byte[] buffer = SerDe.ReadBytes(inputStream, messageLength);
                    result[0] = SerDe.ToString(buffer);
                }
                else
                {
                    result[0] = null;
                }
                break;
            }

            case SerializedMode.Row:
            {
                Debug.Assert(messageLength > 0);
                byte[] buffer           = SerDe.ReadBytes(inputStream, messageLength);
                var    unpickledObjects = PythonSerDe.GetUnpickledObjects(buffer);
                var    rows             = unpickledObjects.Select(item => (item as RowConstructor).GetRow()).ToList();
                result = rows.Cast <object>().ToArray();
                break;
            }

            case SerializedMode.Pair:
            {
                byte[] pairKey   = (messageLength > 0) ? SerDe.ReadBytes(inputStream, messageLength) : null;
                byte[] pairValue = null;

                int valueLength = SerDe.ReadInt(inputStream);
                if (valueLength > 0)
                {
                    pairValue = SerDe.ReadBytes(inputStream, valueLength);
                }
                else if (valueLength == (int)SpecialLengths.NULL)
                {
                    pairValue = null;
                }
                else
                {
                    throw new Exception(string.Format("unexpected valueLength: {0}", valueLength));
                }

                result    = new object[1];
                result[0] = new KeyValuePair <byte[], byte[]>(pairKey, pairValue);
                break;
            }

            case SerializedMode.None:     //just read raw bytes
            {
                result = new object[1];
                if (messageLength > 0)
                {
                    result[0] = SerDe.ReadBytes(inputStream, messageLength);
                }
                else
                {
                    result[0] = null;
                }
                break;
            }

            case SerializedMode.Byte:
            default:
            {
                result = new object[1];
                if (messageLength > 0)
                {
                    byte[] buffer = SerDe.ReadBytes(inputStream, messageLength);
                    var    ms     = new MemoryStream(buffer);
                    result[0] = formatter.Deserialize(ms);
                }
                else
                {
                    result[0] = null;
                }

                break;
            }
            }

            return(result);
        }
示例#10
0
        private static IEnumerable <dynamic> GetIterator(Stream inputStream, string serializedMode, int isFuncSqlUdf)
        {
            logger.LogInfo("Serialized mode in GetIterator: " + serializedMode);
            IFormatter formatter = new BinaryFormatter();
            var        mode      = (SerializedMode)Enum.Parse(typeof(SerializedMode), serializedMode);
            int        messageLength;
            Stopwatch  watch = Stopwatch.StartNew();

            while ((messageLength = SerDe.ReadInt(inputStream)) != (int)SpecialLengths.END_OF_DATA_SECTION)
            {
                watch.Stop();
                if (messageLength > 0 || messageLength == (int)SpecialLengths.NULL)
                {
                    watch.Start();
                    byte[] buffer = messageLength > 0 ? SerDe.ReadBytes(inputStream, messageLength) : null;
                    watch.Stop();
                    switch (mode)
                    {
                    case SerializedMode.String:
                    {
                        if (messageLength > 0)
                        {
                            if (buffer == null)
                            {
                                logger.LogDebug("Buffer is null. Message length is {0}", messageLength);
                            }
                            yield return(SerDe.ToString(buffer));
                        }
                        else
                        {
                            yield return(null);
                        }
                        break;
                    }

                    case SerializedMode.Row:
                    {
                        Debug.Assert(messageLength > 0);
                        var unpickledObjects = PythonSerDe.GetUnpickledObjects(buffer);

                        if (isFuncSqlUdf == 0)
                        {
                            foreach (var row in unpickledObjects.Select(item => (item as RowConstructor).GetRow()))
                            {
                                yield return(row);
                            }
                        }
                        else
                        {
                            foreach (var row in unpickledObjects)
                            {
                                yield return(row);
                            }
                        }

                        break;
                    }

                    case SerializedMode.Pair:
                    {
                        byte[] pairKey   = buffer;
                        byte[] pairValue = null;

                        watch.Start();
                        int valueLength = SerDe.ReadInt(inputStream);
                        if (valueLength > 0)
                        {
                            pairValue = SerDe.ReadBytes(inputStream, valueLength);
                        }
                        else if (valueLength == (int)SpecialLengths.NULL)
                        {
                            pairValue = null;
                        }
                        else
                        {
                            throw new Exception(string.Format("unexpected valueLength: {0}", valueLength));
                        }
                        watch.Stop();

                        yield return(new KeyValuePair <byte[], byte[]>(pairKey, pairValue));

                        break;
                    }

                    case SerializedMode.None:     //just return raw bytes
                    {
                        yield return(buffer);

                        break;
                    }

                    case SerializedMode.Byte:
                    default:
                    {
                        if (buffer != null)
                        {
                            var ms = new MemoryStream(buffer);
                            yield return(formatter.Deserialize(ms));
                        }
                        else
                        {
                            yield return(null);
                        }
                        break;
                    }
                    }
                }
                watch.Start();
            }

            logger.LogInfo(string.Format("total receive time: {0}", watch.ElapsedMilliseconds));
        }