close
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,23 @@ def _tabledata_list_page_to_arrow(page, column_names, arrow_types):
for column_index, arrow_type in enumerate(arrow_types):
arrays.append(pyarrow.array(page._columns[column_index], type=arrow_type))

return pyarrow.RecordBatch.from_arrays(arrays, column_names)
if isinstance(column_names, pyarrow.Schema):
return pyarrow.RecordBatch.from_arrays(arrays, schema=column_names)
return pyarrow.RecordBatch.from_arrays(arrays, names=column_names)


def download_arrow_tabledata_list(pages, schema):
"""Use tabledata.list to construct an iterable of RecordBatches."""
"""Use tabledata.list to construct an iterable of RecordBatches.

Args:
pages (Iterator[:class:`google.api_core.page_iterator.Page`]):
An iterator over the result pages.
schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
A decription of the fields in result pages.
Yields:
:class:`pyarrow.RecordBatch`
The next page of records as a ``pyarrow`` record batch.
"""
column_names = bq_to_arrow_schema(schema) or [field.name for field in schema]
arrow_types = [bq_to_arrow_data_type(field) for field in schema]

Expand Down
72 changes: 72 additions & 0 deletions bigquery/tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import pytest
import pytz

from google import api_core
from google.cloud.bigquery import schema


Expand Down Expand Up @@ -905,3 +906,74 @@ def test_dataframe_to_parquet_compression_method(module_under_test):
call_args = fake_write_table.call_args
assert call_args is not None
assert call_args.kwargs.get("compression") == "ZSTD"


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_download_arrow_tabledata_list_unknown_field_type(module_under_test):
fake_page = api_core.page_iterator.Page(
parent=mock.Mock(),
items=[{"page_data": "foo"}],
item_to_value=api_core.page_iterator._item_to_value_identity,
)
fake_page._columns = [[1, 10, 100], [2.2, 22.22, 222.222]]
pages = [fake_page]

bq_schema = [
schema.SchemaField("population_size", "INTEGER"),
schema.SchemaField("alien_field", "ALIEN_FLOAT_TYPE"),
]

results_gen = module_under_test.download_arrow_tabledata_list(pages, bq_schema)

with warnings.catch_warnings(record=True) as warned:
result = next(results_gen)

unwanted_warnings = [
warning
for warning in warned
if "please pass schema= explicitly" in str(warning).lower()
]
assert not unwanted_warnings

assert len(result.columns) == 2
col = result.columns[0]
assert type(col) is pyarrow.lib.Int64Array
assert list(col) == [1, 10, 100]
col = result.columns[1]
assert type(col) is pyarrow.lib.DoubleArray
assert list(col) == [2.2, 22.22, 222.222]


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_download_arrow_tabledata_list_known_field_type(module_under_test):
fake_page = api_core.page_iterator.Page(
parent=mock.Mock(),
items=[{"page_data": "foo"}],
item_to_value=api_core.page_iterator._item_to_value_identity,
)
fake_page._columns = [[1, 10, 100], ["2.2", "22.22", "222.222"]]
pages = [fake_page]

bq_schema = [
schema.SchemaField("population_size", "INTEGER"),
schema.SchemaField("non_alien_field", "STRING"),
]

results_gen = module_under_test.download_arrow_tabledata_list(pages, bq_schema)
with warnings.catch_warnings(record=True) as warned:
result = next(results_gen)

unwanted_warnings = [
warning
for warning in warned
if "please pass schema= explicitly" in str(warning).lower()
]
assert not unwanted_warnings

assert len(result.columns) == 2
col = result.columns[0]
assert type(col) is pyarrow.lib.Int64Array
assert list(col) == [1, 10, 100]
col = result.columns[1]
assert type(col) is pyarrow.lib.StringArray
assert list(col) == ["2.2", "22.22", "222.222"]