diff --git a/packages/bigframes/bigframes/pandas/io/api.py b/packages/bigframes/bigframes/pandas/io/api.py index b7ed1a65d922..1304cac616c4 100644 --- a/packages/bigframes/bigframes/pandas/io/api.py +++ b/packages/bigframes/bigframes/pandas/io/api.py @@ -99,6 +99,7 @@ def read_avro( *, engine: str = "auto", ) -> bigframes.dataframe.DataFrame: + _set_default_session_location_if_possible_gcs(path) return global_session.with_default_session( bigframes.session.Session.read_avro, path, @@ -145,6 +146,7 @@ def read_csv( write_engine: constants.WriteEngineType = "default", **kwargs, ) -> bigframes.dataframe.DataFrame: + _set_default_session_location_if_possible_gcs(filepath_or_buffer) return global_session.with_default_session( bigframes.session.Session.read_csv, filepath_or_buffer=filepath_or_buffer, @@ -177,6 +179,7 @@ def read_json( write_engine: constants.WriteEngineType = "default", **kwargs, ) -> bigframes.dataframe.DataFrame: + _set_default_session_location_if_possible_gcs(path_or_buf) return global_session.with_default_session( bigframes.session.Session.read_json, path_or_buf=path_or_buf, @@ -535,6 +538,7 @@ def read_orc( engine: str = "auto", write_engine: constants.WriteEngineType = "default", ) -> bigframes.dataframe.DataFrame: + _set_default_session_location_if_possible_gcs(path) return global_session.with_default_session( bigframes.session.Session.read_orc, path, @@ -610,6 +614,7 @@ def read_parquet( engine: str = "auto", write_engine: constants.WriteEngineType = "default", ) -> bigframes.dataframe.DataFrame: + _set_default_session_location_if_possible_gcs(path) return global_session.with_default_session( bigframes.session.Session.read_parquet, path, @@ -638,6 +643,7 @@ def read_gbq_function( def from_glob_path( path: str, *, connection: Optional[str] = None, name: Optional[str] = None ) -> bigframes.dataframe.DataFrame: + _set_default_session_location_if_possible_gcs(path) return global_session.with_default_session( bigframes.session.Session.from_glob_path, path=path, @@ -729,3 +735,46 @@ def _set_default_session_location_if_possible_deferred_query(create_query): else: table = bqclient.get_table(query) config.options.bigquery.location = table.location + + +def _get_storage_client(): + from bigframes.session import clients + + try: + clients_provider = clients.ClientsProvider( + project=config.options.bigquery.project, + location=config.options.bigquery.location, + use_regional_endpoints=config.options.bigquery.use_regional_endpoints, + credentials=config.options.bigquery.credentials, + application_name=config.options.bigquery.application_name, + bq_kms_key_name=config.options.bigquery.kms_key_name, + client_endpoints_override=config.options.bigquery.client_endpoints_override, + requests_transport_adapters=config.options.bigquery.requests_transport_adapters, + ) + return clients_provider.storageclient + except Exception: + return None + + +def _set_default_session_location_if_possible_gcs(path: Any): + if isinstance(path, str) and path.startswith("gs://"): + global _default_location_lock + + with _default_location_lock: + if ( + config.options.bigquery._session_started + or config.options.bigquery.location + or config.options.bigquery.use_regional_endpoints + ): + return + + bucket_name = path[5:].split("/", 1)[0] + storage_client = _get_storage_client() + if storage_client: + try: + bucket = storage_client.get_bucket(bucket_name) + if bucket.location: + config.options.bigquery.location = bucket.location + except Exception: + pass + diff --git a/packages/bigframes/tests/unit/pandas/io/test_api.py b/packages/bigframes/tests/unit/pandas/io/test_api.py index dbdf427d91b3..0de12b93ba1a 100644 --- a/packages/bigframes/tests/unit/pandas/io/test_api.py +++ b/packages/bigframes/tests/unit/pandas/io/test_api.py @@ -127,3 +127,43 @@ def test_read_gbq_colab_calls_set_location( assert kwargs["pyformat_args"] == sample_pyformat_args assert not kwargs["dry_run"] assert isinstance(result, bigframes.dataframe.DataFrame) + + +@mock.patch("bigframes.pandas.io.api._get_storage_client") +@mock.patch("bigframes.core.global_session.with_default_session") +def test_read_csv_gcs_sets_location(mock_with_default_session, mock_get_storage_client): + mock_storage_client = mock.Mock() + mock_bucket = mock.Mock() + mock_bucket.location = "us-east1" + mock_storage_client.get_bucket.return_value = mock_bucket + mock_get_storage_client.return_value = mock_storage_client + + import bigframes._config as config + config.options.bigquery.location = None + config.options.bigquery._session_started = False + config.options.bigquery.use_regional_endpoints = None + + bf_io_api.read_csv("gs://test-bucket/file.csv") + + assert config.options.bigquery.location == "us-east1" + + +@mock.patch("bigframes.pandas.io.api._get_storage_client") +@mock.patch("bigframes.core.global_session.with_default_session") +def test_read_csv_gcs_doesnt_overwrite_set_location(mock_with_default_session, mock_get_storage_client): + mock_storage_client = mock.Mock() + mock_bucket = mock.Mock() + mock_bucket.location = "us-east1" + mock_storage_client.get_bucket.return_value = mock_bucket + mock_get_storage_client.return_value = mock_storage_client + + import bigframes._config as config + config.options.bigquery.location = "eu" + config.options.bigquery._session_started = False + config.options.bigquery.use_regional_endpoints = None + + bf_io_api.read_csv("gs://test-bucket/file.csv") + + assert config.options.bigquery.location == "EU" + +