Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions packages/bigframes/bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,9 @@ def if_(
prompt: PROMPT_TYPE,
*,
connection_id: str | None = None,
endpoint: str | None = None,
optimization_mode: Literal["minimize_cost", "maximize_quality"] = "minimize_cost",
max_error_ratio: float | None = None,
) -> series.Series:
"""
Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function
Expand All @@ -838,20 +841,26 @@ def if_(
1 Illinois
dtype: string

.. note::

This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
and might have limited support. For more information, see the launch stage descriptions
(https://cloud.google.com/products#product-launch-stages).

Args:
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
or pandas Series.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the query uses your end-user credential.
endpoint (str, optional):
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML dynamically chooses a model based on your query to have the
best cost to quality tradeoff for the task.
optimization_mode (Literal["minimize_cost", "maximize_quality"]):
Specifies the optimization strategy to use. Supported values are:
* "minimize_cost" (default): uses a local, distilled model to process the majority of rows, reducing latency and cost.
* "maximize_quality": always uses the remote LLM for inference.
max_error_ratio (float):
A float value between 0.0 and 1.0 that contains the maximum acceptable ratio of row-level inference failures to
rows processed on this function. If this value is exceeded, then the query fails. The default value is 1.0.
This argument isn't supported when `optimization_mode` is set to "minimize_cost".

Returns:
bigframes.series.Series: A new series of bools.
Expand All @@ -863,6 +872,9 @@ def if_(
operator = ai_ops.AIIf(
prompt_context=tuple(prompt_context),
connection_id=connection_id,
endpoint=endpoint,
optimization_mode=optimization_mode,
max_error_ratio=max_error_ratio,
)

return series_list[0]._apply_nary_op(operator, series_list[1:])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,9 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
return ai_ops.AIIf(
_construct_prompt(values, op.prompt_context), # type: ignore
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore
op.max_error_ratio, # type: ignore
).to_expr()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
expression=sge.JSON(this=sge.Literal.string(value)),
)
)
elif field == "optimization_mode":
args.append(
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
)
elif field == "max_error_ratio":
args.append(sge.Kwarg(this=field, expression=sge.Literal.number(value)))
elif field == "request_type":
args.append(
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
Expand Down
3 changes: 3 additions & 0 deletions packages/bigframes/bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class AIIf(base_ops.NaryOp):

prompt_context: Tuple[str | None, ...]
connection_id: str | None
endpoint: str | None = None
optimization_mode: str | None = None
max_error_ratio: float | None = None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return dtypes.BOOL_DTYPE
Expand Down
6 changes: 5 additions & 1 deletion packages/bigframes/tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,11 @@ def test_ai_if(session):
s2 = bpd.Series(["fruit", "tree"], session=session)
prompt = (s1, " is a ", s2)

result = bbq.ai.if_(prompt)
result = bbq.ai.if_(
prompt,
optimization_mode="maximize_quality",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value "maximize_quality" is not a valid option for optimization_mode. Based on the Literal definition in the function signature and the BigQuery ML documentation, this should be "maximize_performance".

Suggested change
optimization_mode="maximize_quality",
optimization_mode="maximize_performance",

max_error_ratio=0.5,
)

assert _contains_no_nulls(result)
assert result.dtype == dtypes.BOOL_DTYPE
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
SELECT
AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
AI.IF(
prompt => (`string_col`, ' is the same as ', `string_col`),
optimization_mode => 'MINIMIZE_COST',
max_error_ratio => 0.5
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
SELECT
AI.IF(
prompt => (`string_col`, ' is the same as ', `string_col`),
connection_id => 'bigframes-dev.us.bigframes-default-connection'
connection_id => 'bigframes-dev.us.bigframes-default-connection',
optimization_mode => 'MINIMIZE_COST',
max_error_ratio => 0.5
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT
AI.IF(
prompt => (`string_col`, ' is the same as ', `string_col`),
endpoint => 'gemini-2.5-flash'
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,24 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
op = ops.AIIf(
prompt_context=(None, " is the same as ", None),
connection_id=connection_id,
optimization_mode="minimize_cost",
max_error_ratio=0.5,
)

sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_if_with_endpoint(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIIf(
prompt_context=(None, " is the same as ", None),
connection_id=None,
endpoint="gemini-2.5-flash",
)

sql = utils._apply_ops_to_sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ class AIIf(Value):

prompt: Value
connection_id: Optional[Value[dt.String]]
endpoint: Optional[Value[dt.String]] = None
optimization_mode: Optional[Value[dt.String]] = None
max_error_ratio: Optional[Value[dt.Float64]] = None

shape = rlz.shape_like("prompt")

Expand Down
Loading