From a1f353e795ee34ad6fc708e6d1a2b69200fac63e Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 28 Apr 2026 22:32:04 +0000 Subject: [PATCH 1/5] feat(bigframes): update ai.if_() params to match the SQL version --- .../bigframes/bigquery/_operations/ai.py | 28 ++++++++++++++----- .../ibis_compiler/scalar_op_registry.py | 3 ++ .../compile/sqlglot/expressions/ai_ops.py | 6 ++++ .../bigframes/bigframes/operations/ai_ops.py | 3 ++ .../tests/system/small/bigquery/test_ai.py | 6 +++- .../test_ai_ops/test_ai_if/None/out.sql | 6 +++- .../out.sql | 4 ++- .../test_ai_if_with_endpoint/out.sql | 6 ++++ .../sqlglot/expressions/test_ai_ops.py | 18 ++++++++++++ .../ibis/expr/operations/ai_ops.py | 3 ++ 10 files changed, 73 insertions(+), 10 deletions(-) create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 025dbb6aae9d..ecc1e14fb06b 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -817,6 +817,11 @@ def if_( prompt: PROMPT_TYPE, *, connection_id: str | None = None, + endpoint: str | None = None, + optimization_mode: Literal[ + "minimize_cost", "maximize_performance" + ] = "minimize_cost", + max_error_ratio: float = 1.0, ) -> series.Series: """ Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function @@ -838,13 +843,6 @@ def if_( 1 Illinois dtype: string - .. note:: - - This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the - Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" - and might have limited support. For more information, see the launch stage descriptions - (https://cloud.google.com/products#product-launch-stages). - Args: prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series @@ -852,6 +850,19 @@ def if_( connection_id (str, optional): Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. If not provided, the query uses your end-user credential. + endpoint (str, optional): + Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any + generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and + uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML dynamically chooses a model based on your query to have the + best cost to quality tradeoff for the task. + optimization_mode (Literal["minimize_cost", "maximize_performance"]): + Specifies the optimization strategy to use. Supported values are: + * "minimize_cost" (default): uses a local, distilled model to process the majority of rows, reducing latency and cost. + * "maximize_performance": always uses the remote LLM for inference. + max_error_ratio (float): + A float value between 0.0 and 1.0 that contains the maximum acceptable ratio of row-level inference failures to + rows processed on this function. If this value is exceeded, then the query fails. The default value is 1.0. + This argument isn't supported when `optimization_mode` is set to "minimize_cost". Returns: bigframes.series.Series: A new series of bools. @@ -863,6 +874,9 @@ def if_( operator = ai_ops.AIIf( prompt_context=tuple(prompt_context), connection_id=connection_id, + endpoint=endpoint, + optimization_mode=optimization_mode, + max_error_ratio=max_error_ratio, ) return series_list[0]._apply_nary_op(operator, series_list[1:]) diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index d56b8ec4387e..732d2ebfac05 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1983,6 +1983,9 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue: return ai_ops.AIIf( _construct_prompt(values, op.prompt_context), # type: ignore op.connection_id, # type: ignore + op.endpoint, # type: ignore + op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore + op.max_error_ratio, # type: ignore ).to_expr() diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index ddf1df64c015..2860c9d50d20 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -139,6 +139,12 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: expression=sge.JSON(this=sge.Literal.string(value)), ) ) + elif field == "optimization_mode": + args.append( + sge.Kwarg(this=field, expression=sge.Literal.string(value.upper())) + ) + elif field == "max_error_ratio": + args.append(sge.Kwarg(this=field, expression=sge.Literal.number(value))) elif field == "request_type": args.append( sge.Kwarg(this=field, expression=sge.Literal.string(value.upper())) diff --git a/packages/bigframes/bigframes/operations/ai_ops.py b/packages/bigframes/bigframes/operations/ai_ops.py index 37da57540a0e..fa471ecaf9a2 100644 --- a/packages/bigframes/bigframes/operations/ai_ops.py +++ b/packages/bigframes/bigframes/operations/ai_ops.py @@ -146,6 +146,9 @@ class AIIf(base_ops.NaryOp): prompt_context: Tuple[str | None, ...] connection_id: str | None + endpoint: str | None = None + optimization_mode: str | None = None + max_error_ratio: float | None = None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return dtypes.BOOL_DTYPE diff --git a/packages/bigframes/tests/system/small/bigquery/test_ai.py b/packages/bigframes/tests/system/small/bigquery/test_ai.py index 3bf1f1276ca4..a7d3fc30cdce 100644 --- a/packages/bigframes/tests/system/small/bigquery/test_ai.py +++ b/packages/bigframes/tests/system/small/bigquery/test_ai.py @@ -323,7 +323,11 @@ def test_ai_if(session): s2 = bpd.Series(["fruit", "tree"], session=session) prompt = (s1, " is a ", s2) - result = bbq.ai.if_(prompt) + result = bbq.ai.if_( + prompt, + optimization_mode="maximize_quality", + max_error_ratio=0.5, + ) assert _contains_no_nulls(result) assert result.dtype == dtypes.BOOL_DTYPE diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql index bae091982ea8..7696a12c5893 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql @@ -1,3 +1,7 @@ SELECT - AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result` + AI.IF( + prompt => (`string_col`, ' is the same as ', `string_col`), + optimization_mode => 'MINIMIZE_COST', + max_error_ratio => 0.5 + ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql index 698523d2e0b4..dc8707487b54 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql @@ -1,6 +1,8 @@ SELECT AI.IF( prompt => (`string_col`, ' is the same as ', `string_col`), - connection_id => 'bigframes-dev.us.bigframes-default-connection' + connection_id => 'bigframes-dev.us.bigframes-default-connection', + optimization_mode => 'MINIMIZE_COST', + max_error_ratio => 0.5 ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql new file mode 100644 index 000000000000..5074584bd72d --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql @@ -0,0 +1,6 @@ +SELECT + AI.IF( + prompt => (`string_col`, ' is the same as ', `string_col`), + endpoint => 'gemini-2.5-flash' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 05d3d64eddaa..c6dacee3fe6a 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -358,6 +358,24 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): op = ops.AIIf( prompt_context=(None, " is the same as ", None), connection_id=connection_id, + optimization_mode="minimize_cost", + max_error_ratio=0.5, + ) + + sql = utils._apply_ops_to_sql( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_if_with_endpoint(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIIf( + prompt_context=(None, " is the same as ", None), + connection_id=None, + endpoint="gemini-2.5-flash", ) sql = utils._apply_ops_to_sql( diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 150d7bccb32d..51cb3d415903 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -138,6 +138,9 @@ class AIIf(Value): prompt: Value connection_id: Optional[Value[dt.String]] + endpoint: Optional[Value[dt.String]] = None + optimization_mode: Optional[Value[dt.String]] = None + max_error_ratio: Optional[Value[dt.Float64]] = None shape = rlz.shape_like("prompt") From f1ac91660d94b7d41452c73f7956048ee9a6ac31 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 28 Apr 2026 22:35:32 +0000 Subject: [PATCH 2/5] fix param type --- packages/bigframes/bigframes/bigquery/_operations/ai.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index ecc1e14fb06b..2a684bef43cb 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -818,9 +818,7 @@ def if_( *, connection_id: str | None = None, endpoint: str | None = None, - optimization_mode: Literal[ - "minimize_cost", "maximize_performance" - ] = "minimize_cost", + optimization_mode: Literal["minimize_cost", "maximize_quality"] = "minimize_cost", max_error_ratio: float = 1.0, ) -> series.Series: """ @@ -855,10 +853,10 @@ def if_( generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML dynamically chooses a model based on your query to have the best cost to quality tradeoff for the task. - optimization_mode (Literal["minimize_cost", "maximize_performance"]): + optimization_mode (Literal["minimize_cost", "maximize_quality"]): Specifies the optimization strategy to use. Supported values are: * "minimize_cost" (default): uses a local, distilled model to process the majority of rows, reducing latency and cost. - * "maximize_performance": always uses the remote LLM for inference. + * "maximize_quality": always uses the remote LLM for inference. max_error_ratio (float): A float value between 0.0 and 1.0 that contains the maximum acceptable ratio of row-level inference failures to rows processed on this function. If this value is exceeded, then the query fails. The default value is 1.0. From e22c4d5ea0652a99b5c6413a09cd1f321163b115 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 28 Apr 2026 15:48:42 -0700 Subject: [PATCH 3/5] Update packages/bigframes/bigframes/bigquery/_operations/ai.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- packages/bigframes/bigframes/bigquery/_operations/ai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 2a684bef43cb..db148ef8460d 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -821,7 +821,7 @@ def if_( optimization_mode: Literal["minimize_cost", "maximize_quality"] = "minimize_cost", max_error_ratio: float = 1.0, ) -> series.Series: - """ + max_error_ratio: float | None = None, Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function provides optimization such that not all rows are evaluated with the LLM. From 08a8ffa039aff48d6e7db335a793ff08d2e3b379 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 28 Apr 2026 15:49:18 -0700 Subject: [PATCH 4/5] Update packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../bigframes/core/compile/sqlglot/expressions/ai_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 2860c9d50d20..a906e65144b0 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -139,11 +139,11 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: expression=sge.JSON(this=sge.Literal.string(value)), ) ) - elif field == "optimization_mode": + elif field == "optimization_mode" and value is not None: args.append( sge.Kwarg(this=field, expression=sge.Literal.string(value.upper())) ) - elif field == "max_error_ratio": + elif field == "max_error_ratio" and value is not None: args.append(sge.Kwarg(this=field, expression=sge.Literal.number(value))) elif field == "request_type": args.append( From e3593c2059ed26e951d2e535de00430edec28a41 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 28 Apr 2026 22:51:40 +0000 Subject: [PATCH 5/5] fix AI commits --- packages/bigframes/bigframes/bigquery/_operations/ai.py | 4 ++-- .../bigframes/core/compile/sqlglot/expressions/ai_ops.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index db148ef8460d..3911c6a913a6 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -819,9 +819,9 @@ def if_( connection_id: str | None = None, endpoint: str | None = None, optimization_mode: Literal["minimize_cost", "maximize_quality"] = "minimize_cost", - max_error_ratio: float = 1.0, -) -> series.Series: max_error_ratio: float | None = None, +) -> series.Series: + """ Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function provides optimization such that not all rows are evaluated with the LLM. diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index a906e65144b0..2860c9d50d20 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -139,11 +139,11 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: expression=sge.JSON(this=sge.Literal.string(value)), ) ) - elif field == "optimization_mode" and value is not None: + elif field == "optimization_mode": args.append( sge.Kwarg(this=field, expression=sge.Literal.string(value.upper())) ) - elif field == "max_error_ratio" and value is not None: + elif field == "max_error_ratio": args.append(sge.Kwarg(this=field, expression=sge.Literal.number(value))) elif field == "request_type": args.append(