Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 78 additions & 38 deletions sentry_sdk/integrations/asyncpg.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

import contextlib
import re
from typing import Any, TypeVar, Callable, Awaitable, Iterator
from typing import Any, Awaitable, Callable, Iterator, TypeVar, Union

import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable
from sentry_sdk.integrations import DidNotEnable, Integration, _check_minimum_version
from sentry_sdk.traces import StreamedSpan
from sentry_sdk.tracing import Span
from sentry_sdk.tracing_utils import add_query_source, record_sql_queries
from sentry_sdk.tracing_utils import (
add_query_source,
has_span_streaming_enabled,
record_sql_queries_supporting_streaming,
)
from sentry_sdk.utils import (
capture_internal_exceptions,
ensure_integration_enabled,
parse_version,
capture_internal_exceptions,
)

try:
Expand Down Expand Up @@ -62,7 +68,8 @@ def _normalize_query(query: str) -> str:

def _wrap_execute(f: "Callable[..., Awaitable[T]]") -> "Callable[..., Awaitable[T]]":
async def _inner(*args: "Any", **kwargs: "Any") -> "T":
if sentry_sdk.get_client().get_integration(AsyncPGIntegration) is None:
client = sentry_sdk.get_client()
if client.get_integration(AsyncPGIntegration) is None:
return await f(*args, **kwargs)

# Avoid recording calls to _execute twice.
Expand All @@ -73,7 +80,7 @@ async def _inner(*args: "Any", **kwargs: "Any") -> "T":
return await f(*args, **kwargs)

query = _normalize_query(args[1])
with record_sql_queries(
with record_sql_queries_supporting_streaming(
cursor=None,
query=query,
params_list=None,
Expand All @@ -82,9 +89,13 @@ async def _inner(*args: "Any", **kwargs: "Any") -> "T":
span_origin=AsyncPGIntegration.origin,
) as span:
res = await f(*args, **kwargs)
if isinstance(span, StreamedSpan):
with capture_internal_exceptions():
add_query_source(span)

with capture_internal_exceptions():
add_query_source(span)
Comment thread
ericapisani marked this conversation as resolved.
if not isinstance(span, StreamedSpan):
Comment thread
ericapisani marked this conversation as resolved.
with capture_internal_exceptions():
add_query_source(span)

return res

Expand All @@ -101,15 +112,16 @@ def _record(
params_list: "tuple[Any, ...] | None",
*,
executemany: bool = False,
) -> "Iterator[Span]":
integration = sentry_sdk.get_client().get_integration(AsyncPGIntegration)
) -> "Iterator[Union[Span, StreamedSpan]]":
client = sentry_sdk.get_client()
integration = client.get_integration(AsyncPGIntegration)
if integration is not None and not integration._record_params:
params_list = None

param_style = "pyformat" if params_list else None

query = _normalize_query(query)
with record_sql_queries(
with record_sql_queries_supporting_streaming(
Comment thread
ericapisani marked this conversation as resolved.
cursor=cursor,
query=query,
params_list=params_list,
Comment thread
ericapisani marked this conversation as resolved.
Expand Down Expand Up @@ -152,7 +164,6 @@ def _inner(*args: "Any", **kwargs: "Any") -> "T": # noqa: N807
) as span:
_set_db_data(span, args[0])
res = f(*args, **kwargs)
span.set_data("db.cursor", res)

return res

Expand All @@ -163,56 +174,85 @@ def _wrap_connect_addr(
f: "Callable[..., Awaitable[T]]",
) -> "Callable[..., Awaitable[T]]":
async def _inner(*args: "Any", **kwargs: "Any") -> "T":
if sentry_sdk.get_client().get_integration(AsyncPGIntegration) is None:
client = sentry_sdk.get_client()
if client.get_integration(AsyncPGIntegration) is None:
return await f(*args, **kwargs)

user = kwargs["params"].user
database = kwargs["params"].database

with sentry_sdk.start_span(
op=OP.DB,
name="connect",
origin=AsyncPGIntegration.origin,
) as span:
span.set_data(SPANDATA.DB_SYSTEM, "postgresql")
addr = kwargs.get("addr")
addr = kwargs.get("addr")

if has_span_streaming_enabled(client.options):
span_attributes = {
"sentry.op": OP.DB,
"sentry.origin": AsyncPGIntegration.origin,
SPANDATA.DB_SYSTEM: "postgresql",
SPANDATA.DB_USER: user,
SPANDATA.DB_NAME: database,
SPANDATA.DB_DRIVER_NAME: "asyncpg",
}
if addr:
try:
span.set_data(SPANDATA.SERVER_ADDRESS, addr[0])
span.set_data(SPANDATA.SERVER_PORT, addr[1])
span_attributes[SPANDATA.SERVER_ADDRESS] = addr[0]
span_attributes[SPANDATA.SERVER_PORT] = addr[1]
except IndexError:
pass
span.set_data(SPANDATA.DB_NAME, database)
span.set_data(SPANDATA.DB_USER, user)
span.set_data(SPANDATA.DB_DRIVER_NAME, "asyncpg")

with capture_internal_exceptions():
sentry_sdk.add_breadcrumb(
message="connect", category="query", data=span._data
)
res = await f(*args, **kwargs)
with sentry_sdk.traces.start_span(
name="connect", attributes=span_attributes
) as span:
with capture_internal_exceptions():
sentry_sdk.add_breadcrumb(
message="connect", category="query", data=span_attributes
)
res = await f(*args, **kwargs)

else:
with sentry_sdk.start_span(
op=OP.DB,
name="connect",
origin=AsyncPGIntegration.origin,
) as span:
span.set_data(SPANDATA.DB_SYSTEM, "postgresql")
if addr:
try:
span.set_data(SPANDATA.SERVER_ADDRESS, addr[0])
span.set_data(SPANDATA.SERVER_PORT, addr[1])
except IndexError:
pass
span.set_data(SPANDATA.DB_NAME, database)
span.set_data(SPANDATA.DB_USER, user)
span.set_data(SPANDATA.DB_DRIVER_NAME, "asyncpg")

with capture_internal_exceptions():
sentry_sdk.add_breadcrumb(
message="connect", category="query", data=span._data
)
res = await f(*args, **kwargs)

return res

return _inner


def _set_db_data(span: "Span", conn: "Any") -> None:
span.set_data(SPANDATA.DB_SYSTEM, "postgresql")
span.set_data(SPANDATA.DB_DRIVER_NAME, "asyncpg")
def _set_db_data(span: "Union[Span, StreamedSpan]", conn: "Any") -> None:
set_value = span.set_attribute if isinstance(span, StreamedSpan) else span.set_data

set_value(SPANDATA.DB_SYSTEM, "postgresql")
set_value(SPANDATA.DB_DRIVER_NAME, "asyncpg")

addr = conn._addr
if addr:
try:
span.set_data(SPANDATA.SERVER_ADDRESS, addr[0])
span.set_data(SPANDATA.SERVER_PORT, addr[1])
set_value(SPANDATA.SERVER_ADDRESS, addr[0])
set_value(SPANDATA.SERVER_PORT, addr[1])
except IndexError:
pass

database = conn._params.database
if database:
span.set_data(SPANDATA.DB_NAME, database)
set_value(SPANDATA.DB_NAME, database)

user = conn._params.user
if user:
span.set_data(SPANDATA.DB_USER, user)
set_value(SPANDATA.DB_USER, user)
Loading
Loading