예제 #1
0
    public void getOutputSchemaInner(string sql)
    {
        reader = new ArrowStreamReader(this.outputBuffer, leaveOpen: false);
        // reader one batch to get the arrow schema first
        reader.ReadNextRecordBatch();

        this.schema = ArrowSchemaToASARecordSchema(reader.Schema);

        var result =
            SqlCompiler.Compile(
                sql,
                new QueryBindings(
                    new Dictionary <string, InputDescription> {
            { "input", new InputDescription(this.schema, InputType.Stream) }
        }));

        var step = result.Steps.First();

        Schema.Builder builder = new Schema.Builder();
        foreach (KeyValuePair <string, int> kv in step.Output.PayloadSchema.Ordinals.OrderBy(kv => kv.Value))
        {
            builder = builder.Field(f => f.Name(kv.Key).DataType(ASATypeToArrowType(step.Output.PayloadSchema[kv.Value].Schema)).Nullable(false));
        }

        this.outputArrowSchema = builder.Build();

        this.writer = new ArrowStreamWriter(this.inputBuffer, this.outputArrowSchema);
        //Write empty batch to send the schema to Java side
        var emptyRecordBatch = createOutputRecordBatch(new List <IRecord>());

        WriteRecordBatch(emptyRecordBatch);
    }
예제 #2
0
        public async Task CanWriteToNetworkStreamAsync()
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);

            const int   port     = 32154;
            TcpListener listener = new TcpListener(IPAddress.Loopback, port);

            listener.Start();

            using (TcpClient sender = new TcpClient())
            {
                sender.Connect(IPAddress.Loopback, port);
                NetworkStream stream = sender.GetStream();

                using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema))
                {
                    await writer.WriteRecordBatchAsync(originalBatch);

                    await writer.WriteEndAsync();

                    stream.Flush();
                }
            }

            using (TcpClient receiver = listener.AcceptTcpClient())
            {
                NetworkStream stream = receiver.GetStream();
                using (var reader = new ArrowStreamReader(stream))
                {
                    RecordBatch newBatch = reader.ReadNextRecordBatch();
                    ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
                }
            }
        }
        public void TestEmptyDataFrameRecordBatch()
        {
            PrimitiveDataFrameColumn <int> ageColumn    = new PrimitiveDataFrameColumn <int>("Age");
            PrimitiveDataFrameColumn <int> lengthColumn = new PrimitiveDataFrameColumn <int>("CharCount");
            ArrowStringDataFrameColumn     stringColumn = new ArrowStringDataFrameColumn("Empty");
            DataFrame df = new DataFrame(new List <DataFrameColumn>()
            {
                ageColumn, lengthColumn, stringColumn
            });

            IEnumerable <RecordBatch> recordBatches = df.ToArrowRecordBatches();
            bool foundARecordBatch = false;

            foreach (RecordBatch recordBatch in recordBatches)
            {
                foundARecordBatch = true;
                MemoryStream      stream = new MemoryStream();
                ArrowStreamWriter writer = new ArrowStreamWriter(stream, recordBatch.Schema);
                writer.WriteRecordBatchAsync(recordBatch).GetAwaiter().GetResult();

                stream.Position = 0;
                ArrowStreamReader reader          = new ArrowStreamReader(stream);
                RecordBatch       readRecordBatch = reader.ReadNextRecordBatch();
                while (readRecordBatch != null)
                {
                    RecordBatchComparer.CompareBatches(recordBatch, readRecordBatch);
                    readRecordBatch = reader.ReadNextRecordBatch();
                }
            }
            Assert.True(foundARecordBatch);
        }
예제 #4
0
        private IEnumerable <RecordBatch> GetInputIterator(Stream inputStream)
        {
            using (var reader = new ArrowStreamReader(inputStream, leaveOpen: true))
            {
                RecordBatch batch;
                bool        returnedResult = false;
                while ((batch = reader.ReadNextRecordBatch()) != null)
                {
                    yield return(batch);

                    returnedResult = true;
                }

                if (!returnedResult)
                {
                    // When no input batches were received, return an empty RecordBatch
                    // in order to create and write back the result schema.

                    int columnCount = reader.Schema.Fields.Count;
                    var arrays      = new IArrowArray[columnCount];
                    for (int i = 0; i < columnCount; ++i)
                    {
                        IArrowType type = reader.Schema.GetFieldByIndex(i).DataType;
                        arrays[i] = ArrowArrayHelpers.CreateEmptyArray(type);
                    }
                    yield return(new RecordBatch(reader.Schema, arrays, 0));
                }
            }
        }
        public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen, bool createDictionaryArray, int expectedAllocations)
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray);

            using (MemoryStream stream = new MemoryStream())
            {
                ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema);
                await writer.WriteRecordBatchAsync(originalBatch);

                await writer.WriteEndAsync();

                stream.Position = 0;

                var memoryPool           = new TestMemoryAllocator();
                ArrowStreamReader reader = new ArrowStreamReader(stream, memoryPool, shouldLeaveOpen);
                reader.ReadNextRecordBatch();

                Assert.Equal(expectedAllocations, memoryPool.Statistics.Allocations);
                Assert.True(memoryPool.Statistics.BytesAllocated > 0);

                reader.Dispose();

                if (shouldLeaveOpen)
                {
                    Assert.True(stream.Position > 0);
                }
                else
                {
                    Assert.Throws <ObjectDisposedException>(() => stream.Position);
                }
            }
        }
        private static async Task TestRoundTripRecordBatchesAsync(List <RecordBatch> originalBatches, IpcOptions options = null)
        {
            using (MemoryStream stream = new MemoryStream())
            {
                using (var writer = new ArrowStreamWriter(stream, originalBatches[0].Schema, leaveOpen: true, options))
                {
                    foreach (RecordBatch originalBatch in originalBatches)
                    {
                        await writer.WriteRecordBatchAsync(originalBatch);
                    }
                    await writer.WriteEndAsync();
                }

                stream.Position = 0;

                using (var reader = new ArrowStreamReader(stream))
                {
                    foreach (RecordBatch originalBatch in originalBatches)
                    {
                        RecordBatch newBatch = reader.ReadNextRecordBatch();
                        ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
                    }
                }
            }
        }
        public void CanWriteToNetworkStream(bool createDictionaryArray, int port)
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray);

            TcpListener listener = new TcpListener(IPAddress.Loopback, port);

            listener.Start();

            using (TcpClient sender = new TcpClient())
            {
                sender.Connect(IPAddress.Loopback, port);
                NetworkStream stream = sender.GetStream();

                using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema))
                {
                    writer.WriteRecordBatch(originalBatch);
                    writer.WriteEnd();

                    stream.Flush();
                }
            }

            using (TcpClient receiver = listener.AcceptTcpClient())
            {
                NetworkStream stream = receiver.GetStream();
                using (var reader = new ArrowStreamReader(stream))
                {
                    RecordBatch newBatch = reader.ReadNextRecordBatch();
                    ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
                }
            }
        }
예제 #8
0
        public static string ApacheArrowToJSON(string base64)
        {
            try
            {
                byte[] bytes = Convert.FromBase64String(base64);
                using (ArrowStreamReader reader = new ArrowStreamReader(bytes))
                {
                    reader.ReadNextRecordBatch();
                    return(JsonConvert.SerializeObject(reader.Schema, Formatting.Indented));

                    var metadata = new JObject();
                    var schema   = new JObject();

                    var fields = new JArray();
                    if (reader.Schema?.Fields != null)
                    {
                        foreach (var _field in reader.Schema.Fields)
                        {
                            var field = new JObject();
                            field[nameof(_field.Value.Name)]       = _field.Value.Name;
                            field[nameof(_field.Value.IsNullable)] = _field.Value.IsNullable;
                            field[nameof(_field.Value.DataType)]   = JObject.Parse(JsonConvert.SerializeObject(_field.Value.DataType));

                            if (_field.Value.HasMetadata)
                            {
                                metadata = new JObject();
                                foreach (var _fieldMetadata in _field.Value.Metadata)
                                {
                                    metadata[_fieldMetadata.Key] = _fieldMetadata.Value;
                                }
                                field[nameof(metadata)] = metadata;
                            }

                            fields.Add(field);
                        }
                    }
                    schema[nameof(fields)] = fields;

                    metadata = new JObject();
                    if (reader.Schema?.Metadata != null)
                    {
                        foreach (var _metadata in reader.Schema.Metadata)
                        {
                            metadata[_metadata.Key] = _metadata.Value;
                        }
                    }
                    schema[nameof(metadata)] = metadata;

                    return(schema.ToString(Formatting.Indented));
                }
            }
            catch (Exception ex)
            {
                return($"Something went wrong while processing the schema:{Environment.NewLine}{Environment.NewLine}{ex.ToString()}");
            }
        }
예제 #9
0
        public static async Task VerifyReaderAsync(ArrowStreamReader reader, RecordBatch originalBatch)
        {
            RecordBatch readBatch = await reader.ReadNextRecordBatchAsync();

            CompareBatches(originalBatch, readBatch);

            // There should only be one batch - calling ReadNextRecordBatchAsync again should return null.
            Assert.Null(await reader.ReadNextRecordBatchAsync());
            Assert.Null(await reader.ReadNextRecordBatchAsync());
        }
예제 #10
0
        public static void VerifyReader(ArrowStreamReader reader, RecordBatch originalBatch)
        {
            RecordBatch readBatch = reader.ReadNextRecordBatch();

            CompareBatches(originalBatch, readBatch);

            // There should only be one batch - calling ReadNextRecordBatch again should return null.
            Assert.Null(reader.ReadNextRecordBatch());
            Assert.Null(reader.ReadNextRecordBatch());
        }
예제 #11
0
        public async Task <double> ArrowReaderWithMemory()
        {
            double      sum    = 0;
            var         reader = new ArrowStreamReader(_memoryStream.GetBuffer());
            RecordBatch recordBatch;

            while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null)
            {
                sum += SumAllNumbers(recordBatch);
            }
            return(sum);
        }
예제 #12
0
        /// <summary>
        /// Create input iterator from the given input stream.
        /// </summary>
        /// <param name="inputStream">Stream to read from</param>
        /// <returns></returns>
        private IEnumerable <ReadOnlyMemory <IArrowArray> > GetInputIterator(Stream inputStream)
        {
            IArrowArray[] arrays      = null;
            int           columnCount = 0;

            try
            {
                using (var reader = new ArrowStreamReader(inputStream, leaveOpen: true))
                {
                    RecordBatch batch;
                    while ((batch = reader.ReadNextRecordBatch()) != null)
                    {
                        columnCount = batch.ColumnCount;
                        if (arrays == null)
                        {
                            // Note that every batch in a stream has the same schema.
                            arrays = ArrayPool <IArrowArray> .Shared.Rent(columnCount);
                        }

                        for (int i = 0; i < columnCount; ++i)
                        {
                            arrays[i] = batch.Column(i);
                        }

                        yield return(new ReadOnlyMemory <IArrowArray>(arrays, 0, columnCount));
                    }

                    if (arrays == null)
                    {
                        // When no input batches were received, return empty IArrowArrays
                        // in order to create and write back the result schema.
                        columnCount = reader.Schema.Fields.Count;
                        arrays      = ArrayPool <IArrowArray> .Shared.Rent(columnCount);

                        for (int i = 0; i < columnCount; ++i)
                        {
                            arrays[i] = null;
                        }
                        yield return(new ReadOnlyMemory <IArrowArray>(arrays, 0, columnCount));
                    }
                }
            }
            finally
            {
                if (arrays != null)
                {
                    arrays.AsSpan(0, columnCount).Clear();
                    ArrayPool <IArrowArray> .Shared.Return(arrays);
                }
            }
        }
예제 #13
0
        /// <summary>
        /// Verifies that the stream reader reads multiple times when a stream
        /// only returns a subset of the data from each Read.
        /// </summary>
        private static async Task TestReaderFromPartialReadStream(Func <ArrowStreamReader, RecordBatch, Task> verificationFunc)
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);

            using (PartialReadStream stream = new PartialReadStream())
            {
                ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema);
                await writer.WriteRecordBatchAsync(originalBatch);

                stream.Position = 0;

                ArrowStreamReader reader = new ArrowStreamReader(stream);
                await verificationFunc(reader, originalBatch);
            }
        }
예제 #14
0
        public async Task <double> ArrowReaderWithMemoryStream_ManagedMemory()
        {
            double      sum    = 0;
            var         reader = new ArrowStreamReader(_memoryStream, s_allocator);
            RecordBatch recordBatch;

            while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null)
            {
                using (recordBatch)
                {
                    sum += SumAllNumbers(recordBatch);
                }
            }
            return(sum);
        }
 public static string ApacheArrowToJSON(string base64)
 {
     try
     {
         byte[] bytes = Convert.FromBase64String(base64);
         using (ArrowStreamReader reader = new ArrowStreamReader(bytes))
         {
             reader.ReadNextRecordBatch();
             return(JsonConvert.SerializeObject(reader.Schema, Formatting.Indented));
         }
     }
     catch (Exception ex)
     {
         return($"Something went wrong while processing the schema:{Environment.NewLine}{Environment.NewLine}{ex.ToString()}");
     }
 }
예제 #16
0
        private static async Task TestReaderFromMemory(Func <ArrowStreamReader, RecordBatch, Task> verificationFunc)
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);

            byte[] buffer;
            using (MemoryStream stream = new MemoryStream())
            {
                ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema);
                await writer.WriteRecordBatchAsync(originalBatch);

                buffer = stream.GetBuffer();
            }

            ArrowStreamReader reader = new ArrowStreamReader(buffer);

            await verificationFunc(reader, originalBatch);
        }
예제 #17
0
        private static async Task TestRoundTripRecordBatch(RecordBatch originalBatch)
        {
            using (MemoryStream stream = new MemoryStream())
            {
                using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true))
                {
                    await writer.WriteRecordBatchAsync(originalBatch);
                }

                stream.Position = 0;

                using (var reader = new ArrowStreamReader(stream))
                {
                    RecordBatch newBatch = reader.ReadNextRecordBatch();
                    ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
                }
            }
        }
        public void WritesEmptyFile()
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 1);

            var stream = new MemoryStream();
            var writer = new ArrowStreamWriter(stream, originalBatch.Schema);

            writer.WriteStart();
            writer.WriteEnd();

            stream.Position = 0;

            var         reader    = new ArrowStreamReader(stream);
            RecordBatch readBatch = reader.ReadNextRecordBatch();

            Assert.Null(readBatch);
            SchemaComparer.Compare(originalBatch.Schema, reader.Schema);
        }
예제 #19
0
        private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null)
        {
            using (MemoryStream stream = new MemoryStream())
            {
                using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options))
                {
                    writer.WriteRecordBatch(originalBatch);
                    writer.WriteEnd();
                }

                stream.Position = 0;

                using (var reader = new ArrowStreamReader(stream))
                {
                    RecordBatch newBatch = reader.ReadNextRecordBatch();
                    ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
                }
            }
        }
예제 #20
0
        public async Task WriteEmptyBatch()
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0);

            using (MemoryStream stream = new MemoryStream())
            {
                using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true))
                {
                    await writer.WriteRecordBatchAsync(originalBatch);
                }

                stream.Position = 0;

                using (var reader = new ArrowStreamReader(stream))
                {
                    RecordBatch newBatch = reader.ReadNextRecordBatch();
                    ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
                }
            }
        }
예제 #21
0
        private static async Task TestReaderFromStream(
            Func <ArrowStreamReader, RecordBatch, Task> verificationFunc,
            bool writeEnd, bool createDictionaryArray)
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray);

            using (MemoryStream stream = new MemoryStream())
            {
                ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema);
                await writer.WriteRecordBatchAsync(originalBatch);

                if (writeEnd)
                {
                    await writer.WriteEndAsync();
                }

                stream.Position = 0;

                ArrowStreamReader reader = new ArrowStreamReader(stream);
                await verificationFunc(reader, originalBatch);
            }
        }
예제 #22
0
        public async Task ReadRecordBatch()
        {
            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);

            byte[] buffer;
            using (MemoryStream stream = new MemoryStream())
            {
                ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema);
                await writer.WriteRecordBatchAsync(originalBatch);

                buffer = stream.GetBuffer();
            }

            ArrowStreamReader reader    = new ArrowStreamReader(buffer);
            RecordBatch       readBatch = reader.ReadNextRecordBatch();

            CompareBatches(originalBatch, readBatch);

            // There should only be one batch - calling ReadNextRecordBatch again should return null.
            Assert.Null(reader.ReadNextRecordBatch());
            Assert.Null(reader.ReadNextRecordBatch());
        }
        public async Task TestDataFrameSqlCommandExecutorWithSingleCommand(
            Version sparkVersion,
            IpcOptions ipcOptions)
        {
            var udfWrapper = new Sql.DataFrameUdfWrapper <ArrowStringDataFrameColumn, ArrowStringDataFrameColumn>(
                (strings) => strings.Apply(cur => $"udf: {cur}"));

            var command = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.DataFrameWorkerFunction(udfWrapper.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                Commands = new[] { command }
            };

            using var inputStream  = new MemoryStream();
            using var outputStream = new MemoryStream();
            int numRows = 10;

            // Write test data to the input stream.
            Schema schema = new Schema.Builder()
                            .Field(b => b.Name("arg1").DataType(StringType.Default))
                            .Build();
            var arrowWriter =
                new ArrowStreamWriter(inputStream, schema, leaveOpen: false, ipcOptions);
            await arrowWriter.WriteRecordBatchAsync(
                new RecordBatch(
                    schema,
                    new[]
            {
                ToArrowArray(
                    Enumerable.Range(0, numRows)
                    .Select(i => i.ToString())
                    .ToArray())
            },
                    numRows));

            inputStream.Seek(0, SeekOrigin.Begin);

            CommandExecutorStat stat = new CommandExecutor(sparkVersion).Execute(
                inputStream,
                outputStream,
                0,
                commandPayload);

            // Validate that all the data on the stream is read.
            Assert.Equal(inputStream.Length, inputStream.Position);
            Assert.Equal(numRows, stat.NumEntriesProcessed);

            // Validate the output stream.
            outputStream.Seek(0, SeekOrigin.Begin);
            int arrowLength = SerDe.ReadInt32(outputStream);

            Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
            var         arrowReader = new ArrowStreamReader(outputStream);
            RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();

            Assert.Equal(numRows, outputBatch.Length);
            Assert.Single(outputBatch.Arrays);
            var array = (StringArray)outputBatch.Arrays.ElementAt(0);

            // Validate the single command.
            for (int i = 0; i < numRows; ++i)
            {
                Assert.Equal($"udf: {i}", array.GetString(i));
            }

            CheckEOS(outputStream, ipcOptions);

            // Validate all the data on the stream is read.
            Assert.Equal(outputStream.Length, outputStream.Position);
        }
예제 #24
0
        public async Task TestArrowSqlCommandExecutorWithMultiCommands()
        {
            var udfWrapper1 = new Sql.ArrowUdfWrapper <StringArray, StringArray>(
                (strings) => (StringArray)ToArrowArray(
                    Enumerable.Range(0, strings.Length)
                    .Select(i => $"udf: {strings.GetString(i)}")
                    .ToArray()));
            var udfWrapper2 = new Sql.ArrowUdfWrapper <Int32Array, Int32Array, Int32Array>(
                (arg1, arg2) => (Int32Array)ToArrowArray(
                    Enumerable.Range(0, arg1.Length)
                    .Select(i => arg1.Values[i] * arg2.Values[i])
                    .ToArray()));

            var command1 = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowWorkerFunction(udfWrapper1.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var command2 = new SqlCommand()
            {
                ArgOffsets          = new[] { 1, 2 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowWorkerFunction(udfWrapper2.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                Commands = new[] { command1, command2 }
            };

            using (var inputStream = new MemoryStream())
                using (var outputStream = new MemoryStream())
                {
                    int numRows = 10;

                    // Write test data to the input stream.
                    Schema schema = new Schema.Builder()
                                    .Field(b => b.Name("arg1").DataType(StringType.Default))
                                    .Field(b => b.Name("arg2").DataType(Int32Type.Default))
                                    .Field(b => b.Name("arg3").DataType(Int32Type.Default))
                                    .Build();
                    var arrowWriter = new ArrowStreamWriter(inputStream, schema);
                    await arrowWriter.WriteRecordBatchAsync(
                        new RecordBatch(
                            schema,
                            new[]
                    {
                        ToArrowArray(
                            Enumerable.Range(0, numRows)
                            .Select(i => i.ToString())
                            .ToArray()),
                        ToArrowArray(Enumerable.Range(0, numRows).ToArray()),
                        ToArrowArray(Enumerable.Range(0, numRows).ToArray()),
                    },
                            numRows));

                    inputStream.Seek(0, SeekOrigin.Begin);

                    CommandExecutorStat stat = new CommandExecutor().Execute(
                        inputStream,
                        outputStream,
                        0,
                        commandPayload);

                    // Validate all the data on the stream is read.
                    Assert.Equal(inputStream.Length, inputStream.Position);
                    Assert.Equal(numRows, stat.NumEntriesProcessed);

                    // Validate the output stream.
                    outputStream.Seek(0, SeekOrigin.Begin);
                    var arrowLength = SerDe.ReadInt32(outputStream);
                    Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
                    var         arrowReader = new ArrowStreamReader(outputStream);
                    RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();

                    Assert.Equal(numRows, outputBatch.Length);
                    Assert.Equal(2, outputBatch.Arrays.Count());
                    var array1 = (StringArray)outputBatch.Arrays.ElementAt(0);
                    var array2 = (Int32Array)outputBatch.Arrays.ElementAt(1);
                    for (int i = 0; i < numRows; ++i)
                    {
                        Assert.Equal($"udf: {i}", array1.GetString(i));
                        Assert.Equal(i * i, array2.Values[i]);
                    }

                    int end = SerDe.ReadInt32(outputStream);
                    Assert.Equal(0, end);

                    // Validate all the data on the stream is read.
                    Assert.Equal(outputStream.Length, outputStream.Position);
                }
        }
예제 #25
0
        public async Task TestArrowGroupedMapCommandExecutor()
        {
            StringArray ConvertStrings(StringArray strings)
            {
                return((StringArray)ToArrowArray(
                           Enumerable.Range(0, strings.Length)
                           .Select(i => $"udf: {strings.GetString(i)}")
                           .ToArray()));
            }

            Int64Array ConvertInt64s(Int64Array int64s)
            {
                return((Int64Array)ToArrowArray(
                           Enumerable.Range(0, int64s.Length)
                           .Select(i => int64s.Values[i] + 100)
                           .ToArray()));
            }

            Schema resultSchema = new Schema.Builder()
                                  .Field(b => b.Name("arg1").DataType(StringType.Default))
                                  .Field(b => b.Name("arg2").DataType(Int64Type.Default))
                                  .Build();

            var udfWrapper = new Sql.ArrowGroupedMapUdfWrapper(
                (batch) => new RecordBatch(
                    resultSchema,
                    new IArrowArray[]
            {
                ConvertStrings((StringArray)batch.Column(0)),
                ConvertInt64s((Int64Array)batch.Column(1)),
            },
                    batch.Length));

            var command = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowGroupedMapWorkerFunction(udfWrapper.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                Commands = new[] { command }
            };

            using (var inputStream = new MemoryStream())
                using (var outputStream = new MemoryStream())
                {
                    int numRows = 10;

                    // Write test data to the input stream.
                    Schema schema = new Schema.Builder()
                                    .Field(b => b.Name("arg1").DataType(StringType.Default))
                                    .Field(b => b.Name("arg2").DataType(Int64Type.Default))
                                    .Build();
                    var arrowWriter = new ArrowStreamWriter(inputStream, schema);
                    await arrowWriter.WriteRecordBatchAsync(
                        new RecordBatch(
                            schema,
                            new[]
                    {
                        ToArrowArray(
                            Enumerable.Range(0, numRows)
                            .Select(i => i.ToString())
                            .ToArray()),
                        ToArrowArray(
                            Enumerable.Range(0, numRows)
                            .Select(i => (long)i)
                            .ToArray())
                    },
                            numRows));

                    inputStream.Seek(0, SeekOrigin.Begin);

                    CommandExecutorStat stat = new CommandExecutor().Execute(
                        inputStream,
                        outputStream,
                        0,
                        commandPayload);

                    // Validate that all the data on the stream is read.
                    Assert.Equal(inputStream.Length, inputStream.Position);
                    Assert.Equal(numRows, stat.NumEntriesProcessed);

                    // Validate the output stream.
                    outputStream.Seek(0, SeekOrigin.Begin);
                    int arrowLength = SerDe.ReadInt32(outputStream);
                    Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
                    var         arrowReader = new ArrowStreamReader(outputStream);
                    RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();

                    Assert.Equal(numRows, outputBatch.Length);
                    Assert.Equal(2, outputBatch.ColumnCount);

                    var stringArray = (StringArray)outputBatch.Column(0);
                    for (int i = 0; i < numRows; ++i)
                    {
                        Assert.Equal($"udf: {i}", stringArray.GetString(i));
                    }

                    var longArray = (Int64Array)outputBatch.Column(1);
                    for (int i = 0; i < numRows; ++i)
                    {
                        Assert.Equal(100 + i, longArray.Values[i]);
                    }

                    int end = SerDe.ReadInt32(outputStream);
                    Assert.Equal(0, end);

                    // Validate all the data on the stream is read.
                    Assert.Equal(outputStream.Length, outputStream.Position);
                }
        }
예제 #26
0
        public void TestArrowSqlCommandExecutorWithEmptyInput()
        {
            var udfWrapper = new Sql.ArrowUdfWrapper <StringArray, StringArray>(
                (strings) => (StringArray)ToArrowArray(
                    Enumerable.Range(0, strings.Length)
                    .Select(i => $"udf: {strings.GetString(i)}")
                    .ToArray()));

            var command = new SqlCommand()
            {
                ArgOffsets          = new[] { 0 },
                NumChainedFunctions = 1,
                WorkerFunction      = new Sql.ArrowWorkerFunction(udfWrapper.Execute),
                SerializerMode      = CommandSerDe.SerializedMode.Row,
                DeserializerMode    = CommandSerDe.SerializedMode.Row
            };

            var commandPayload = new Worker.CommandPayload()
            {
                EvalType = UdfUtils.PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                Commands = new[] { command }
            };

            using (var inputStream = new MemoryStream())
                using (var outputStream = new MemoryStream())
                {
                    // Write test data to the input stream.
                    Schema schema = new Schema.Builder()
                                    .Field(b => b.Name("arg1").DataType(StringType.Default))
                                    .Build();
                    var arrowWriter = new ArrowStreamWriter(inputStream, schema);

                    // The .NET ArrowStreamWriter doesn't currently support writing just a
                    // schema with no batches - but Java does. We use Reflection to simulate
                    // the request Spark sends.
                    MethodInfo writeSchemaMethod = arrowWriter.GetType().GetMethod(
                        "WriteSchemaAsync",
                        BindingFlags.NonPublic | BindingFlags.Instance);

                    writeSchemaMethod.Invoke(
                        arrowWriter,
                        new object[] { schema, CancellationToken.None });

                    SerDe.Write(inputStream, 0);

                    inputStream.Seek(0, SeekOrigin.Begin);

                    CommandExecutorStat stat = new CommandExecutor().Execute(
                        inputStream,
                        outputStream,
                        0,
                        commandPayload);

                    // Validate that all the data on the stream is read.
                    Assert.Equal(inputStream.Length, inputStream.Position);
                    Assert.Equal(0, stat.NumEntriesProcessed);

                    // Validate the output stream.
                    outputStream.Seek(0, SeekOrigin.Begin);
                    int arrowLength = SerDe.ReadInt32(outputStream);
                    Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
                    var         arrowReader = new ArrowStreamReader(outputStream);
                    RecordBatch outputBatch = arrowReader.ReadNextRecordBatch();

                    Assert.Equal(1, outputBatch.Schema.Fields.Count);
                    Assert.IsType <StringType>(outputBatch.Schema.GetFieldByIndex(0).DataType);

                    Assert.Equal(0, outputBatch.Length);
                    Assert.Single(outputBatch.Arrays);

                    var array = (StringArray)outputBatch.Arrays.ElementAt(0);
                    Assert.Equal(0, array.Length);

                    int end = SerDe.ReadInt32(outputStream);
                    Assert.Equal(0, end);

                    // Validate all the data on the stream is read.
                    Assert.Equal(outputStream.Length, outputStream.Position);
                }
        }