コード例 #1
0
ファイル: SqlCommandExecutor.cs プロジェクト: sullaspqr/spark
        private CommandExecutorStat ExecuteArrowGroupedMapCommand(
            Stream inputStream,
            Stream outputStream,
            SqlCommand[] commands)
        {
            Debug.Assert(commands.Length == 1,
                         "Grouped Map UDFs do not support combining multiple UDFs.");

            var stat   = new CommandExecutorStat();
            var worker = (ArrowGroupedMapWorkerFunction)commands[0].WorkerFunction;

            SerDe.Write(outputStream, (int)SpecialLengths.START_ARROW_STREAM);

            IpcOptions        ipcOptions = ArrowIpcOptions();
            ArrowStreamWriter writer     = null;

            foreach (RecordBatch input in GetInputIterator(inputStream))
            {
                RecordBatch result = worker.Func(input);

                int numEntries = result.Length;
                stat.NumEntriesProcessed += numEntries;

                if (writer == null)
                {
                    writer =
                        new ArrowStreamWriter(outputStream, result.Schema, leaveOpen: true, ipcOptions);
                }

                // TODO: Remove sync-over-async once WriteRecordBatch exists.
                writer.WriteRecordBatchAsync(result).GetAwaiter().GetResult();
            }

            WriteEnd(outputStream, ipcOptions);
            writer?.Dispose();

            return(stat);
        }
コード例 #2
0
        private CommandExecutorStat ExecuteArrowGroupedMapCommand(
            Stream inputStream,
            Stream outputStream,
            SqlCommand[] commands)
        {
            Debug.Assert(commands.Length == 1,
                         "Grouped Map UDFs do not support combining multiple UDFs.");

            var stat   = new CommandExecutorStat();
            var worker = (ArrowGroupedMapWorkerFunction)commands[0].WorkerFunction;

            SerDe.Write(outputStream, (int)SpecialLengths.START_ARROW_STREAM);

            IpcOptions        ipcOptions = ArrowIpcOptions();
            ArrowStreamWriter writer     = null;

            foreach (RecordBatch input in GetInputIterator(inputStream))
            {
                RecordBatch batch = worker.Func(input);

                RecordBatch final      = WrapColumnsInStructIfApplicable(batch);
                int         numEntries = final.Length;
                stat.NumEntriesProcessed += numEntries;

                if (writer == null)
                {
                    writer =
                        new ArrowStreamWriter(outputStream, final.Schema, leaveOpen: true, ipcOptions);
                }

                writer.WriteRecordBatch(final);
            }

            WriteEnd(outputStream, ipcOptions);
            writer?.Dispose();

            return(stat);
        }
コード例 #3
0
        private static async Task TestReaderFromMemory(
            Func <ArrowStreamReader, RecordBatch, Task> verificationFunc,
            bool writeEnd)
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);

            byte[] buffer;
            using (MemoryStream stream = new MemoryStream())
            {
                ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema);
                await writer.WriteRecordBatchAsync(originalBatch);

                if (writeEnd)
                {
                    await writer.WriteEndAsync();
                }
                buffer = stream.GetBuffer();
            }

            ArrowStreamReader reader = new ArrowStreamReader(buffer);

            await verificationFunc(reader, originalBatch);
        }
コード例 #4
0
        public async Task TestArrowGroupedMapCommandExecutor()
        {
            StringArray ConvertStrings(StringArray strings)
            {
                return((StringArray)ToArrowArray(
                           Enumerable.Range(0, strings.Length)
                           .Select(i => $"udf: {strings.GetString(i)}")
                           .ToArray()));
            }

            Int64Array ConvertInt64s(Int64Array int64s)
            {
                return((Int64Array)ToArrowArray(
                           Enumerable.Range(0, int64s.Length)
                           .Select(i => int64s.Values[i] + 100)
                           .ToArray()));
            }

            Schema resultSchema = new Schema.Builder()
                                  .Field(b => b.Name("arg1").DataType(StringType.Default))
                                  .Field(b => b.Name("arg2").DataType(Int64Type.Default))
                                  .Build();

            var udfWrapper = new Sql.ArrowGroupedMapUdfWrapper(
                (batch) => new RecordBatch(
                    resultSchema,
                    new IArrowArray[]
            {
                ConvertStrings((StringArray)batch.Column(0)),
                ConvertInt64s((Int64Array)batch.Column(1)),
            },
                    batch.Length));

            var command = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowGroupedMapWorkerFunction(udfWrapper.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                Commands = new[] { command }
            };

            using (var inputStream = new MemoryStream())
                using (var outputStream = new MemoryStream())
                {
                    int numRows = 10;

                    // Write test data to the input stream.
                    Schema schema = new Schema.Builder()
                                    .Field(b => b.Name("arg1").DataType(StringType.Default))
                                    .Field(b => b.Name("arg2").DataType(Int64Type.Default))
                                    .Build();
                    var arrowWriter = new ArrowStreamWriter(inputStream, schema);
                    await arrowWriter.WriteRecordBatchAsync(
                        new RecordBatch(
                            schema,
                            new[]
                    {
                        ToArrowArray(
                            Enumerable.Range(0, numRows)
                            .Select(i => i.ToString())
                            .ToArray()),
                        ToArrowArray(
                            Enumerable.Range(0, numRows)
                            .Select(i => (long)i)
                            .ToArray())
                    },
                            numRows));

                    inputStream.Seek(0, SeekOrigin.Begin);

                    CommandExecutorStat stat = new CommandExecutor().Execute(
                        inputStream,
                        outputStream,
                        0,
                        commandPayload);

                    // Validate that all the data on the stream is read.
                    Assert.Equal(inputStream.Length, inputStream.Position);
                    Assert.Equal(numRows, stat.NumEntriesProcessed);

                    // Validate the output stream.
                    outputStream.Seek(0, SeekOrigin.Begin);
                    int arrowLength = SerDe.ReadInt32(outputStream);
                    Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
                    var         arrowReader = new ArrowStreamReader(outputStream);
                    RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();

                    Assert.Equal(numRows, outputBatch.Length);
                    Assert.Equal(2, outputBatch.ColumnCount);

                    var stringArray = (StringArray)outputBatch.Column(0);
                    for (int i = 0; i < numRows; ++i)
                    {
                        Assert.Equal($"udf: {i}", stringArray.GetString(i));
                    }

                    var longArray = (Int64Array)outputBatch.Column(1);
                    for (int i = 0; i < numRows; ++i)
                    {
                        Assert.Equal(100 + i, longArray.Values[i]);
                    }

                    int end = SerDe.ReadInt32(outputStream);
                    Assert.Equal(0, end);

                    // Validate all the data on the stream is read.
                    Assert.Equal(outputStream.Length, outputStream.Position);
                }
        }
コード例 #5
0
        public void TestArrowSqlCommandExecutorWithEmptyInput()
        {
            var udfWrapper = new Sql.ArrowUdfWrapper <StringArray, StringArray>(
                (strings) => (StringArray)ToArrowArray(
                    Enumerable.Range(0, strings.Length)
                    .Select(i => $"udf: {strings.GetString(i)}")
                    .ToArray()));

            var command = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowWorkerFunction(udfWrapper.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                Commands = new[] { command }
            };

            using (var inputStream = new MemoryStream())
                using (var outputStream = new MemoryStream())
                {
                    // Write test data to the input stream.
                    Schema schema = new Schema.Builder()
                                    .Field(b => b.Name("arg1").DataType(StringType.Default))
                                    .Build();
                    var arrowWriter = new ArrowStreamWriter(inputStream, schema);

                    // The .NET ArrowStreamWriter doesn't currently support writing just a
                    // schema with no batches - but Java does. We use Reflection to simulate
                    // the request Spark sends.
                    MethodInfo writeSchemaMethod = arrowWriter.GetType().GetMethod(
                        "WriteSchemaAsync",
                        BindingFlags.NonPublic | BindingFlags.Instance);

                    writeSchemaMethod.Invoke(
                        arrowWriter,
                        new object[] { schema, CancellationToken.None });

                    SerDe.Write(inputStream, 0);

                    inputStream.Seek(0, SeekOrigin.Begin);

                    CommandExecutorStat stat = new CommandExecutor().Execute(
                        inputStream,
                        outputStream,
                        0,
                        commandPayload);

                    // Validate that all the data on the stream is read.
                    Assert.Equal(inputStream.Length, inputStream.Position);
                    Assert.Equal(0, stat.NumEntriesProcessed);

                    // Validate the output stream.
                    outputStream.Seek(0, SeekOrigin.Begin);
                    int arrowLength = SerDe.ReadInt32(outputStream);
                    Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
                    var         arrowReader = new ArrowStreamReader(outputStream);
                    RecordBatch outputBatch = arrowReader.ReadNextRecordBatch();

                    Assert.Equal(1, outputBatch.Schema.Fields.Count);
                    Assert.IsType <StringType>(outputBatch.Schema.GetFieldByIndex(0).DataType);

                    Assert.Equal(0, outputBatch.Length);
                    Assert.Single(outputBatch.Arrays);

                    var array = (StringArray)outputBatch.Arrays.ElementAt(0);
                    Assert.Equal(0, array.Length);

                    int end = SerDe.ReadInt32(outputStream);
                    Assert.Equal(0, end);

                    // Validate all the data on the stream is read.
                    Assert.Equal(outputStream.Length, outputStream.Position);
                }
        }
コード例 #6
0
        public async Task TestArrowSqlCommandExecutorWithMultiCommands()
        {
            var udfWrapper1 = new Sql.ArrowUdfWrapper <StringArray, StringArray>(
                (strings) => (StringArray)ToArrowArray(
                    Enumerable.Range(0, strings.Length)
                    .Select(i => $"udf: {strings.GetString(i)}")
                    .ToArray()));
            var udfWrapper2 = new Sql.ArrowUdfWrapper <Int32Array, Int32Array, Int32Array>(
                (arg1, arg2) => (Int32Array)ToArrowArray(
                    Enumerable.Range(0, arg1.Length)
                    .Select(i => arg1.Values[i] * arg2.Values[i])
                    .ToArray()));

            var command1 = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowWorkerFunction(udfWrapper1.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var command2 = new SqlCommand()
            {
                ArgOffsets          = new[] { 1, 2 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowWorkerFunction(udfWrapper2.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                Commands = new[] { command1, command2 }
            };

            using (var inputStream = new MemoryStream())
                using (var outputStream = new MemoryStream())
                {
                    int numRows = 10;

                    // Write test data to the input stream.
                    Schema schema = new Schema.Builder()
                                    .Field(b => b.Name("arg1").DataType(StringType.Default))
                                    .Field(b => b.Name("arg2").DataType(Int32Type.Default))
                                    .Field(b => b.Name("arg3").DataType(Int32Type.Default))
                                    .Build();
                    var arrowWriter = new ArrowStreamWriter(inputStream, schema);
                    await arrowWriter.WriteRecordBatchAsync(
                        new RecordBatch(
                            schema,
                            new[]
                    {
                        ToArrowArray(
                            Enumerable.Range(0, numRows)
                            .Select(i => i.ToString())
                            .ToArray()),
                        ToArrowArray(Enumerable.Range(0, numRows).ToArray()),
                        ToArrowArray(Enumerable.Range(0, numRows).ToArray()),
                    },
                            numRows));

                    inputStream.Seek(0, SeekOrigin.Begin);

                    CommandExecutorStat stat = new CommandExecutor().Execute(
                        inputStream,
                        outputStream,
                        0,
                        commandPayload);

                    // Validate all the data on the stream is read.
                    Assert.Equal(inputStream.Length, inputStream.Position);
                    Assert.Equal(numRows, stat.NumEntriesProcessed);

                    // Validate the output stream.
                    outputStream.Seek(0, SeekOrigin.Begin);
                    var arrowLength = SerDe.ReadInt32(outputStream);
                    Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
                    var         arrowReader = new ArrowStreamReader(outputStream);
                    RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();

                    Assert.Equal(numRows, outputBatch.Length);
                    Assert.Equal(2, outputBatch.Arrays.Count());
                    var array1 = (StringArray)outputBatch.Arrays.ElementAt(0);
                    var array2 = (Int32Array)outputBatch.Arrays.ElementAt(1);
                    for (int i = 0; i < numRows; ++i)
                    {
                        Assert.Equal($"udf: {i}", array1.GetString(i));
                        Assert.Equal(i * i, array2.Values[i]);
                    }

                    int end = SerDe.ReadInt32(outputStream);
                    Assert.Equal(0, end);

                    // Validate all the data on the stream is read.
                    Assert.Equal(outputStream.Length, outputStream.Position);
                }
        }
コード例 #7
0
 public async Task WriteBatch()
 {
     ArrowStreamWriter writer = new ArrowStreamWriter(_memoryStream, _batch.Schema);
     await writer.WriteRecordBatchAsync(_batch);
 }
コード例 #8
0
        public async Task TestDataFrameSqlCommandExecutorWithSingleCommand(
            Version sparkVersion,
            IpcOptions ipcOptions)
        {
            var udfWrapper = new Sql.DataFrameUdfWrapper <ArrowStringDataFrameColumn, ArrowStringDataFrameColumn>(
                (strings) => strings.Apply(cur => $"udf: {cur}"));

            var command = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.DataFrameWorkerFunction(udfWrapper.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                Commands = new[] { command }
            };

            using var inputStream  = new MemoryStream();
            using var outputStream = new MemoryStream();
            int numRows = 10;

            // Write test data to the input stream.
            Schema schema = new Schema.Builder()
                            .Field(b => b.Name("arg1").DataType(StringType.Default))
                            .Build();
            var arrowWriter =
                new ArrowStreamWriter(inputStream, schema, leaveOpen: false, ipcOptions);
            await arrowWriter.WriteRecordBatchAsync(
                new RecordBatch(
                    schema,
                    new[]
            {
                ToArrowArray(
                    Enumerable.Range(0, numRows)
                    .Select(i => i.ToString())
                    .ToArray())
            },
                    numRows));

            inputStream.Seek(0, SeekOrigin.Begin);

            CommandExecutorStat stat = new CommandExecutor(sparkVersion).Execute(
                inputStream,
                outputStream,
                0,
                commandPayload);

            // Validate that all the data on the stream is read.
            Assert.Equal(inputStream.Length, inputStream.Position);
            Assert.Equal(numRows, stat.NumEntriesProcessed);

            // Validate the output stream.
            outputStream.Seek(0, SeekOrigin.Begin);
            int arrowLength = SerDe.ReadInt32(outputStream);

            Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
            var         arrowReader = new ArrowStreamReader(outputStream);
            RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();

            Assert.Equal(numRows, outputBatch.Length);
            Assert.Single(outputBatch.Arrays);
            var array = (StringArray)outputBatch.Arrays.ElementAt(0);

            // Validate the single command.
            for (int i = 0; i < numRows; ++i)
            {
                Assert.Equal($"udf: {i}", array.GetString(i));
            }

            CheckEOS(outputStream, ipcOptions);

            // Validate all the data on the stream is read.
            Assert.Equal(outputStream.Length, outputStream.Position);
        }
コード例 #9
0
        public async Task WriteBatchWithCorrectPaddingAsync()
        {
            byte value1 = 0x04;
            byte value2 = 0x14;
            var  batch  = new RecordBatch(
                new Schema.Builder()
                .Field(f => f.Name("age").DataType(Int32Type.Default))
                .Field(f => f.Name("characterCount").DataType(Int32Type.Default))
                .Build(),
                new IArrowArray[]
            {
                new Int32Array(
                    new ArrowBuffer(new byte[] { value1, value1, 0x00, 0x00 }),
                    ArrowBuffer.Empty,
                    length: 1,
                    nullCount: 0,
                    offset: 0),
                new Int32Array(
                    new ArrowBuffer(new byte[] { value2, value2, 0x00, 0x00 }),
                    ArrowBuffer.Empty,
                    length: 1,
                    nullCount: 0,
                    offset: 0)
            },
                length: 1);

            await TestRoundTripRecordBatchAsync(batch);

            using (MemoryStream stream = new MemoryStream())
            {
                using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true))
                {
                    await writer.WriteRecordBatchAsync(batch);

                    await writer.WriteEndAsync();
                }

                byte[] writtenBytes = stream.ToArray();

                // ensure that the data buffers at the end are 8-byte aligned
                Assert.Equal(value1, writtenBytes[writtenBytes.Length - 24]);
                Assert.Equal(value1, writtenBytes[writtenBytes.Length - 23]);
                for (int i = 22; i > 16; i--)
                {
                    Assert.Equal(0, writtenBytes[writtenBytes.Length - i]);
                }

                Assert.Equal(value2, writtenBytes[writtenBytes.Length - 16]);
                Assert.Equal(value2, writtenBytes[writtenBytes.Length - 15]);
                for (int i = 14; i > 8; i--)
                {
                    Assert.Equal(0, writtenBytes[writtenBytes.Length - i]);
                }

                // verify the EOS is written correctly
                for (int i = 8; i > 4; i--)
                {
                    Assert.Equal(0xFF, writtenBytes[writtenBytes.Length - i]);
                }
                for (int i = 4; i > 0; i--)
                {
                    Assert.Equal(0x00, writtenBytes[writtenBytes.Length - i]);
                }
            }
        }