diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 025dbb6aae9d..3911c6a913a6 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -817,6 +817,9 @@ def if_( prompt: PROMPT_TYPE, *, connection_id: str | None = None, + endpoint: str | None = None, + optimization_mode: Literal["minimize_cost", "maximize_quality"] = "minimize_cost", + max_error_ratio: float | None = None, ) -> series.Series: """ Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function @@ -838,13 +841,6 @@ def if_( 1 Illinois dtype: string - .. note:: - - This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the - Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" - and might have limited support. For more information, see the launch stage descriptions - (https://cloud.google.com/products#product-launch-stages). - Args: prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series @@ -852,6 +848,19 @@ def if_( connection_id (str, optional): Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. If not provided, the query uses your end-user credential. + endpoint (str, optional): + Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any + generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and + uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML dynamically chooses a model based on your query to have the + best cost to quality tradeoff for the task. + optimization_mode (Literal["minimize_cost", "maximize_quality"]): + Specifies the optimization strategy to use. Supported values are: + * "minimize_cost" (default): uses a local, distilled model to process the majority of rows, reducing latency and cost. + * "maximize_quality": always uses the remote LLM for inference. + max_error_ratio (float): + A float value between 0.0 and 1.0 that contains the maximum acceptable ratio of row-level inference failures to + rows processed on this function. If this value is exceeded, then the query fails. The default value is 1.0. + This argument isn't supported when `optimization_mode` is set to "minimize_cost". Returns: bigframes.series.Series: A new series of bools. @@ -863,6 +872,9 @@ def if_( operator = ai_ops.AIIf( prompt_context=tuple(prompt_context), connection_id=connection_id, + endpoint=endpoint, + optimization_mode=optimization_mode, + max_error_ratio=max_error_ratio, ) return series_list[0]._apply_nary_op(operator, series_list[1:]) diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index d56b8ec4387e..732d2ebfac05 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1983,6 +1983,9 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue: return ai_ops.AIIf( _construct_prompt(values, op.prompt_context), # type: ignore op.connection_id, # type: ignore + op.endpoint, # type: ignore + op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore + op.max_error_ratio, # type: ignore ).to_expr() diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index ddf1df64c015..2860c9d50d20 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -139,6 +139,12 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: expression=sge.JSON(this=sge.Literal.string(value)), ) ) + elif field == "optimization_mode": + args.append( + sge.Kwarg(this=field, expression=sge.Literal.string(value.upper())) + ) + elif field == "max_error_ratio": + args.append(sge.Kwarg(this=field, expression=sge.Literal.number(value))) elif field == "request_type": args.append( sge.Kwarg(this=field, expression=sge.Literal.string(value.upper())) diff --git a/packages/bigframes/bigframes/operations/ai_ops.py b/packages/bigframes/bigframes/operations/ai_ops.py index 37da57540a0e..fa471ecaf9a2 100644 --- a/packages/bigframes/bigframes/operations/ai_ops.py +++ b/packages/bigframes/bigframes/operations/ai_ops.py @@ -146,6 +146,9 @@ class AIIf(base_ops.NaryOp): prompt_context: Tuple[str | None, ...] connection_id: str | None + endpoint: str | None = None + optimization_mode: str | None = None + max_error_ratio: float | None = None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return dtypes.BOOL_DTYPE diff --git a/packages/bigframes/tests/system/small/bigquery/test_ai.py b/packages/bigframes/tests/system/small/bigquery/test_ai.py index 3bf1f1276ca4..a7d3fc30cdce 100644 --- a/packages/bigframes/tests/system/small/bigquery/test_ai.py +++ b/packages/bigframes/tests/system/small/bigquery/test_ai.py @@ -323,7 +323,11 @@ def test_ai_if(session): s2 = bpd.Series(["fruit", "tree"], session=session) prompt = (s1, " is a ", s2) - result = bbq.ai.if_(prompt) + result = bbq.ai.if_( + prompt, + optimization_mode="maximize_quality", + max_error_ratio=0.5, + ) assert _contains_no_nulls(result) assert result.dtype == dtypes.BOOL_DTYPE diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql index bae091982ea8..7696a12c5893 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/None/out.sql @@ -1,3 +1,7 @@ SELECT - AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result` + AI.IF( + prompt => (`string_col`, ' is the same as ', `string_col`), + optimization_mode => 'MINIMIZE_COST', + max_error_ratio => 0.5 + ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql index 698523d2e0b4..dc8707487b54 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql @@ -1,6 +1,8 @@ SELECT AI.IF( prompt => (`string_col`, ' is the same as ', `string_col`), - connection_id => 'bigframes-dev.us.bigframes-default-connection' + connection_id => 'bigframes-dev.us.bigframes-default-connection', + optimization_mode => 'MINIMIZE_COST', + max_error_ratio => 0.5 ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql new file mode 100644 index 000000000000..5074584bd72d --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if_with_endpoint/out.sql @@ -0,0 +1,6 @@ +SELECT + AI.IF( + prompt => (`string_col`, ' is the same as ', `string_col`), + endpoint => 'gemini-2.5-flash' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 05d3d64eddaa..c6dacee3fe6a 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -358,6 +358,24 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): op = ops.AIIf( prompt_context=(None, " is the same as ", None), connection_id=connection_id, + optimization_mode="minimize_cost", + max_error_ratio=0.5, + ) + + sql = utils._apply_ops_to_sql( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_if_with_endpoint(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIIf( + prompt_context=(None, " is the same as ", None), + connection_id=None, + endpoint="gemini-2.5-flash", ) sql = utils._apply_ops_to_sql( diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 150d7bccb32d..51cb3d415903 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -138,6 +138,9 @@ class AIIf(Value): prompt: Value connection_id: Optional[Value[dt.String]] + endpoint: Optional[Value[dt.String]] = None + optimization_mode: Optional[Value[dt.String]] = None + max_error_ratio: Optional[Value[dt.Float64]] = None shape = rlz.shape_like("prompt")