コード例 #1
0
        public void TestArrowSqlCommandExecutorWithEmptyInput()
        {
            var udfWrapper = new Sql.ArrowUdfWrapper <StringArray, StringArray>(
                (strings) => (StringArray)ToArrowArray(
                    Enumerable.Range(0, strings.Length)
                    .Select(i => $"udf: {strings.GetString(i)}")
                    .ToArray()));

            var command = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowWorkerFunction(udfWrapper.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                Commands = new[] { command }
            };

            using (var inputStream = new MemoryStream())
                using (var outputStream = new MemoryStream())
                {
                    // Write test data to the input stream.
                    Schema schema = new Schema.Builder()
                                    .Field(b => b.Name("arg1").DataType(StringType.Default))
                                    .Build();
                    var arrowWriter = new ArrowStreamWriter(inputStream, schema);

                    // The .NET ArrowStreamWriter doesn't currently support writing just a
                    // schema with no batches - but Java does. We use Reflection to simulate
                    // the request Spark sends.
                    MethodInfo writeSchemaMethod = arrowWriter.GetType().GetMethod(
                        "WriteSchemaAsync",
                        BindingFlags.NonPublic | BindingFlags.Instance);

                    writeSchemaMethod.Invoke(
                        arrowWriter,
                        new object[] { schema, CancellationToken.None });

                    SerDe.Write(inputStream, 0);

                    inputStream.Seek(0, SeekOrigin.Begin);

                    CommandExecutorStat stat = new CommandExecutor().Execute(
                        inputStream,
                        outputStream,
                        0,
                        commandPayload);

                    // Validate that all the data on the stream is read.
                    Assert.Equal(inputStream.Length, inputStream.Position);
                    Assert.Equal(0, stat.NumEntriesProcessed);

                    // Validate the output stream.
                    outputStream.Seek(0, SeekOrigin.Begin);
                    int arrowLength = SerDe.ReadInt32(outputStream);
                    Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
                    var         arrowReader = new ArrowStreamReader(outputStream);
                    RecordBatch outputBatch = arrowReader.ReadNextRecordBatch();

                    Assert.Equal(1, outputBatch.Schema.Fields.Count);
                    Assert.IsType <StringType>(outputBatch.Schema.GetFieldByIndex(0).DataType);

                    Assert.Equal(0, outputBatch.Length);
                    Assert.Single(outputBatch.Arrays);

                    var array = (StringArray)outputBatch.Arrays.ElementAt(0);
                    Assert.Equal(0, array.Length);

                    int end = SerDe.ReadInt32(outputStream);
                    Assert.Equal(0, end);

                    // Validate all the data on the stream is read.
                    Assert.Equal(outputStream.Length, outputStream.Position);
                }
        }