Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions packages/bigframes/bigframes/pandas/io/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def read_avro(
*,
engine: str = "auto",
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible_gcs(path)
return global_session.with_default_session(
bigframes.session.Session.read_avro,
path,
Expand Down Expand Up @@ -145,6 +146,7 @@ def read_csv(
write_engine: constants.WriteEngineType = "default",
**kwargs,
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible_gcs(filepath_or_buffer)
return global_session.with_default_session(
bigframes.session.Session.read_csv,
filepath_or_buffer=filepath_or_buffer,
Expand Down Expand Up @@ -177,6 +179,7 @@ def read_json(
write_engine: constants.WriteEngineType = "default",
**kwargs,
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible_gcs(path_or_buf)
return global_session.with_default_session(
bigframes.session.Session.read_json,
path_or_buf=path_or_buf,
Expand Down Expand Up @@ -535,6 +538,7 @@ def read_orc(
engine: str = "auto",
write_engine: constants.WriteEngineType = "default",
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible_gcs(path)
return global_session.with_default_session(
bigframes.session.Session.read_orc,
path,
Expand Down Expand Up @@ -610,6 +614,7 @@ def read_parquet(
engine: str = "auto",
write_engine: constants.WriteEngineType = "default",
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible_gcs(path)
return global_session.with_default_session(
bigframes.session.Session.read_parquet,
path,
Expand Down Expand Up @@ -638,6 +643,7 @@ def read_gbq_function(
def from_glob_path(
path: str, *, connection: Optional[str] = None, name: Optional[str] = None
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible_gcs(path)
return global_session.with_default_session(
bigframes.session.Session.from_glob_path,
path=path,
Expand Down Expand Up @@ -729,3 +735,46 @@ def _set_default_session_location_if_possible_deferred_query(create_query):
else:
table = bqclient.get_table(query)
config.options.bigquery.location = table.location


def _get_storage_client():
from bigframes.session import clients

try:
clients_provider = clients.ClientsProvider(
project=config.options.bigquery.project,
location=config.options.bigquery.location,
use_regional_endpoints=config.options.bigquery.use_regional_endpoints,
credentials=config.options.bigquery.credentials,
application_name=config.options.bigquery.application_name,
bq_kms_key_name=config.options.bigquery.kms_key_name,
client_endpoints_override=config.options.bigquery.client_endpoints_override,
requests_transport_adapters=config.options.bigquery.requests_transport_adapters,
)
return clients_provider.storageclient
except Exception:
return None


def _set_default_session_location_if_possible_gcs(path: Any):
if isinstance(path, str) and path.startswith("gs://"):
global _default_location_lock

with _default_location_lock:
Comment on lines +761 to +763
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable _default_location_lock is used but not defined at the module level, and the threading module is not imported in this file. This will result in a NameError at runtime when _set_default_session_location_if_possible_gcs is called. Please define the lock at the module level (e.g., _default_location_lock = threading.Lock()) and ensure threading is imported.

if (
config.options.bigquery._session_started
or config.options.bigquery.location
or config.options.bigquery.use_regional_endpoints
):
return

bucket_name = path[5:].split("/", 1)[0]
storage_client = _get_storage_client()
if storage_client:
try:
bucket = storage_client.get_bucket(bucket_name)
if bucket.location:
config.options.bigquery.location = bucket.location
except Exception:
pass

40 changes: 40 additions & 0 deletions packages/bigframes/tests/unit/pandas/io/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,43 @@ def test_read_gbq_colab_calls_set_location(
assert kwargs["pyformat_args"] == sample_pyformat_args
assert not kwargs["dry_run"]
assert isinstance(result, bigframes.dataframe.DataFrame)


@mock.patch("bigframes.pandas.io.api._get_storage_client")
@mock.patch("bigframes.core.global_session.with_default_session")
def test_read_csv_gcs_sets_location(mock_with_default_session, mock_get_storage_client):
mock_storage_client = mock.Mock()
mock_bucket = mock.Mock()
mock_bucket.location = "us-east1"
mock_storage_client.get_bucket.return_value = mock_bucket
mock_get_storage_client.return_value = mock_storage_client

import bigframes._config as config
config.options.bigquery.location = None
config.options.bigquery._session_started = False
config.options.bigquery.use_regional_endpoints = None

bf_io_api.read_csv("gs://test-bucket/file.csv")

assert config.options.bigquery.location == "us-east1"


@mock.patch("bigframes.pandas.io.api._get_storage_client")
@mock.patch("bigframes.core.global_session.with_default_session")
def test_read_csv_gcs_doesnt_overwrite_set_location(mock_with_default_session, mock_get_storage_client):
mock_storage_client = mock.Mock()
mock_bucket = mock.Mock()
mock_bucket.location = "us-east1"
mock_storage_client.get_bucket.return_value = mock_bucket
mock_get_storage_client.return_value = mock_storage_client

import bigframes._config as config
config.options.bigquery.location = "eu"
config.options.bigquery._session_started = False
config.options.bigquery.use_regional_endpoints = None

bf_io_api.read_csv("gs://test-bucket/file.csv")

assert config.options.bigquery.location == "EU"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion expects "EU", but the location was set to "eu" at line 161. Since _set_default_session_location_if_possible_gcs returns early when a location is already set, the value will remain "eu", causing this test to fail. Based on the previous test case (line 148), locations appear to be case-sensitive in these assertions.

Suggested change
assert config.options.bigquery.location == "EU"
assert config.options.bigquery.location == "eu"



Loading