[0-9]+(?:\.[0-9]+)*) # release segment
+ (?P # pre-release
+ [-_\.]?
+ (?P(a|b|c|rc|alpha|beta|pre|preview))
+ [-_\.]?
+ (?P[0-9]+)?
+ )?
+ (?P # post release
+ (?:-(?P[0-9]+))
+ |
+ (?:
+ [-_\.]?
+ (?Ppost|rev|r)
+ [-_\.]?
+ (?P[0-9]+)?
+ )
+ )?
+ (?P # dev release
+ [-_\.]?
+ (?Pdev)
+ [-_\.]?
+ (?P[0-9]+)?
+ )?
)
- return value
+ (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
+ """
+
+ pattern = re.compile(
+ r"^\s*" + VERSION_PATTERN + r"\s*$",
+ re.VERBOSE | re.IGNORECASE,
+ )
+
+ try:
+ release = pattern.match(version).groupdict()["release"] # type: ignore
+ release_tuple: "Tuple[int, ...]" = tuple(map(int, release.split(".")[:3]))
+ except (TypeError, ValueError, AttributeError):
+ return None
+ return release_tuple
-def _is_contextvars_broken():
- # type: () -> bool
+
+def _is_contextvars_broken() -> bool:
"""
Returns whether gevent/eventlet have patched the stdlib in a way where thread locals are now more "correct" than contextvars.
"""
try:
- import gevent # type: ignore
- from gevent.monkey import is_object_patched # type: ignore
+ import gevent
+ from gevent.monkey import is_object_patched
# Get the MAJOR and MINOR version numbers of Gevent
version_tuple = tuple(
@@ -976,9 +1372,18 @@ def _is_contextvars_broken():
pass
try:
+ import greenlet
from eventlet.patcher import is_monkey_patched # type: ignore
- if is_monkey_patched("thread"):
+ greenlet_version = parse_version(greenlet.__version__)
+
+ if greenlet_version is None:
+ logger.error(
+ "Internal error in Sentry SDK: Could not parse Greenlet version from greenlet.__version__."
+ )
+ return False
+
+ if is_monkey_patched("thread") and greenlet_version < (0, 5):
return True
except ImportError:
pass
@@ -986,29 +1391,35 @@ def _is_contextvars_broken():
return False
-def _make_threadlocal_contextvars(local):
- # type: (type) -> type
- class ContextVar(object):
+def _make_threadlocal_contextvars(local: type) -> type:
+ class ContextVar:
# Super-limited impl of ContextVar
- def __init__(self, name):
- # type: (str) -> None
+ def __init__(self, name: str, default: "Any" = None) -> None:
self._name = name
+ self._default = default
self._local = local()
+ self._original_local = local()
- def get(self, default):
- # type: (Any) -> Any
- return getattr(self._local, "value", default)
+ def get(self, default: "Any" = None) -> "Any":
+ return getattr(self._local, "value", default or self._default)
- def set(self, value):
- # type: (Any) -> None
+ def set(self, value: "Any") -> "Any":
+ token = str(random.getrandbits(64))
+ original_value = self.get()
+ setattr(self._original_local, token, original_value)
self._local.value = value
+ return token
+
+ def reset(self, token: "Any") -> None:
+ self._local.value = getattr(self._original_local, token)
+ # delete the original value (this way it works in Python 3.6+)
+ del self._original_local.__dict__[token]
return ContextVar
-def _get_contextvars():
- # type: () -> Tuple[bool, type]
+def _get_contextvars() -> "Tuple[bool, type]":
"""
Figure out the "right" contextvars installation to use. Returns a
`contextvars.ContextVar`-like class with a limited API.
@@ -1057,10 +1468,9 @@ def _get_contextvars():
"""
-def qualname_from_function(func):
- # type: (Callable[..., Any]) -> Optional[str]
+def qualname_from_function(func: "Callable[..., Any]") -> "Optional[str]":
"""Return the qualified name of func. Works with regular function, lambda, partial and partialmethod."""
- func_qualname = None # type: Optional[str]
+ func_qualname: "Optional[str]" = None
# Python 2
try:
@@ -1074,16 +1484,18 @@ def qualname_from_function(func):
prefix, suffix = "", ""
- if (
- _PARTIALMETHOD_AVAILABLE
- and hasattr(func, "_partialmethod")
- and isinstance(func._partialmethod, partialmethod) # type: ignore
- ):
- prefix, suffix = "partialmethod()"
- func = func._partialmethod.func # type: ignore
- elif isinstance(func, partial) and hasattr(func.func, "__name__"):
+ if isinstance(func, partial) and hasattr(func.func, "__name__"):
prefix, suffix = "partial()"
func = func.func
+ else:
+ # The _partialmethod attribute of methods wrapped with partialmethod() was renamed to __partialmethod__ in CPython 3.13:
+ # https://github.com/python/cpython/pull/16600
+ partial_method = getattr(func, "_partialmethod", None) or getattr(
+ func, "__partialmethod__", None
+ )
+ if isinstance(partial_method, partialmethod):
+ prefix, suffix = "partialmethod()"
+ func = partial_method.func
if hasattr(func, "__qualname__"):
func_qualname = func.__qualname__
@@ -1092,15 +1504,14 @@ def qualname_from_function(func):
# Python 3: methods, functions, classes
if func_qualname is not None:
- if hasattr(func, "__module__"):
+ if hasattr(func, "__module__") and isinstance(func.__module__, str):
func_qualname = func.__module__ + "." + func_qualname
func_qualname = prefix + func_qualname + suffix
return func_qualname
-def transaction_from_function(func):
- # type: (Callable[..., Any]) -> Optional[str]
+def transaction_from_function(func: "Callable[..., Any]") -> "Optional[str]":
return qualname_from_function(func)
@@ -1118,20 +1529,39 @@ class TimeoutThread(threading.Thread):
waiting_time and raises a custom ServerlessTimeout exception.
"""
- def __init__(self, waiting_time, configured_timeout):
- # type: (float, int) -> None
+ def __init__(
+ self,
+ waiting_time: float,
+ configured_timeout: int,
+ isolation_scope: "Optional[sentry_sdk.Scope]" = None,
+ current_scope: "Optional[sentry_sdk.Scope]" = None,
+ ) -> None:
threading.Thread.__init__(self)
self.waiting_time = waiting_time
self.configured_timeout = configured_timeout
+
+ self.isolation_scope = isolation_scope
+ self.current_scope = current_scope
+
self._stop_event = threading.Event()
- def stop(self):
- # type: () -> None
+ def stop(self) -> None:
self._stop_event.set()
- def run(self):
- # type: () -> None
+ def _capture_exception(self) -> "ExcInfo":
+ exc_info = sys.exc_info()
+
+ client = sentry_sdk.get_client()
+ event, hint = event_from_exception(
+ exc_info,
+ client_options=client.options,
+ mechanism={"type": "threading", "handled": False},
+ )
+ sentry_sdk.capture_event(event, hint=hint)
+ return exc_info
+
+ def run(self) -> None:
self._stop_event.wait(self.waiting_time)
if self._stop_event.is_set():
@@ -1144,6 +1574,18 @@ def run(self):
integer_configured_timeout = integer_configured_timeout + 1
# Raising Exception after timeout duration is reached
+ if self.isolation_scope is not None and self.current_scope is not None:
+ with sentry_sdk.scope.use_isolation_scope(self.isolation_scope):
+ with sentry_sdk.scope.use_scope(self.current_scope):
+ try:
+ raise ServerlessTimeoutWarning(
+ "WARNING : Function is expected to get timed out. Configured timeout duration = {} seconds.".format(
+ integer_configured_timeout
+ )
+ )
+ except Exception:
+ reraise(*self._capture_exception())
+
raise ServerlessTimeoutWarning(
"WARNING : Function is expected to get timed out. Configured timeout duration = {} seconds.".format(
integer_configured_timeout
@@ -1151,8 +1593,7 @@ def run(self):
)
-def to_base64(original):
- # type: (str) -> Optional[str]
+def to_base64(original: str) -> "Optional[str]":
"""
Convert a string to base64, via UTF-8. Returns None on invalid input.
"""
@@ -1168,8 +1609,7 @@ def to_base64(original):
return base64_string
-def from_base64(base64_string):
- # type: (str) -> Optional[str]
+def from_base64(base64_string: str) -> "Optional[str]":
"""
Convert a string from base64, via UTF-8. Returns None on invalid input.
"""
@@ -1193,8 +1633,12 @@ def from_base64(base64_string):
Components = namedtuple("Components", ["scheme", "netloc", "path", "query", "fragment"])
-def sanitize_url(url, remove_authority=True, remove_query_values=True):
- # type: (str, bool, bool) -> str
+def sanitize_url(
+ url: str,
+ remove_authority: bool = True,
+ remove_query_values: bool = True,
+ split: bool = False,
+) -> "Union[str, Components]":
"""
Removes the authority and query parameter values from a given URL.
"""
@@ -1223,48 +1667,51 @@ def sanitize_url(url, remove_authority=True, remove_query_values=True):
else:
query_string = parsed_url.query
- safe_url = urlunsplit(
- Components(
- scheme=parsed_url.scheme,
- netloc=netloc,
- query=query_string,
- path=parsed_url.path,
- fragment=parsed_url.fragment,
- )
+ components = Components(
+ scheme=parsed_url.scheme,
+ netloc=netloc,
+ query=query_string,
+ path=parsed_url.path,
+ fragment=parsed_url.fragment,
)
- return safe_url
+ if split:
+ return components
+ else:
+ return urlunsplit(components)
ParsedUrl = namedtuple("ParsedUrl", ["url", "query", "fragment"])
-def parse_url(url, sanitize=True):
-
- # type: (str, bool) -> ParsedUrl
+def parse_url(url: str, sanitize: bool = True) -> "ParsedUrl":
"""
Splits a URL into a url (including path), query and fragment. If sanitize is True, the query
parameters will be sanitized to remove sensitive data. The autority (username and password)
in the URL will always be removed.
"""
- url = sanitize_url(url, remove_authority=True, remove_query_values=sanitize)
+ parsed_url = sanitize_url(
+ url, remove_authority=True, remove_query_values=sanitize, split=True
+ )
- parsed_url = urlsplit(url)
base_url = urlunsplit(
Components(
- scheme=parsed_url.scheme,
- netloc=parsed_url.netloc,
+ scheme=parsed_url.scheme, # type: ignore
+ netloc=parsed_url.netloc, # type: ignore
query="",
- path=parsed_url.path,
+ path=parsed_url.path, # type: ignore
fragment="",
)
)
- return ParsedUrl(url=base_url, query=parsed_url.query, fragment=parsed_url.fragment)
+ return ParsedUrl(
+ url=base_url,
+ query=parsed_url.query, # type: ignore
+ fragment=parsed_url.fragment, # type: ignore
+ )
-def is_valid_sample_rate(rate, source):
- # type: (Any, str) -> bool
+def is_valid_sample_rate(rate: "Any", source: str) -> bool:
"""
Checks the given sample rate to make sure it is valid type and value (a
boolean or a number between 0 and 1, inclusive).
@@ -1294,33 +1741,424 @@ def is_valid_sample_rate(rate, source):
return True
+def match_regex_list(
+ item: str,
+ regex_list: "Optional[List[str]]" = None,
+ substring_matching: bool = False,
+) -> bool:
+ if regex_list is None:
+ return False
+
+ for item_matcher in regex_list:
+ if not substring_matching and item_matcher[-1] != "$":
+ item_matcher += "$"
+
+ matched = re.search(item_matcher, item)
+ if matched:
+ return True
+
+ return False
+
+
+def is_sentry_url(client: "sentry_sdk.client.BaseClient", url: str) -> bool:
+ """
+ Determines whether the given URL matches the Sentry DSN.
+ """
+ return (
+ client is not None
+ and client.transport is not None
+ and client.transport.parsed_dsn is not None
+ and client.transport.parsed_dsn.netloc in url
+ )
+
+
+def _generate_installed_modules() -> "Iterator[Tuple[str, str]]":
+ try:
+ from importlib import metadata
+
+ yielded = set()
+ for dist in metadata.distributions():
+ name = dist.metadata.get("Name", None) # type: ignore[attr-defined]
+ # `metadata` values may be `None`, see:
+ # https://github.com/python/cpython/issues/91216
+ # and
+ # https://github.com/python/importlib_metadata/issues/371
+ if name is not None:
+ normalized_name = _normalize_module_name(name)
+ if dist.version is not None and normalized_name not in yielded:
+ yield normalized_name, dist.version
+ yielded.add(normalized_name)
+
+ except ImportError:
+ # < py3.8
+ try:
+ import pkg_resources
+ except ImportError:
+ return
+
+ for info in pkg_resources.working_set:
+ yield _normalize_module_name(info.key), info.version
+
+
+def _normalize_module_name(name: str) -> str:
+ return name.lower()
+
+
+def _replace_hyphens_dots_and_underscores_with_dashes(name: str) -> str:
+ # https://peps.python.org/pep-0503/#normalized-names
+ return re.sub(r"[-_.]+", "-", name)
+
+
+def _get_installed_modules() -> "Dict[str, str]":
+ global _installed_modules
+ if _installed_modules is None:
+ _installed_modules = dict(_generate_installed_modules())
+ return _installed_modules
+
+
+def package_version(package: str) -> "Optional[Tuple[int, ...]]":
+ normalized_package = _normalize_module_name(
+ _replace_hyphens_dots_and_underscores_with_dashes(package)
+ )
+
+ installed_packages = {
+ _replace_hyphens_dots_and_underscores_with_dashes(module): v
+ for module, v in _get_installed_modules().items()
+ }
+ version = installed_packages.get(normalized_package)
+ if version is None:
+ return None
+
+ return parse_version(version)
+
+
+def reraise(
+ tp: "Optional[Type[BaseException]]",
+ value: "Optional[BaseException]",
+ tb: "Optional[Any]" = None,
+) -> "NoReturn":
+ assert value is not None
+ if value.__traceback__ is not tb:
+ raise value.with_traceback(tb)
+ raise value
+
+
+def _no_op(*_a: "Any", **_k: "Any") -> None:
+ """No-op function for ensure_integration_enabled."""
+ pass
+
+
+if TYPE_CHECKING:
+
+ @overload
+ def ensure_integration_enabled(
+ integration: "type[sentry_sdk.integrations.Integration]",
+ original_function: "Callable[P, R]",
+ ) -> "Callable[[Callable[P, R]], Callable[P, R]]": ...
+
+ @overload
+ def ensure_integration_enabled(
+ integration: "type[sentry_sdk.integrations.Integration]",
+ ) -> "Callable[[Callable[P, None]], Callable[P, None]]": ...
+
+
+def ensure_integration_enabled(
+ integration: "type[sentry_sdk.integrations.Integration]",
+ original_function: "Union[Callable[P, R], Callable[P, None]]" = _no_op,
+) -> "Callable[[Callable[P, R]], Callable[P, R]]":
+ """
+ Ensures a given integration is enabled prior to calling a Sentry-patched function.
+
+ The function takes as its parameters the integration that must be enabled and the original
+ function that the SDK is patching. The function returns a function that takes the
+ decorated (Sentry-patched) function as its parameter, and returns a function that, when
+ called, checks whether the given integration is enabled. If the integration is enabled, the
+ function calls the decorated, Sentry-patched function. If the integration is not enabled,
+ the original function is called.
+
+ The function also takes care of preserving the original function's signature and docstring.
+
+ Example usage:
+
+ ```python
+ @ensure_integration_enabled(MyIntegration, my_function)
+ def patch_my_function():
+ with sentry_sdk.start_transaction(...):
+ return my_function()
+ ```
+ """
+ if TYPE_CHECKING:
+ # Type hint to ensure the default function has the right typing. The overloads
+ # ensure the default _no_op function is only used when R is None.
+ original_function = cast(Callable[P, R], original_function)
+
+ def patcher(sentry_patched_function: "Callable[P, R]") -> "Callable[P, R]":
+ def runner(*args: "P.args", **kwargs: "P.kwargs") -> "R":
+ if sentry_sdk.get_client().get_integration(integration) is None:
+ return original_function(*args, **kwargs)
+
+ return sentry_patched_function(*args, **kwargs)
+
+ if original_function is _no_op:
+ return wraps(sentry_patched_function)(runner)
+
+ return wraps(original_function)(runner)
+
+ return patcher
+
+
if PY37:
- def nanosecond_time():
- # type: () -> int
+ def nanosecond_time() -> int:
return time.perf_counter_ns()
-elif PY33:
+else:
- def nanosecond_time():
- # type: () -> int
+ def nanosecond_time() -> int:
return int(time.perf_counter() * 1e9)
-else:
- def nanosecond_time():
- # type: () -> int
- raise AttributeError
+def now() -> float:
+ return time.perf_counter()
-if PY2:
+try:
+ from gevent import get_hub as get_gevent_hub
+ from gevent.monkey import is_module_patched
+except ImportError:
+ # it's not great that the signatures are different, get_hub can't return None
+ # consider adding an if TYPE_CHECKING to change the signature to Optional[Hub]
+ def get_gevent_hub() -> "Optional[Hub]": # type: ignore[misc]
+ return None
- def now():
- # type: () -> float
- return time.time()
+ def is_module_patched(mod_name: str) -> bool:
+ # unable to import from gevent means no modules have been patched
+ return False
-else:
- def now():
- # type: () -> float
- return time.perf_counter()
+def is_gevent() -> bool:
+ return is_module_patched("threading") or is_module_patched("_thread")
+
+
+def get_current_thread_meta(
+ thread: "Optional[threading.Thread]" = None,
+) -> "Tuple[Optional[int], Optional[str]]":
+ """
+ Try to get the id of the current thread, with various fall backs.
+ """
+
+ # if a thread is specified, that takes priority
+ if thread is not None:
+ try:
+ thread_id = thread.ident
+ thread_name = thread.name
+ if thread_id is not None:
+ return thread_id, thread_name
+ except AttributeError:
+ pass
+
+ # if the app is using gevent, we should look at the gevent hub first
+ # as the id there differs from what the threading module reports
+ if is_gevent():
+ gevent_hub = get_gevent_hub()
+ if gevent_hub is not None:
+ try:
+ # this is undocumented, so wrap it in try except to be safe
+ return gevent_hub.thread_ident, None
+ except AttributeError:
+ pass
+
+ # use the current thread's id if possible
+ try:
+ thread = threading.current_thread()
+ thread_id = thread.ident
+ thread_name = thread.name
+ if thread_id is not None:
+ return thread_id, thread_name
+ except AttributeError:
+ pass
+
+ # if we can't get the current thread id, fall back to the main thread id
+ try:
+ thread = threading.main_thread()
+ thread_id = thread.ident
+ thread_name = thread.name
+ if thread_id is not None:
+ return thread_id, thread_name
+ except AttributeError:
+ pass
+
+ # we've tried everything, time to give up
+ return None, None
+
+
+def should_be_treated_as_error(ty: "Any", value: "Any") -> bool:
+ if ty == SystemExit and hasattr(value, "code") and value.code in (0, None):
+ # https://docs.python.org/3/library/exceptions.html#SystemExit
+ return False
+
+ return True
+
+
+if TYPE_CHECKING:
+ T = TypeVar("T")
+
+
+def try_convert(convert_func: "Callable[[Any], T]", value: "Any") -> "Optional[T]":
+ """
+ Attempt to convert from an unknown type to a specific type, using the
+ given function. Return None if the conversion fails, i.e. if the function
+ raises an exception.
+ """
+ try:
+ if isinstance(value, convert_func): # type: ignore
+ return value
+ except TypeError:
+ pass
+
+ try:
+ return convert_func(value)
+ except Exception:
+ return None
+
+
+def safe_serialize(data: "Any") -> str:
+ """Safely serialize to a readable string."""
+
+ def serialize_item(
+ item: "Any",
+ ) -> "Union[str, dict[Any, Any], list[Any], tuple[Any, ...]]":
+ if callable(item):
+ try:
+ module = getattr(item, "__module__", None)
+ qualname = getattr(item, "__qualname__", None)
+ name = getattr(item, "__name__", "anonymous")
+
+ if module and qualname:
+ full_path = f"{module}.{qualname}"
+ elif module and name:
+ full_path = f"{module}.{name}"
+ else:
+ full_path = name
+
+ return f""
+ except Exception:
+ return f""
+ elif isinstance(item, dict):
+ return {k: serialize_item(v) for k, v in item.items()}
+ elif isinstance(item, (list, tuple)):
+ return [serialize_item(x) for x in item]
+ elif hasattr(item, "__dict__"):
+ try:
+ attrs = {
+ k: serialize_item(v)
+ for k, v in vars(item).items()
+ if not k.startswith("_")
+ }
+ return f"<{type(item).__name__} {attrs}>"
+ except Exception:
+ return repr(item)
+ else:
+ return item
+
+ try:
+ serialized = serialize_item(data)
+ return (
+ json.dumps(serialized, default=str)
+ if not isinstance(serialized, str)
+ else serialized
+ )
+ except Exception:
+ return str(data)
+
+
+def has_logs_enabled(options: "Optional[dict[str, Any]]") -> bool:
+ if options is None:
+ return False
+
+ return bool(
+ options.get("enable_logs", False)
+ or options["_experiments"].get("enable_logs", False)
+ )
+
+
+def get_before_send_log(
+ options: "Optional[dict[str, Any]]",
+) -> "Optional[Callable[[Log, Hint], Optional[Log]]]":
+ if options is None:
+ return None
+
+ return options.get("before_send_log") or options["_experiments"].get(
+ "before_send_log"
+ )
+
+
+def has_metrics_enabled(options: "Optional[dict[str, Any]]") -> bool:
+ if options is None:
+ return False
+
+ return bool(options.get("enable_metrics", True))
+
+
+def get_before_send_metric(
+ options: "Optional[dict[str, Any]]",
+) -> "Optional[Callable[[Metric, Hint], Optional[Metric]]]":
+ if options is None:
+ return None
+
+ return options.get("before_send_metric") or options["_experiments"].get(
+ "before_send_metric"
+ )
+
+
+def format_attribute(val: "Any") -> "AttributeValue":
+ """
+ Turn unsupported attribute value types into an AttributeValue.
+
+ We do this as soon as a user-provided attribute is set, to prevent spans,
+ logs, metrics and similar from having live references to various objects.
+
+ Note: This is not the final attribute value format. Before they're sent,
+ they're serialized further into the actual format the protocol expects:
+ https://develop.sentry.dev/sdk/telemetry/attributes/
+ """
+ if isinstance(val, (bool, int, float, str)):
+ return val
+
+ if isinstance(val, (list, tuple)) and not val:
+ return []
+ elif isinstance(val, list):
+ ty = type(val[0])
+ if ty in (str, int, float, bool) and all(type(v) is ty for v in val):
+ return copy.deepcopy(val)
+ elif isinstance(val, tuple):
+ ty = type(val[0])
+ if ty in (str, int, float, bool) and all(type(v) is ty for v in val):
+ return list(val)
+
+ return safe_repr(val)
+
+
+def serialize_attribute(val: "AttributeValue") -> "SerializedAttributeValue":
+ """Serialize attribute value to the transport format."""
+ if isinstance(val, bool):
+ return {"value": val, "type": "boolean"}
+ if isinstance(val, int):
+ return {"value": val, "type": "integer"}
+ if isinstance(val, float):
+ return {"value": val, "type": "double"}
+ if isinstance(val, str):
+ return {"value": val, "type": "string"}
+
+ if isinstance(val, list):
+ if not val:
+ return {"value": [], "type": "array"}
+
+ # Only lists of elements of a single type are supported
+ ty = type(val[0])
+ if ty in (int, str, bool, float) and all(type(v) is ty for v in val):
+ return {"value": val, "type": "array"}
+
+ # Coerce to string if we don't know what to do with the value. This should
+ # never happen as we pre-format early in format_attribute, but let's be safe.
+ return {"value": safe_repr(val), "type": "string"}
diff --git a/sentry_sdk/worker.py b/sentry_sdk/worker.py
index ca0ca28d94..7931f9c027 100644
--- a/sentry_sdk/worker.py
+++ b/sentry_sdk/worker.py
@@ -1,13 +1,14 @@
+from abc import ABC, abstractmethod
+import asyncio
import os
import threading
from time import sleep, time
-from sentry_sdk._compat import check_thread_support
from sentry_sdk._queue import Queue, FullError
-from sentry_sdk.utils import logger
+from sentry_sdk.utils import logger, mark_sentry_task_internal
from sentry_sdk.consts import DEFAULT_QUEUE_SIZE
-from sentry_sdk._types import TYPE_CHECKING
+from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
@@ -18,31 +19,57 @@
_TERMINATOR = object()
-class BackgroundWorker(object):
- def __init__(self, queue_size=DEFAULT_QUEUE_SIZE):
- # type: (int) -> None
- check_thread_support()
- self._queue = Queue(queue_size) # type: Queue
+class Worker(ABC):
+ """Base class for all workers."""
+
+ @property
+ @abstractmethod
+ def is_alive(self) -> bool:
+ """Whether the worker is alive and running."""
+ pass
+
+ @abstractmethod
+ def kill(self) -> None:
+ """Kill the worker. It will not process any more events."""
+ pass
+
+ def flush(
+ self, timeout: float, callback: "Optional[Callable[[int, float], Any]]" = None
+ ) -> None:
+ """Flush the worker, blocking until done or timeout is reached."""
+ return None
+
+ @abstractmethod
+ def full(self) -> bool:
+ """Whether the worker's queue is full."""
+ pass
+
+ @abstractmethod
+ def submit(self, callback: "Callable[[], Any]") -> bool:
+ """Schedule a callback. Returns True if queued, False if full."""
+ pass
+
+
+class BackgroundWorker(Worker):
+ def __init__(self, queue_size: int = DEFAULT_QUEUE_SIZE) -> None:
+ self._queue: "Queue" = Queue(queue_size)
self._lock = threading.Lock()
- self._thread = None # type: Optional[threading.Thread]
- self._thread_for_pid = None # type: Optional[int]
+ self._thread: "Optional[threading.Thread]" = None
+ self._thread_for_pid: "Optional[int]" = None
@property
- def is_alive(self):
- # type: () -> bool
+ def is_alive(self) -> bool:
if self._thread_for_pid != os.getpid():
return False
if not self._thread:
return False
return self._thread.is_alive()
- def _ensure_thread(self):
- # type: () -> None
+ def _ensure_thread(self) -> None:
if not self.is_alive:
self.start()
- def _timed_queue_join(self, timeout):
- # type: (float) -> bool
+ def _timed_queue_join(self, timeout: float) -> bool:
deadline = time() + timeout
queue = self._queue
@@ -59,19 +86,23 @@ def _timed_queue_join(self, timeout):
finally:
queue.all_tasks_done.release()
- def start(self):
- # type: () -> None
+ def start(self) -> None:
with self._lock:
if not self.is_alive:
self._thread = threading.Thread(
- target=self._target, name="raven-sentry.BackgroundWorker"
+ target=self._target, name="sentry-sdk.BackgroundWorker"
)
self._thread.daemon = True
- self._thread.start()
- self._thread_for_pid = os.getpid()
+ try:
+ self._thread.start()
+ self._thread_for_pid = os.getpid()
+ except RuntimeError:
+ # At this point we can no longer start because the interpreter
+ # is already shutting down. Sadly at this point we can no longer
+ # send out events.
+ self._thread = None
- def kill(self):
- # type: () -> None
+ def kill(self) -> None:
"""
Kill worker thread. Returns immediately. Not useful for
waiting on shutdown for events, use `flush` for that.
@@ -87,16 +118,17 @@ def kill(self):
self._thread = None
self._thread_for_pid = None
- def flush(self, timeout, callback=None):
- # type: (float, Optional[Any]) -> None
+ def flush(self, timeout: float, callback: "Optional[Any]" = None) -> None:
logger.debug("background worker got flush request")
with self._lock:
if self.is_alive and timeout > 0.0:
self._wait_flush(timeout, callback)
logger.debug("background worker flushed")
- def _wait_flush(self, timeout, callback):
- # type: (float, Optional[Any]) -> None
+ def full(self) -> bool:
+ return self._queue.full()
+
+ def _wait_flush(self, timeout: float, callback: "Optional[Any]") -> None:
initial_timeout = min(0.1, timeout)
if not self._timed_queue_join(initial_timeout):
pending = self._queue.qsize() + 1
@@ -108,8 +140,7 @@ def _wait_flush(self, timeout, callback):
pending = self._queue.qsize() + 1
logger.error("flush timed out, dropped %s events", pending)
- def submit(self, callback):
- # type: (Callable[[], None]) -> bool
+ def submit(self, callback: "Callable[[], Any]") -> bool:
self._ensure_thread()
try:
self._queue.put_nowait(callback)
@@ -117,8 +148,7 @@ def submit(self, callback):
except FullError:
return False
- def _target(self):
- # type: () -> None
+ def _target(self) -> None:
while True:
callback = self._queue.get()
try:
@@ -131,3 +161,151 @@ def _target(self):
finally:
self._queue.task_done()
sleep(0)
+
+
+class AsyncWorker(Worker):
+ def __init__(self, queue_size: int = DEFAULT_QUEUE_SIZE) -> None:
+ self._queue: "Optional[asyncio.Queue[Any]]" = None
+ self._queue_size = queue_size
+ self._task: "Optional[asyncio.Task[None]]" = None
+ # Event loop needs to remain in the same process
+ self._task_for_pid: "Optional[int]" = None
+ self._loop: "Optional[asyncio.AbstractEventLoop]" = None
+ # Track active callback tasks so they have a strong reference and can be cancelled on kill
+ self._active_tasks: "set[asyncio.Task[None]]" = set()
+
+ @property
+ def is_alive(self) -> bool:
+ if self._task_for_pid != os.getpid():
+ return False
+ if not self._task or not self._loop:
+ return False
+ return self._loop.is_running() and not self._task.done()
+
+ def kill(self) -> None:
+ if self._task:
+ # Cancel the main consumer task to prevent duplicate consumers
+ self._task.cancel()
+ # Also cancel any active callback tasks
+ # Avoid modifying the set while cancelling tasks
+ tasks_to_cancel = set(self._active_tasks)
+ for task in tasks_to_cancel:
+ task.cancel()
+ self._active_tasks.clear()
+ self._loop = None
+ self._task = None
+ self._task_for_pid = None
+
+ def start(self) -> None:
+ if not self.is_alive:
+ try:
+ self._loop = asyncio.get_running_loop()
+ # Always create a fresh queue on start to avoid stale items
+ self._queue = asyncio.Queue(maxsize=self._queue_size)
+ with mark_sentry_task_internal():
+ self._task = self._loop.create_task(self._target())
+ self._task_for_pid = os.getpid()
+ except RuntimeError:
+ # There is no event loop running
+ logger.warning("No event loop running, async worker not started")
+ self._loop = None
+ self._task = None
+ self._task_for_pid = None
+
+ def full(self) -> bool:
+ if self._queue is None:
+ return True
+ return self._queue.full()
+
+ def _ensure_task(self) -> None:
+ if not self.is_alive:
+ self.start()
+
+ async def _wait_flush(
+ self, timeout: float, callback: "Optional[Any]" = None
+ ) -> None:
+ if not self._loop or not self._loop.is_running() or self._queue is None:
+ return
+
+ initial_timeout = min(0.1, timeout)
+
+ # Timeout on the join
+ try:
+ await asyncio.wait_for(self._queue.join(), timeout=initial_timeout)
+ except asyncio.TimeoutError:
+ pending = self._queue.qsize() + len(self._active_tasks)
+ logger.debug("%d event(s) pending on flush", pending)
+ if callback is not None:
+ callback(pending, timeout)
+
+ try:
+ remaining_timeout = timeout - initial_timeout
+ await asyncio.wait_for(self._queue.join(), timeout=remaining_timeout)
+ except asyncio.TimeoutError:
+ pending = self._queue.qsize() + len(self._active_tasks)
+ logger.error("flush timed out, dropped %s events", pending)
+
+ def flush( # type: ignore[override]
+ self, timeout: float, callback: "Optional[Any]" = None
+ ) -> "Optional[asyncio.Task[None]]":
+ if self.is_alive and timeout > 0.0 and self._loop and self._loop.is_running():
+ with mark_sentry_task_internal():
+ return self._loop.create_task(self._wait_flush(timeout, callback))
+ return None
+
+ def submit(self, callback: "Callable[[], Any]") -> bool:
+ self._ensure_task()
+ if self._queue is None:
+ return False
+ try:
+ self._queue.put_nowait(callback)
+ return True
+ except asyncio.QueueFull:
+ return False
+
+ async def _target(self) -> None:
+ if self._queue is None:
+ return
+ try:
+ while True:
+ callback = await self._queue.get()
+ if callback is _TERMINATOR:
+ self._queue.task_done()
+ break
+ # Firing tasks instead of awaiting them allows for concurrent requests
+ with mark_sentry_task_internal():
+ task = asyncio.create_task(self._process_callback(callback))
+ # Create a strong reference to the task so it can be cancelled on kill
+ # and does not get garbage collected while running
+ self._active_tasks.add(task)
+ # Capture queue ref at dispatch time so done callbacks use the
+ # correct queue even if kill()/start() replace self._queue.
+ queue_ref = self._queue
+ task.add_done_callback(lambda t: self._on_task_complete(t, queue_ref))
+ # Yield to let the event loop run other tasks
+ await asyncio.sleep(0)
+ except asyncio.CancelledError:
+ pass # Expected during kill()
+
+ async def _process_callback(self, callback: "Callable[[], Any]") -> None:
+ # Callback is an async coroutine, need to await it
+ await callback()
+
+ def _on_task_complete(
+ self,
+ task: "asyncio.Task[None]",
+ queue: "Optional[asyncio.Queue[Any]]" = None,
+ ) -> None:
+ try:
+ task.result()
+ except asyncio.CancelledError:
+ pass # Task was cancelled, expected during shutdown
+ except Exception:
+ logger.error("Failed processing job", exc_info=True)
+ finally:
+ # Mark the task as done and remove it from the active tasks set
+ # Use the queue reference captured at dispatch time, not self._queue,
+ # to avoid calling task_done() on a different queue after kill()/start().
+ if queue is not None:
+ queue.task_done()
+ self._active_tasks.discard(task)
diff --git a/setup.py b/setup.py
index 7aa4430080..3942ee630e 100644
--- a/setup.py
+++ b/setup.py
@@ -21,7 +21,7 @@ def get_file_text(file_name):
setup(
name="sentry-sdk",
- version="1.19.1",
+ version="2.58.0",
author="Sentry Team and Contributors",
author_email="hello@sentry.io",
url="https://github.com/getsentry/sentry-python",
@@ -36,56 +36,81 @@ def get_file_text(file_name):
# PEP 561
package_data={"sentry_sdk": ["py.typed"]},
zip_safe=False,
- license="MIT",
+ license_expression="MIT",
+ python_requires=">=3.6",
install_requires=[
- 'urllib3>=1.25.7; python_version<="3.4"',
- 'urllib3>=1.26.9; python_version=="3.5"',
- 'urllib3>=1.26.11; python_version >="3.6"',
+ "urllib3>=1.26.11",
"certifi",
],
extras_require={
- "flask": ["flask>=0.11", "blinker>=1.1"],
- "quart": ["quart>=0.16.1", "blinker>=1.1"],
+ "aiohttp": ["aiohttp>=3.5"],
+ "anthropic": ["anthropic>=0.16"],
+ "arq": ["arq>=0.23"],
+ "asyncpg": ["asyncpg>=0.23"],
+ "beam": ["apache-beam>=2.12"],
"bottle": ["bottle>=0.12.13"],
- "falcon": ["falcon>=1.4"],
- "django": ["django>=1.8"],
- "sanic": ["sanic>=0.8"],
"celery": ["celery>=3"],
+ "celery-redbeat": ["celery-redbeat>=2"],
+ "chalice": ["chalice>=1.16.0"],
+ "clickhouse-driver": ["clickhouse-driver>=0.2.0"],
+ "django": ["django>=1.8"],
+ "falcon": ["falcon>=1.4"],
+ "fastapi": ["fastapi>=0.79.0"],
+ "flask": ["flask>=0.11", "blinker>=1.1", "markupsafe"],
+ "grpcio": ["grpcio>=1.21.1", "protobuf>=3.8.0"],
+ "http2": ["httpcore[http2]==1.*"],
+ "asyncio": ["httpcore[asyncio]==1.*"],
+ "httpx": ["httpx>=0.16.0"],
"huey": ["huey>=2"],
- "beam": ["apache-beam>=2.12"],
- "arq": ["arq>=0.23"],
+ "huggingface_hub": ["huggingface_hub>=0.22"],
+ "langchain": ["langchain>=0.0.210"],
+ "langgraph": ["langgraph>=0.6.6"],
+ "launchdarkly": ["launchdarkly-server-sdk>=9.8.0"],
+ "litellm": ["litellm>=1.77.5,!=1.82.7,!=1.82.8"],
+ "litestar": ["litestar>=2.0.0"],
+ "loguru": ["loguru>=0.5"],
+ "mcp": ["mcp>=1.15.0"],
+ "openai": ["openai>=1.0.0", "tiktoken>=0.3.0"],
+ "openfeature": ["openfeature-sdk>=0.7.1"],
+ "opentelemetry": ["opentelemetry-distro>=0.35b0"],
+ "opentelemetry-experimental": ["opentelemetry-distro"],
+ "opentelemetry-otlp": ["opentelemetry-distro[otlp]>=0.35b0"],
+ "pure-eval": ["pure_eval", "executing", "asttokens"],
+ "pydantic_ai": ["pydantic-ai>=1.0.0"],
+ "pymongo": ["pymongo>=3.1"],
+ "pyspark": ["pyspark>=2.4.4"],
+ "quart": ["quart>=0.16.1", "blinker>=1.1"],
"rq": ["rq>=0.6"],
- "aiohttp": ["aiohttp>=3.5"],
- "tornado": ["tornado>=5"],
+ "sanic": ["sanic>=0.8"],
"sqlalchemy": ["sqlalchemy>=1.2"],
- "pyspark": ["pyspark>=2.4.4"],
- "pure_eval": ["pure_eval", "executing", "asttokens"],
- "chalice": ["chalice>=1.16.0"],
- "httpx": ["httpx>=0.16.0"],
"starlette": ["starlette>=0.19.1"],
"starlite": ["starlite>=1.48"],
- "fastapi": ["fastapi>=0.79.0"],
- "pymongo": ["pymongo>=3.1"],
- "opentelemetry": ["opentelemetry-distro>=0.35b0"],
- "grpcio": ["grpcio>=1.21.1"]
+ "statsig": ["statsig>=0.55.3"],
+ "tornado": ["tornado>=6"],
+ "unleash": ["UnleashClient>=6.0.1"],
+ "google-genai": ["google-genai>=1.29.0"],
+ },
+ entry_points={
+ "opentelemetry_propagator": [
+ "sentry=sentry_sdk.integrations.opentelemetry:SentryPropagator"
+ ]
},
classifiers=[
"Development Status :: 5 - Production/Stable",
"Environment :: Web Environment",
"Intended Audience :: Developers",
- "License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
- "Programming Language :: Python :: 2",
- "Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.4",
- "Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
"Topic :: Software Development :: Libraries :: Python Modules",
],
options={"bdist_wheel": {"universal": "1"}},
diff --git a/test-requirements.txt b/test-requirements.txt
deleted file mode 100644
index 5d449df716..0000000000
--- a/test-requirements.txt
+++ /dev/null
@@ -1,15 +0,0 @@
-pip # always use newest pip
-mock # for testing under python < 3.3
-pytest<7
-pytest-cov==2.8.1
-pytest-forked<=1.4.0
-pytest-localserver==0.5.0
-pytest-watch==4.2.0
-tox==3.7.0
-Werkzeug<2.1.0
-jsonschema==3.2.0
-pyrsistent==0.16.0 # TODO(py3): 0.17.0 requires python3, see https://github.com/tobgu/pyrsistent/issues/205
-executing
-asttokens
-responses
-ipdb
diff --git a/tests/__init__.py b/tests/__init__.py
index cac15f9333..2e4df719d5 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,6 +1,5 @@
import sys
-
-import pytest
+import warnings
# This is used in _capture_internal_warnings. We need to run this at import
# time because that's where many deprecation warnings might get thrown.
@@ -9,5 +8,5 @@
# gets loaded too late.
assert "sentry_sdk" not in sys.modules
-_warning_recorder_mgr = pytest.warns(None)
+_warning_recorder_mgr = warnings.catch_warnings(record=True)
_warning_recorder = _warning_recorder_mgr.__enter__()
diff --git a/tests/conftest.py b/tests/conftest.py
index 618f60d282..4e4943ba85 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,11 +1,31 @@
import json
import os
+import asyncio
+from urllib.parse import urlparse, parse_qs
import socket
+import warnings
+import brotli
+import gzip
+import io
+from dataclasses import dataclass
from threading import Thread
+from contextlib import contextmanager
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from unittest import mock
+from collections import namedtuple
import pytest
+from pytest_localserver.http import WSGIServer
+from werkzeug.wrappers import Request, Response
import jsonschema
+try:
+ from starlette.testclient import TestClient
+ # Catch RuntimeError to prevent the following exception in aws_lambda tests.
+ # RuntimeError: The starlette.testclient module requires the httpx package to be installed.
+except (ImportError, RuntimeError):
+ TestClient = None
+
try:
import gevent
except ImportError:
@@ -16,27 +36,79 @@
except ImportError:
eventlet = None
+import sentry_sdk
+import sentry_sdk.utils
+from sentry_sdk.envelope import Envelope, parse_json
+from sentry_sdk.integrations import ( # noqa: F401
+ _DEFAULT_INTEGRATIONS,
+ _installed_integrations,
+ _processed_integrations,
+)
+from sentry_sdk.profiler import teardown_profiler
+from sentry_sdk.profiler.continuous_profiler import teardown_continuous_profiler
+from sentry_sdk.transport import Transport
+from sentry_sdk.utils import reraise
+
try:
- # Python 2
- import BaseHTTPServer
+ import openai
+except ImportError:
+ openai = None
- HTTPServer = BaseHTTPServer.HTTPServer
- BaseHTTPRequestHandler = BaseHTTPServer.BaseHTTPRequestHandler
-except Exception:
- # Python 3
- from http.server import BaseHTTPRequestHandler, HTTPServer
+try:
+ import anthropic
+except ImportError:
+ anthropic = None
+
+
+try:
+ import google
+except ImportError:
+ google = None
-import sentry_sdk
-from sentry_sdk._compat import iteritems, reraise, string_types
-from sentry_sdk.envelope import Envelope
-from sentry_sdk.integrations import _installed_integrations # noqa: F401
-from sentry_sdk.profiler import teardown_profiler
-from sentry_sdk.transport import Transport
-from sentry_sdk.utils import capture_internal_exceptions
from tests import _warning_recorder, _warning_recorder_mgr
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from typing import Any, Callable, MutableMapping, Optional
+ from collections.abc import Iterator
+
+try:
+ from httpx import (
+ ASGITransport,
+ Request as HttpxRequest,
+ Response as HttpxResponse,
+ AsyncByteStream,
+ AsyncClient,
+ )
+except ImportError:
+ ASGITransport = None
+ HttpxRequest = None
+ HttpxResponse = None
+ AsyncByteStream = None
+ AsyncClient = None
+
+
+try:
+ from anyio import create_memory_object_stream, create_task_group, EndOfStream
+ from mcp.types import (
+ JSONRPCMessage,
+ JSONRPCNotification,
+ JSONRPCRequest,
+ )
+ from mcp.shared.message import SessionMessage
+except ImportError:
+ create_memory_object_stream = None
+ create_task_group = None
+ EndOfStream = None
+
+ JSONRPCMessage = None
+ JSONRPCNotification = None
+ JSONRPCRequest = None
+ SessionMessage = None
+
SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json"
@@ -46,37 +118,37 @@
with open(SENTRY_EVENT_SCHEMA) as f:
SENTRY_EVENT_SCHEMA = json.load(f)
-try:
- import pytest_benchmark
-except ImportError:
- @pytest.fixture
- def benchmark():
- return lambda x: x()
+from sentry_sdk import scope
-else:
- del pytest_benchmark
+
+@pytest.fixture(autouse=True)
+def clean_scopes():
+ """
+ Resets the scopes for every test to avoid leaking data between tests.
+ """
+ scope._global_scope = None
+ scope._isolation_scope.set(None)
+ scope._current_scope.set(None)
@pytest.fixture(autouse=True)
-def internal_exceptions(request, monkeypatch):
+def internal_exceptions(request):
errors = []
if "tests_internal_exceptions" in request.keywords:
return
- def _capture_internal_exception(self, exc_info):
+ def _capture_internal_exception(exc_info):
errors.append(exc_info)
@request.addfinalizer
def _():
- # rerasise the errors so that this just acts as a pass-through (that
+ # reraise the errors so that this just acts as a pass-through (that
# happens to keep track of the errors which pass through it)
for e in errors:
reraise(*e)
- monkeypatch.setattr(
- sentry_sdk.Hub, "_capture_internal_exception", _capture_internal_exception
- )
+ sentry_sdk.utils.capture_internal_exception = _capture_internal_exception
return errors
@@ -142,35 +214,6 @@ def _capture_internal_warnings():
raise AssertionError(warning)
-@pytest.fixture
-def monkeypatch_test_transport(monkeypatch, validate_event_schema):
- def check_event(event):
- def check_string_keys(map):
- for key, value in iteritems(map):
- assert isinstance(key, string_types)
- if isinstance(value, dict):
- check_string_keys(value)
-
- with capture_internal_exceptions():
- check_string_keys(event)
- validate_event_schema(event)
-
- def check_envelope(envelope):
- with capture_internal_exceptions():
- # Assert error events are sent without envelope to server, for compat.
- # This does not apply if any item in the envelope is an attachment.
- if not any(x.type == "attachment" for x in envelope.items):
- assert not any(item.data_category == "error" for item in envelope.items)
- assert not any(item.get_event() is not None for item in envelope.items)
-
- def inner(client):
- monkeypatch.setattr(
- client, "transport", TestTransport(check_event, check_envelope)
- )
-
- return inner
-
-
@pytest.fixture
def validate_event_schema(tmpdir):
def inner(event):
@@ -187,18 +230,34 @@ def reset_integrations():
with a clean slate to ensure monkeypatching works well,
but this also means some other stuff will be monkeypatched twice.
"""
- global _installed_integrations
+ global _DEFAULT_INTEGRATIONS, _processed_integrations
+ try:
+ _DEFAULT_INTEGRATIONS.remove(
+ "sentry_sdk.integrations.opentelemetry.integration.OpenTelemetryIntegration"
+ )
+ except ValueError:
+ pass
+ _processed_integrations.clear()
_installed_integrations.clear()
@pytest.fixture
-def sentry_init(monkeypatch_test_transport, request):
+def uninstall_integration():
+ """Use to force the next call to sentry_init to re-install/setup an integration."""
+
+ def inner(identifier):
+ _processed_integrations.discard(identifier)
+ _installed_integrations.discard(identifier)
+
+ return inner
+
+
+@pytest.fixture
+def sentry_init(request):
def inner(*a, **kw):
- hub = sentry_sdk.Hub.current
+ kw.setdefault("transport", TestTransport())
client = sentry_sdk.Client(*a, **kw)
- hub.bind_client(client)
- if "transport" not in kw:
- monkeypatch_test_transport(sentry_sdk.Hub.current.client)
+ sentry_sdk.get_global_scope().set_client(client)
if request.node.get_closest_marker("forked"):
# Do not run isolation if the test is already running in
@@ -206,38 +265,51 @@ def inner(*a, **kw):
# fork)
yield inner
else:
- with sentry_sdk.Hub(None):
+ old_client = sentry_sdk.get_global_scope().client
+ try:
+ sentry_sdk.get_current_scope().set_client(None)
yield inner
+ finally:
+ sentry_sdk.get_global_scope().set_client(old_client)
class TestTransport(Transport):
- def __init__(self, capture_event_callback, capture_envelope_callback):
+ def __init__(self):
Transport.__init__(self)
- self.capture_event = capture_event_callback
- self.capture_envelope = capture_envelope_callback
- self._queue = None
+
+ def capture_envelope(self, _: Envelope) -> None:
+ """No-op capture_envelope for tests"""
+ pass
+
+
+class TestTransportWithOptions(Transport):
+ """TestTransport above does not pass in the options and for some tests we need them"""
+
+ __test__ = False
+
+ def __init__(self, options=None):
+ Transport.__init__(self, options)
+
+ def capture_envelope(self, _: Envelope) -> None:
+ """No-op capture_envelope for tests"""
+ pass
@pytest.fixture
def capture_events(monkeypatch):
def inner():
events = []
- test_client = sentry_sdk.Hub.current.client
- old_capture_event = test_client.transport.capture_event
+ test_client = sentry_sdk.get_client()
old_capture_envelope = test_client.transport.capture_envelope
- def append_event(event):
- events.append(event)
- return old_capture_event(event)
-
- def append_envelope(envelope):
+ def append_event(envelope):
for item in envelope:
if item.headers.get("type") in ("event", "transaction"):
- test_client.transport.capture_event(item.payload.json)
+ events.append(item.payload.json)
return old_capture_envelope(envelope)
- monkeypatch.setattr(test_client.transport, "capture_event", append_event)
- monkeypatch.setattr(test_client.transport, "capture_envelope", append_envelope)
+ monkeypatch.setattr(test_client.transport, "capture_envelope", append_event)
+
return events
return inner
@@ -247,42 +319,79 @@ def append_envelope(envelope):
def capture_envelopes(monkeypatch):
def inner():
envelopes = []
- test_client = sentry_sdk.Hub.current.client
- old_capture_event = test_client.transport.capture_event
+ test_client = sentry_sdk.get_client()
old_capture_envelope = test_client.transport.capture_envelope
- def append_event(event):
- envelope = Envelope()
- envelope.add_event(event)
- envelopes.append(envelope)
- return old_capture_event(event)
-
def append_envelope(envelope):
envelopes.append(envelope)
return old_capture_envelope(envelope)
- monkeypatch.setattr(test_client.transport, "capture_event", append_event)
monkeypatch.setattr(test_client.transport, "capture_envelope", append_envelope)
+
return envelopes
return inner
+@dataclass
+class UnwrappedItem:
+ type: str
+ payload: dict
+
+
+@pytest.fixture
+def capture_items(monkeypatch):
+ """
+ Capture envelope payload, unfurling individual items.
+
+ Makes it easier to work with both events and attribute-based telemetry in
+ one test.
+ """
+
+ def inner(*types):
+ telemetry = []
+ test_client = sentry_sdk.get_client()
+ old_capture_envelope = test_client.transport.capture_envelope
+
+ def append_envelope(envelope):
+ for item in envelope:
+ if types and item.type not in types:
+ continue
+
+ if item.type in ("metric", "log", "span"):
+ for i in item.payload.json["items"]:
+ t = {k: v for k, v in i.items() if k != "attributes"}
+ t["attributes"] = {
+ k: v["value"] for k, v in i["attributes"].items()
+ }
+ telemetry.append(UnwrappedItem(type=item.type, payload=t))
+ else:
+ telemetry.append(
+ UnwrappedItem(type=item.type, payload=item.payload.json)
+ )
+
+ return old_capture_envelope(envelope)
+
+ monkeypatch.setattr(test_client.transport, "capture_envelope", append_envelope)
+
+ return telemetry
+
+ return inner
+
+
@pytest.fixture
-def capture_client_reports(monkeypatch):
+def capture_record_lost_event_calls(monkeypatch):
def inner():
- reports = []
- test_client = sentry_sdk.Hub.current.client
+ calls = []
+ test_client = sentry_sdk.get_client()
- def record_lost_event(reason, data_category=None, item=None):
- if data_category is None:
- data_category = item.data_category
- return reports.append((reason, data_category))
+ def record_lost_event(reason, data_category=None, item=None, *, quantity=1):
+ calls.append((reason, data_category, item, quantity))
monkeypatch.setattr(
test_client.transport, "record_lost_event", record_lost_event
)
- return reports
+ return calls
return inner
@@ -296,19 +405,21 @@ def inner():
events_r = os.fdopen(events_r, "rb", 0)
events_w = os.fdopen(events_w, "wb", 0)
- test_client = sentry_sdk.Hub.current.client
+ test_client = sentry_sdk.get_client()
- old_capture_event = test_client.transport.capture_event
+ old_capture_envelope = test_client.transport.capture_envelope
- def append(event):
- events_w.write(json.dumps(event).encode("utf-8"))
- events_w.write(b"\n")
- return old_capture_event(event)
+ def append(envelope):
+ event = envelope.get_event() or envelope.get_transaction_event()
+ if event is not None:
+ events_w.write(json.dumps(event).encode("utf-8"))
+ events_w.write(b"\n")
+ return old_capture_envelope(envelope)
def flush(timeout=None, callback=None):
events_w.write(b"flush\n")
- monkeypatch.setattr(test_client.transport, "capture_event", append)
+ monkeypatch.setattr(test_client.transport, "capture_envelope", append)
monkeypatch.setattr(test_client, "flush", flush)
return EventStreamReader(events_r, events_w)
@@ -316,7 +427,7 @@ def flush(timeout=None, callback=None):
return inner
-class EventStreamReader(object):
+class EventStreamReader:
def __init__(self, read_file, write_file):
self.read_file = read_file
self.write_file = write_file
@@ -382,7 +493,6 @@ def render_span(span):
root_span = event["contexts"]["trace"]
- # Return a list instead of a multiline string because black will know better how to format that
return "\n".join(render_span(root_span))
return inner
@@ -408,16 +518,10 @@ def string_containing_matcher():
"""
- class StringContaining(object):
+ class StringContaining:
def __init__(self, substring):
self.substring = substring
-
- try:
- # the `unicode` type only exists in python 2, so if this blows up,
- # we must be in py3 and have the `bytes` type
- self.valid_types = (str, unicode)
- except NameError:
- self.valid_types = (str, bytes)
+ self.valid_types = (str, bytes)
def __eq__(self, test_string):
if not isinstance(test_string, self.valid_types):
@@ -491,7 +595,7 @@ def dictionary_containing_matcher():
>>> f.assert_any_call(DictionaryContaining({"dogs": "yes"})) # no AssertionError
"""
- class DictionaryContaining(object):
+ class DictionaryContaining:
def __init__(self, subdict):
self.subdict = subdict
@@ -531,7 +635,7 @@ def object_described_by_matcher():
Used like this:
- >>> class Dog(object):
+ >>> class Dog:
... pass
...
>>> maisey = Dog()
@@ -543,7 +647,7 @@ def object_described_by_matcher():
>>> f.assert_any_call(ObjectDescribedBy(attrs={"name": "Maisey"})) # no AssertionError
"""
- class ObjectDescribedBy(object):
+ class ObjectDescribedBy:
def __init__(self, type=None, attrs=None):
self.type = type
self.attrs = attrs
@@ -573,14 +677,791 @@ def __ne__(self, test_obj):
@pytest.fixture
def teardown_profiling():
+ # Make sure that a previous test didn't leave the profiler running
+ teardown_profiler()
+ teardown_continuous_profiler()
+
yield
+
+ # Make sure that to shut down the profiler after the test
teardown_profiler()
+ teardown_continuous_profiler()
+
+
+@pytest.fixture()
+def suppress_deprecation_warnings():
+ """
+ Use this fixture to suppress deprecation warnings in a test.
+ Useful for testing deprecated SDK features.
+ """
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ yield
+
+
+@pytest.fixture
+def get_initialization_payload():
+ def inner(request_id: str):
+ return SessionMessage( # type: ignore
+ message=JSONRPCMessage( # type: ignore
+ root=JSONRPCRequest( # type: ignore
+ jsonrpc="2.0",
+ id=request_id,
+ method="initialize",
+ params={
+ "protocolVersion": "2025-11-25",
+ "capabilities": {},
+ "clientInfo": {"name": "test-client", "version": "1.0.0"},
+ },
+ )
+ )
+ )
+
+ return inner
+
+
+@pytest.fixture
+def get_initialized_notification_payload():
+ def inner():
+ return SessionMessage( # type: ignore
+ message=JSONRPCMessage( # type: ignore
+ root=JSONRPCNotification( # type: ignore
+ jsonrpc="2.0",
+ method="notifications/initialized",
+ )
+ )
+ )
+
+ return inner
+
+
+@pytest.fixture
+def get_mcp_command_payload():
+ def inner(method: str, params, request_id: str):
+ return SessionMessage( # type: ignore
+ message=JSONRPCMessage( # type: ignore
+ root=JSONRPCRequest( # type: ignore
+ jsonrpc="2.0",
+ id=request_id,
+ method=method,
+ params=params,
+ )
+ )
+ )
+
+ return inner
+
+
+@pytest.fixture
+def stdio(
+ get_initialization_payload,
+ get_initialized_notification_payload,
+ get_mcp_command_payload,
+):
+ async def inner(server, method: str, params, request_id: str | None = None):
+ if request_id is None:
+ request_id = "1"
+
+ read_stream_writer, read_stream = create_memory_object_stream(0) # type: ignore
+ write_stream, write_stream_reader = create_memory_object_stream(0) # type: ignore
+
+ result = {}
+
+ async def run_server():
+ await server.run(
+ read_stream, write_stream, server.create_initialization_options()
+ )
+
+ async def simulate_client(tg, result):
+ init_request = get_initialization_payload("1")
+ await read_stream_writer.send(init_request)
+
+ await write_stream_reader.receive()
+
+ initialized_notification = get_initialized_notification_payload()
+ await read_stream_writer.send(initialized_notification)
+
+ request = get_mcp_command_payload(
+ method, params=params, request_id=request_id
+ )
+ await read_stream_writer.send(request)
+
+ result["response"] = await write_stream_reader.receive()
+
+ tg.cancel_scope.cancel()
+
+ async with create_task_group() as tg: # type: ignore
+ tg.start_soon(run_server)
+ tg.start_soon(simulate_client, tg, result)
+
+ return result["response"]
+
+ return inner
+
+
+@pytest.fixture()
+def json_rpc():
+ def inner(app, method: str, params, request_id: str):
+ with TestClient(app) as client: # type: ignore
+ init_response = client.post(
+ "/mcp/",
+ headers={
+ "Accept": "application/json, text/event-stream",
+ "Content-Type": "application/json",
+ },
+ json={
+ "jsonrpc": "2.0",
+ "method": "initialize",
+ "params": {
+ "clientInfo": {"name": "test-client", "version": "1.0"},
+ "protocolVersion": "2025-11-25",
+ "capabilities": {},
+ },
+ "id": request_id,
+ },
+ )
+
+ session_id = init_response.headers["mcp-session-id"]
+
+ # Notification response is mandatory.
+ # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
+ client.post(
+ "/mcp/",
+ headers={
+ "Accept": "application/json, text/event-stream",
+ "Content-Type": "application/json",
+ "mcp-session-id": session_id,
+ },
+ json={
+ "jsonrpc": "2.0",
+ "method": "notifications/initialized",
+ "params": {},
+ },
+ )
+
+ response = client.post(
+ "/mcp/",
+ headers={
+ "Accept": "application/json, text/event-stream",
+ "Content-Type": "application/json",
+ "mcp-session-id": session_id,
+ },
+ json={
+ "jsonrpc": "2.0",
+ "method": method,
+ "params": params,
+ "id": request_id,
+ },
+ )
+
+ return session_id, response
+
+ return inner
+
+
+@pytest.fixture()
+def select_mcp_transactions():
+ def inner(events):
+ return [
+ event
+ for event in events
+ if event["type"] == "transaction"
+ and event["contexts"]["trace"]["op"] == "mcp.server"
+ ]
+
+ return inner
+
+
+@pytest.fixture()
+def select_transactions_with_mcp_spans():
+ def inner(events, method_name):
+ return [
+ transaction
+ for transaction in events
+ if transaction["type"] == "transaction"
+ and any(
+ span["data"].get("mcp.method.name") == method_name
+ for span in transaction.get("spans", [])
+ )
+ ]
+
+ return inner
+
+
+@pytest.fixture()
+def json_rpc_sse():
+ class StreamingASGITransport(ASGITransport):
+ """
+ Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
+ tests involving SSE interactions to run in-process.
+ """
+
+ def __init__(
+ self,
+ app: "Callable",
+ keep_sse_alive: "asyncio.Event",
+ ) -> None:
+ self.keep_sse_alive = keep_sse_alive
+ super().__init__(app)
+
+ async def handle_async_request(
+ self, request: "HttpxRequest"
+ ) -> "HttpxResponse":
+ scope = {
+ "type": "http",
+ "method": request.method,
+ "headers": [(k.lower(), v) for (k, v) in request.headers.raw],
+ "path": request.url.path,
+ "query_string": request.url.query,
+ }
+
+ is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
+ if not is_streaming_sse:
+ return await super().handle_async_request(request)
+
+ request_body = b""
+ if request.content:
+ request_body = await request.aread()
+
+ body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore
+
+ async def receive() -> "dict[str, Any]":
+ if self.keep_sse_alive.is_set():
+ return {"type": "http.disconnect"}
+
+ await self.keep_sse_alive.wait() # Keep alive :)
+ return {
+ "type": "http.request",
+ "body": request_body,
+ "more_body": False,
+ }
+
+ async def send(message: "MutableMapping[str, Any]") -> None:
+ if message["type"] == "http.response.body":
+ body = message.get("body", b"")
+ more_body = message.get("more_body", False)
+
+ if body == b"" and not more_body:
+ return
+
+ if body:
+ await body_sender.send(body)
+
+ if not more_body:
+ await body_sender.aclose()
+
+ async def run_app():
+ await self.app(scope, receive, send)
+
+ class StreamingBodyStream(AsyncByteStream): # type: ignore
+ def __init__(self, receiver):
+ self.receiver = receiver
+
+ async def __aiter__(self):
+ try:
+ async for chunk in self.receiver:
+ yield chunk
+ except EndOfStream: # type: ignore
+ pass
+
+ stream = StreamingBodyStream(body_receiver)
+ response = HttpxResponse(status_code=200, headers=[], stream=stream) # type: ignore
+
+ asyncio.create_task(run_app())
+ return response
+
+ def parse_sse_data_package(sse_chunk):
+ sse_text = sse_chunk.decode("utf-8")
+ json_str = sse_text.split("data: ")[1]
+ return json.loads(json_str)
+
+ async def inner(
+ app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
+ ):
+ context = {}
+
+ stream_complete = asyncio.Event()
+ endpoint_parsed = asyncio.Event()
+
+ # https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
+ async with AsyncClient( # type: ignore
+ transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
+ base_url="http://test",
+ ) as client:
+
+ async def parse_stream():
+ async with client.stream("GET", "/sse") as stream:
+ # Read directly from stream.stream instead of aiter_bytes()
+ async for chunk in stream.stream:
+ if b"event: endpoint" in chunk:
+ sse_text = chunk.decode("utf-8")
+ url = sse_text.split("data: ")[1]
+
+ parsed = urlparse(url)
+ query_params = parse_qs(parsed.query)
+ context["session_id"] = query_params["session_id"][0]
+ endpoint_parsed.set()
+ continue
+
+ if b"event: message" in chunk and b"structuredContent" in chunk:
+ context["response"] = parse_sse_data_package(chunk)
+ break
+ elif (
+ "result" in parse_sse_data_package(chunk)
+ and "content" in parse_sse_data_package(chunk)["result"]
+ ):
+ context["response"] = parse_sse_data_package(chunk)
+ break
+
+ stream_complete.set()
+
+ task = asyncio.create_task(parse_stream())
+ await endpoint_parsed.wait()
+
+ await client.post(
+ f"/messages/?session_id={context['session_id']}",
+ headers={
+ "Content-Type": "application/json",
+ },
+ json={
+ "jsonrpc": "2.0",
+ "method": "initialize",
+ "params": {
+ "clientInfo": {"name": "test-client", "version": "1.0"},
+ "protocolVersion": "2025-11-25",
+ "capabilities": {},
+ },
+ "id": request_id,
+ },
+ )
+
+ # Notification response is mandatory.
+ # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
+ await client.post(
+ f"/messages/?session_id={context['session_id']}",
+ headers={
+ "Content-Type": "application/json",
+ "mcp-session-id": context["session_id"],
+ },
+ json={
+ "jsonrpc": "2.0",
+ "method": "notifications/initialized",
+ "params": {},
+ },
+ )
+
+ await client.post(
+ f"/messages/?session_id={context['session_id']}",
+ headers={
+ "Content-Type": "application/json",
+ "mcp-session-id": context["session_id"],
+ },
+ json={
+ "jsonrpc": "2.0",
+ "method": method,
+ "params": params,
+ "id": request_id,
+ },
+ )
+
+ await stream_complete.wait()
+ keep_sse_alive.set()
+
+ return task, context["session_id"], context["response"]
+
+ return inner
+
+
+@pytest.fixture()
+def async_iterator():
+ async def inner(values):
+ for value in values:
+ yield value
+
+ return inner
+
+
+@pytest.fixture
+def server_side_event_chunks():
+ def inner(events, include_event_type=True):
+ for event in events:
+ payload = event.model_dump()
+ chunk = (
+ f"event: {payload['type']}\ndata: {json.dumps(payload)}\n\n"
+ if include_event_type
+ else f"data: {json.dumps(payload)}\n\n"
+ )
+ yield chunk.encode("utf-8")
+
+ return inner
+
+
+@pytest.fixture
+def get_model_response():
+ def inner(response_content, serialize_pydantic=False, request_headers=None):
+ if request_headers is None:
+ request_headers = {}
+
+ model_request = HttpxRequest(
+ "POST",
+ "/responses",
+ headers=request_headers,
+ )
+
+ if serialize_pydantic:
+ response_content = json.dumps(
+ response_content.model_dump(
+ by_alias=True,
+ exclude_none=True,
+ )
+ ).encode("utf-8")
+
+ response = HttpxResponse(
+ 200,
+ request=model_request,
+ content=response_content,
+ )
+
+ return response
+
+ return inner
+
+
+@pytest.fixture
+def get_rate_limit_model_response():
+ def inner(request_headers=None):
+ if request_headers is None:
+ request_headers = {}
+
+ model_request = HttpxRequest(
+ "POST",
+ "/responses",
+ headers=request_headers,
+ )
+
+ response = HttpxResponse(
+ 429,
+ request=model_request,
+ )
+
+ return response
+
+ return inner
+
+
+@pytest.fixture
+def streaming_chat_completions_model_response():
+ return [
+ openai.types.chat.ChatCompletionChunk(
+ id="chatcmpl-test",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ openai.types.chat.chat_completion_chunk.Choice(
+ index=0,
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
+ role="assistant"
+ ),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ openai.types.chat.ChatCompletionChunk(
+ id="chatcmpl-test",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ openai.types.chat.chat_completion_chunk.Choice(
+ index=0,
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
+ content="Tes"
+ ),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ openai.types.chat.ChatCompletionChunk(
+ id="chatcmpl-test",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ openai.types.chat.chat_completion_chunk.Choice(
+ index=0,
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
+ content="t r"
+ ),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ openai.types.chat.ChatCompletionChunk(
+ id="chatcmpl-test",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ openai.types.chat.chat_completion_chunk.Choice(
+ index=0,
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
+ content="esp"
+ ),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ openai.types.chat.ChatCompletionChunk(
+ id="chatcmpl-test",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ openai.types.chat.chat_completion_chunk.Choice(
+ index=0,
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
+ content="ons"
+ ),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ openai.types.chat.ChatCompletionChunk(
+ id="chatcmpl-test",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ openai.types.chat.chat_completion_chunk.Choice(
+ index=0,
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
+ content="e"
+ ),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ openai.types.chat.ChatCompletionChunk(
+ id="chatcmpl-test",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ openai.types.chat.chat_completion_chunk.Choice(
+ index=0,
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(),
+ finish_reason="stop",
+ ),
+ ],
+ usage=openai.types.CompletionUsage(
+ prompt_tokens=10,
+ completion_tokens=20,
+ total_tokens=30,
+ ),
+ ),
+ ]
+
+
+@pytest.fixture
+def nonstreaming_chat_completions_model_response():
+ return openai.types.chat.ChatCompletion(
+ id="chatcmpl-test",
+ choices=[
+ openai.types.chat.chat_completion.Choice(
+ index=0,
+ finish_reason="stop",
+ message=openai.types.chat.ChatCompletionMessage(
+ role="assistant", content="Test response"
+ ),
+ )
+ ],
+ created=1234567890,
+ model="gpt-3.5-turbo",
+ object="chat.completion",
+ usage=openai.types.CompletionUsage(
+ prompt_tokens=10,
+ completion_tokens=20,
+ total_tokens=30,
+ ),
+ )
+
+
+@pytest.fixture
+def openai_embedding_model_response():
+ return openai.types.CreateEmbeddingResponse(
+ data=[
+ openai.types.Embedding(
+ embedding=[0.1, 0.2, 0.3],
+ index=0,
+ object="embedding",
+ )
+ ],
+ model="text-embedding-ada-002",
+ object="list",
+ usage=openai.types.create_embedding_response.Usage(
+ prompt_tokens=5,
+ total_tokens=5,
+ ),
+ )
+
+
+@pytest.fixture
+def nonstreaming_responses_model_response():
+ return openai.types.responses.Response(
+ id="resp_123",
+ output=[
+ openai.types.responses.ResponseOutputMessage(
+ id="msg_123",
+ type="message",
+ status="completed",
+ content=[
+ openai.types.responses.ResponseOutputText(
+ text="Hello, how can I help you?",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4",
+ object="response",
+ usage=openai.types.responses.ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=openai.types.responses.response_usage.InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=20,
+ output_tokens_details=openai.types.responses.response_usage.OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=30,
+ ),
+ )
+
+
+@pytest.fixture
+def nonstreaming_anthropic_model_response():
+ return anthropic.types.Message(
+ id="msg_123",
+ type="message",
+ role="assistant",
+ model="claude-3-opus-20240229",
+ content=[
+ anthropic.types.TextBlock(
+ type="text",
+ text="Hello, how can I help you?",
+ )
+ ],
+ stop_reason="end_turn",
+ stop_sequence=None,
+ usage=anthropic.types.Usage(
+ input_tokens=10,
+ output_tokens=20,
+ ),
+ )
+
+
+@pytest.fixture
+def nonstreaming_google_genai_model_response():
+ return google.genai.types.GenerateContentResponse(
+ response_id="resp_123",
+ candidates=[
+ google.genai.types.Candidate(
+ content=google.genai.types.Content(
+ role="model",
+ parts=[
+ google.genai.types.Part(
+ text="Hello, how can I help you?",
+ )
+ ],
+ ),
+ finish_reason="STOP",
+ )
+ ],
+ model_version="gemini/gemini-pro",
+ usage_metadata=google.genai.types.GenerateContentResponseUsageMetadata(
+ prompt_token_count=10,
+ candidates_token_count=20,
+ total_token_count=30,
+ ),
+ )
+
+
+@pytest.fixture
+def responses_tool_call_model_responses():
+ def inner(
+ tool_name: str,
+ arguments: str,
+ response_model: str,
+ response_text: str,
+ response_ids: "Iterator[str]",
+ usages: "Iterator[openai.types.responses.ResponseUsage]",
+ ):
+ yield openai.types.responses.Response(
+ id=next(response_ids),
+ output=[
+ openai.types.responses.ResponseFunctionToolCall(
+ id="call_123",
+ call_id="call_123",
+ name=tool_name,
+ type="function_call",
+ arguments=arguments,
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model=response_model,
+ object="response",
+ usage=next(usages),
+ )
+
+ yield openai.types.responses.Response(
+ id=next(response_ids),
+ output=[
+ openai.types.responses.ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ openai.types.responses.ResponseOutputText(
+ text=response_text,
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model=response_model,
+ object="response",
+ usage=next(usages),
+ )
+
+ return inner
class MockServerRequestHandler(BaseHTTPRequestHandler):
def do_GET(self): # noqa: N802
- # Process an HTTP GET request and return a response with an HTTP 200 status.
- self.send_response(200)
+ # Process an HTTP GET request and return a response.
+ # If the path ends with /status/, return status code .
+ # Otherwise return a 200 response.
+ code = 200
+ if "/status/" in self.path:
+ code = int(self.path[-3:])
+
+ self.send_response(code)
self.end_headers()
return
@@ -598,7 +1479,109 @@ def create_mock_http_server():
mock_server_port = get_free_port()
mock_server = HTTPServer(("localhost", mock_server_port), MockServerRequestHandler)
mock_server_thread = Thread(target=mock_server.serve_forever)
- mock_server_thread.setDaemon(True)
+ mock_server_thread.daemon = True
mock_server_thread.start()
return mock_server_port
+
+
+def unpack_werkzeug_response(response):
+ # werkzeug < 2.1 returns a tuple as client response, newer versions return
+ # an object
+ try:
+ return response.get_data(), response.status, response.headers
+ except AttributeError:
+ content, status, headers = response
+ return b"".join(content), status, headers
+
+
+def werkzeug_set_cookie(client, servername, key, value):
+ # client.set_cookie has a different signature in different werkzeug versions
+ try:
+ client.set_cookie(servername, key, value)
+ except TypeError:
+ client.set_cookie(key, value)
+
+
+@contextmanager
+def patch_start_tracing_child(
+ fake_transaction_is_none: bool = False,
+) -> "Iterator[Optional[mock.MagicMock]]":
+ if not fake_transaction_is_none:
+ fake_transaction = mock.MagicMock()
+ fake_start_child = mock.MagicMock()
+ fake_transaction.start_child = fake_start_child
+ else:
+ fake_transaction = None
+ fake_start_child = None
+
+ with mock.patch(
+ "sentry_sdk.tracing_utils.get_current_span", return_value=fake_transaction
+ ):
+ yield fake_start_child
+
+
+class ApproxDict(dict):
+ def __eq__(self, other):
+ # For an ApproxDict to equal another dict, the other dict just needs to contain
+ # all the keys from the ApproxDict with the same values.
+ #
+ # The other dict may contain additional keys with any value.
+ return all(key in other and other[key] == value for key, value in self.items())
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+CapturedData = namedtuple("CapturedData", ["path", "event", "envelope", "compressed"])
+
+
+class CapturingServer(WSGIServer):
+ def __init__(self, host="127.0.0.1", port=0, ssl_context=None):
+ WSGIServer.__init__(self, host, port, self, ssl_context=ssl_context)
+ self.code = 204
+ self.headers = {}
+ self.captured = []
+
+ def respond_with(self, code=200, headers=None):
+ self.code = code
+ if headers:
+ self.headers = headers
+
+ def clear_captured(self):
+ del self.captured[:]
+
+ def __call__(self, environ, start_response):
+ """
+ This is the WSGI application.
+ """
+ request = Request(environ)
+ event = envelope = None
+ content_encoding = request.headers.get("content-encoding")
+ if content_encoding == "gzip":
+ rdr = gzip.GzipFile(fileobj=io.BytesIO(request.data))
+ compressed = True
+ elif content_encoding == "br":
+ rdr = io.BytesIO(brotli.decompress(request.data))
+ compressed = True
+ else:
+ rdr = io.BytesIO(request.data)
+ compressed = False
+
+ if request.mimetype == "application/json":
+ event = parse_json(rdr.read())
+ else:
+ envelope = Envelope.deserialize_from(rdr)
+
+ self.captured.append(
+ CapturedData(
+ path=request.path,
+ event=event,
+ envelope=envelope,
+ compressed=compressed,
+ )
+ )
+
+ response = Response(status=self.code)
+ response.headers.extend(self.headers)
+ return response(environ, start_response)
diff --git a/tests/integrations/aiohttp/__init__.py b/tests/integrations/aiohttp/__init__.py
index b4711aadba..a585c11e34 100644
--- a/tests/integrations/aiohttp/__init__.py
+++ b/tests/integrations/aiohttp/__init__.py
@@ -1,3 +1,9 @@
+import os
+import sys
import pytest
-aiohttp = pytest.importorskip("aiohttp")
+pytest.importorskip("aiohttp")
+
+# Load `aiohttp_helpers` into the module search path to test request source path names relative to module. See
+# `test_request_source_with_module_in_search_path`
+sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
diff --git a/tests/integrations/aiohttp/aiohttp_helpers/__init__.py b/tests/integrations/aiohttp/aiohttp_helpers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/aiohttp/aiohttp_helpers/helpers.py b/tests/integrations/aiohttp/aiohttp_helpers/helpers.py
new file mode 100644
index 0000000000..86a6fa39e3
--- /dev/null
+++ b/tests/integrations/aiohttp/aiohttp_helpers/helpers.py
@@ -0,0 +1,2 @@
+async def get_request_with_client(client, url):
+ await client.get(url)
diff --git a/tests/integrations/aiohttp/test_aiohttp.py b/tests/integrations/aiohttp/test_aiohttp.py
index 7e49a285c3..849f9d017b 100644
--- a/tests/integrations/aiohttp/test_aiohttp.py
+++ b/tests/integrations/aiohttp/test_aiohttp.py
@@ -1,21 +1,32 @@
+import os
+import datetime
import asyncio
import json
+
from contextlib import suppress
+from unittest import mock
import pytest
+
from aiohttp import web
from aiohttp.client import ServerDisconnectedError
from aiohttp.web_request import Request
+from aiohttp.web_exceptions import (
+ HTTPInternalServerError,
+ HTTPNetworkAuthenticationRequired,
+ HTTPBadRequest,
+ HTTPNotFound,
+ HTTPUnavailableForLegalReasons,
+)
-from sentry_sdk.integrations.aiohttp import AioHttpIntegration
-
-try:
- from unittest import mock # python 3.3 and above
-except ImportError:
- import mock # python < 3.3
+from sentry_sdk import capture_message, start_transaction
+from sentry_sdk.integrations.aiohttp import AioHttpIntegration, create_trace_config
+from sentry_sdk.consts import SPANDATA
+from tests.conftest import ApproxDict
-async def test_basic(sentry_init, aiohttp_client, loop, capture_events):
+@pytest.mark.asyncio
+async def test_basic(sentry_init, aiohttp_client, capture_events):
sentry_init(integrations=[AioHttpIntegration()])
async def hello(request):
@@ -49,13 +60,16 @@ async def hello(request):
assert request["url"] == "http://{host}/".format(host=host)
assert request["headers"] == {
"Accept": "*/*",
- "Accept-Encoding": "gzip, deflate",
+ "Accept-Encoding": mock.ANY,
"Host": host,
"User-Agent": request["headers"]["User-Agent"],
+ "baggage": mock.ANY,
+ "sentry-trace": mock.ANY,
}
-async def test_post_body_not_read(sentry_init, aiohttp_client, loop, capture_events):
+@pytest.mark.asyncio
+async def test_post_body_not_read(sentry_init, aiohttp_client, capture_events):
from sentry_sdk.integrations.aiohttp import BODY_NOT_READ_MESSAGE
sentry_init(integrations=[AioHttpIntegration()])
@@ -84,7 +98,8 @@ async def hello(request):
assert request["data"] == BODY_NOT_READ_MESSAGE
-async def test_post_body_read(sentry_init, aiohttp_client, loop, capture_events):
+@pytest.mark.asyncio
+async def test_post_body_read(sentry_init, aiohttp_client, capture_events):
sentry_init(integrations=[AioHttpIntegration()])
body = {"some": "value"}
@@ -112,7 +127,8 @@ async def hello(request):
assert request["data"] == json.dumps(body)
-async def test_403_not_captured(sentry_init, aiohttp_client, loop, capture_events):
+@pytest.mark.asyncio
+async def test_403_not_captured(sentry_init, aiohttp_client, capture_events):
sentry_init(integrations=[AioHttpIntegration()])
async def hello(request):
@@ -130,8 +146,9 @@ async def hello(request):
assert not events
+@pytest.mark.asyncio
async def test_cancelled_error_not_captured(
- sentry_init, aiohttp_client, loop, capture_events
+ sentry_init, aiohttp_client, capture_events
):
sentry_init(integrations=[AioHttpIntegration()])
@@ -152,7 +169,8 @@ async def hello(request):
assert not events
-async def test_half_initialized(sentry_init, aiohttp_client, loop, capture_events):
+@pytest.mark.asyncio
+async def test_half_initialized(sentry_init, aiohttp_client, capture_events):
sentry_init(integrations=[AioHttpIntegration()])
sentry_init()
@@ -171,7 +189,8 @@ async def hello(request):
assert events == []
-async def test_tracing(sentry_init, aiohttp_client, loop, capture_events):
+@pytest.mark.asyncio
+async def test_tracing(sentry_init, aiohttp_client, capture_events):
sentry_init(integrations=[AioHttpIntegration()], traces_sample_rate=1.0)
async def hello(request):
@@ -195,6 +214,7 @@ async def hello(request):
)
+@pytest.mark.asyncio
@pytest.mark.parametrize(
"url,transaction_style,expected_transaction,expected_source",
[
@@ -245,11 +265,42 @@ async def hello(request):
assert event["transaction_info"] == {"source": expected_source}
+@pytest.mark.tests_internal_exceptions
+@pytest.mark.asyncio
+async def test_tracing_unparseable_url(sentry_init, aiohttp_client, capture_events):
+ sentry_init(integrations=[AioHttpIntegration()], traces_sample_rate=1.0)
+
+ async def hello(request):
+ return web.Response(text="hello")
+
+ app = web.Application()
+ app.router.add_get("/", hello)
+
+ events = capture_events()
+
+ client = await aiohttp_client(app)
+ with mock.patch(
+ "sentry_sdk.integrations.aiohttp.parse_url", side_effect=ValueError
+ ):
+ resp = await client.get("/")
+
+ assert resp.status == 200
+
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert (
+ event["transaction"]
+ == "tests.integrations.aiohttp.test_aiohttp.test_tracing_unparseable_url..hello"
+ )
+
+
+@pytest.mark.asyncio
async def test_traces_sampler_gets_request_object_in_sampling_context(
sentry_init,
aiohttp_client,
- DictionaryContaining, # noqa:N803
- ObjectDescribedBy,
+ DictionaryContaining, # noqa: N803
+ ObjectDescribedBy, # noqa: N803
):
traces_sampler = mock.Mock()
sentry_init(
@@ -275,3 +326,808 @@ async def kangaroo_handler(request):
}
)
)
+
+
+@pytest.mark.asyncio
+async def test_has_trace_if_performance_enabled(
+ sentry_init, aiohttp_client, capture_events
+):
+ sentry_init(integrations=[AioHttpIntegration()], traces_sample_rate=1.0)
+
+ async def hello(request):
+ capture_message("It's a good day to try dividing by 0")
+ 1 / 0
+
+ app = web.Application()
+ app.router.add_get("/", hello)
+
+ events = capture_events()
+
+ client = await aiohttp_client(app)
+ resp = await client.get("/")
+ assert resp.status == 500
+
+ msg_event, error_event, transaction_event = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert transaction_event["contexts"]["trace"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ == msg_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_has_trace_if_performance_disabled(
+ sentry_init, aiohttp_client, capture_events
+):
+ sentry_init(integrations=[AioHttpIntegration()])
+
+ async def hello(request):
+ capture_message("It's a good day to try dividing by 0")
+ 1 / 0
+
+ app = web.Application()
+ app.router.add_get("/", hello)
+
+ events = capture_events()
+
+ client = await aiohttp_client(app)
+ resp = await client.get("/")
+ assert resp.status == 500
+
+ msg_event, error_event = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == msg_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_trace_from_headers_if_performance_enabled(
+ sentry_init, aiohttp_client, capture_events
+):
+ sentry_init(integrations=[AioHttpIntegration()], traces_sample_rate=1.0)
+
+ async def hello(request):
+ capture_message("It's a good day to try dividing by 0")
+ 1 / 0
+
+ app = web.Application()
+ app.router.add_get("/", hello)
+
+ events = capture_events()
+
+ # The aiohttp_client is instrumented so will generate the sentry-trace header and add request.
+ # Get the sentry-trace header from the request so we can later compare with transaction events.
+ client = await aiohttp_client(app)
+ with start_transaction():
+ # Headers are only added to the span if there is an active transaction
+ resp = await client.get("/")
+
+ sentry_trace_header = resp.request_info.headers.get("sentry-trace")
+ trace_id = sentry_trace_header.split("-")[0]
+
+ assert resp.status == 500
+
+ # Last item is the custom transaction event wrapping `client.get("/")`
+ msg_event, error_event, transaction_event, _ = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert transaction_event["contexts"]["trace"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert transaction_event["contexts"]["trace"]["trace_id"] == trace_id
+
+
+@pytest.mark.asyncio
+async def test_trace_from_headers_if_performance_disabled(
+ sentry_init, aiohttp_client, capture_events
+):
+ sentry_init(integrations=[AioHttpIntegration()])
+
+ async def hello(request):
+ capture_message("It's a good day to try dividing by 0")
+ 1 / 0
+
+ app = web.Application()
+ app.router.add_get("/", hello)
+
+ events = capture_events()
+
+ # The aiohttp_client is instrumented so will generate the sentry-trace header and add request.
+ # Get the sentry-trace header from the request so we can later compare with transaction events.
+ client = await aiohttp_client(app)
+ resp = await client.get("/")
+ sentry_trace_header = resp.request_info.headers.get("sentry-trace")
+ trace_id = sentry_trace_header.split("-")[0]
+
+ assert resp.status == 500
+
+ msg_event, error_event = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
+
+
+@pytest.mark.asyncio
+async def test_crumb_capture(
+ sentry_init, aiohttp_raw_server, aiohttp_client, capture_events
+):
+ def before_breadcrumb(crumb, hint):
+ crumb["data"]["extra"] = "foo"
+ return crumb
+
+ sentry_init(
+ integrations=[AioHttpIntegration()], before_breadcrumb=before_breadcrumb
+ )
+
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ with start_transaction():
+ events = capture_events()
+
+ client = await aiohttp_client(raw_server)
+ resp = await client.get("/")
+ assert resp.status == 200
+ capture_message("Testing!")
+
+ (event,) = events
+
+ crumb = event["breadcrumbs"]["values"][0]
+ assert crumb["type"] == "http"
+ assert crumb["category"] == "httplib"
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": "http://127.0.0.1:{}/".format(raw_server.port),
+ "http.fragment": "",
+ "http.method": "GET",
+ "http.query": "",
+ "http.response.status_code": 200,
+ "reason": "OK",
+ "extra": "foo",
+ }
+ )
+
+
+@pytest.mark.parametrize(
+ "status_code,level",
+ [
+ (200, None),
+ (301, None),
+ (403, "warning"),
+ (405, "warning"),
+ (500, "error"),
+ ],
+)
+@pytest.mark.asyncio
+async def test_crumb_capture_client_error(
+ sentry_init,
+ aiohttp_raw_server,
+ aiohttp_client,
+ capture_events,
+ status_code,
+ level,
+):
+ sentry_init(integrations=[AioHttpIntegration()])
+
+ async def handler(request):
+ return web.Response(status=status_code)
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ with start_transaction():
+ events = capture_events()
+
+ client = await aiohttp_client(raw_server)
+ resp = await client.get("/")
+ assert resp.status == status_code
+ capture_message("Testing!")
+
+ (event,) = events
+
+ crumb = event["breadcrumbs"]["values"][0]
+ assert crumb["type"] == "http"
+ if level is None:
+ assert "level" not in crumb
+ else:
+ assert crumb["level"] == level
+ assert crumb["category"] == "httplib"
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": "http://127.0.0.1:{}/".format(raw_server.port),
+ "http.fragment": "",
+ "http.method": "GET",
+ "http.query": "",
+ "http.response.status_code": status_code,
+ }
+ )
+
+
+@pytest.mark.asyncio
+async def test_outgoing_trace_headers(sentry_init, aiohttp_raw_server, aiohttp_client):
+ sentry_init(
+ integrations=[AioHttpIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ with start_transaction(
+ name="/interactions/other-dogs/new-dog",
+ op="greeting.sniff",
+ # make trace_id difference between transactions
+ trace_id="0123456789012345678901234567890",
+ ) as transaction:
+ client = await aiohttp_client(raw_server)
+ resp = await client.get("/")
+ request_span = transaction._span_recorder.spans[-1]
+
+ assert resp.request_info.headers[
+ "sentry-trace"
+ ] == "{trace_id}-{parent_span_id}-{sampled}".format(
+ trace_id=transaction.trace_id,
+ parent_span_id=request_span.span_id,
+ sampled=1,
+ )
+
+
+@pytest.mark.asyncio
+async def test_outgoing_trace_headers_append_to_baggage(
+ sentry_init, aiohttp_raw_server, aiohttp_client
+):
+ sentry_init(
+ integrations=[AioHttpIntegration()],
+ traces_sample_rate=1.0,
+ release="d08ebdb9309e1b004c6f52202de58a09c2268e42",
+ )
+
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ with mock.patch("sentry_sdk.tracing_utils.Random.randrange", return_value=500000):
+ with start_transaction(
+ name="/interactions/other-dogs/new-dog",
+ op="greeting.sniff",
+ trace_id="0123456789012345678901234567890",
+ ):
+ client = await aiohttp_client(raw_server)
+ resp = await client.get("/", headers={"bagGage": "custom=value"})
+
+ assert (
+ resp.request_info.headers["baggage"]
+ == "custom=value,sentry-trace_id=0123456789012345678901234567890,sentry-sample_rand=0.500000,sentry-environment=production,sentry-release=d08ebdb9309e1b004c6f52202de58a09c2268e42,sentry-transaction=/interactions/other-dogs/new-dog,sentry-sample_rate=1.0,sentry-sampled=true"
+ )
+
+
+@pytest.mark.asyncio
+async def test_request_source_disabled(
+ sentry_init,
+ aiohttp_raw_server,
+ aiohttp_client,
+ capture_events,
+):
+ sentry_options = {
+ "integrations": [AioHttpIntegration()],
+ "traces_sample_rate": 1.0,
+ "enable_http_request_source": False,
+ "http_request_source_threshold_ms": 0,
+ }
+
+ sentry_init(**sentry_options)
+
+ # server for making span request
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ async def hello(request):
+ span_client = await aiohttp_client(raw_server)
+ await span_client.get("/")
+ return web.Response(text="hello")
+
+ app = web.Application()
+ app.router.add_get(r"/", hello)
+
+ events = capture_events()
+
+ client = await aiohttp_client(app)
+ await client.get("/")
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("enable_http_request_source", [None, True])
+async def test_request_source_enabled(
+ sentry_init,
+ aiohttp_raw_server,
+ aiohttp_client,
+ capture_events,
+ enable_http_request_source,
+):
+ sentry_options = {
+ "integrations": [AioHttpIntegration()],
+ "traces_sample_rate": 1.0,
+ "http_request_source_threshold_ms": 0,
+ }
+ if enable_http_request_source is not None:
+ sentry_options["enable_http_request_source"] = enable_http_request_source
+
+ sentry_init(**sentry_options)
+
+ # server for making span request
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ async def hello(request):
+ span_client = await aiohttp_client(raw_server)
+ await span_client.get("/")
+ return web.Response(text="hello")
+
+ app = web.Application()
+ app.router.add_get(r"/", hello)
+
+ events = capture_events()
+
+ client = await aiohttp_client(app)
+ await client.get("/")
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+
+@pytest.mark.asyncio
+async def test_request_source(
+ sentry_init, aiohttp_raw_server, aiohttp_client, capture_events
+):
+ sentry_init(
+ integrations=[AioHttpIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=0,
+ )
+
+ # server for making span request
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ async def handler_with_outgoing_request(request):
+ span_client = await aiohttp_client(raw_server)
+ await span_client.get("/")
+ return web.Response(text="hello")
+
+ app = web.Application()
+ app.router.add_get(r"/", handler_with_outgoing_request)
+
+ events = capture_events()
+
+ client = await aiohttp_client(app)
+ await client.get("/")
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE) == "tests.integrations.aiohttp.test_aiohttp"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/aiohttp/test_aiohttp.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "handler_with_outgoing_request"
+
+
+@pytest.mark.asyncio
+async def test_request_source_with_module_in_search_path(
+ sentry_init, aiohttp_raw_server, aiohttp_client, capture_events
+):
+ """
+ Test that request source is relative to the path of the module it ran in
+ """
+ sentry_init(
+ integrations=[AioHttpIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=0,
+ )
+
+ # server for making span request
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ from aiohttp_helpers.helpers import get_request_with_client
+
+ async def handler_with_outgoing_request(request):
+ span_client = await aiohttp_client(raw_server)
+ await get_request_with_client(span_client, "/")
+ return web.Response(text="hello")
+
+ app = web.Application()
+ app.router.add_get(r"/", handler_with_outgoing_request)
+
+ events = capture_events()
+
+ client = await aiohttp_client(app)
+ await client.get("/")
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "aiohttp_helpers.helpers"
+ assert data.get(SPANDATA.CODE_FILEPATH) == "aiohttp_helpers/helpers.py"
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "get_request_with_client"
+
+
+@pytest.mark.asyncio
+async def test_no_request_source_if_duration_too_short(
+ sentry_init, aiohttp_raw_server, aiohttp_client, capture_events
+):
+ sentry_init(
+ integrations=[AioHttpIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=100,
+ )
+
+ # server for making span request
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ async def handler_with_outgoing_request(request):
+ span_client = await aiohttp_client(raw_server)
+ await span_client.get("/")
+ return web.Response(text="hello")
+
+ app = web.Application()
+ app.router.add_get(r"/", handler_with_outgoing_request)
+
+ events = capture_events()
+
+ def fake_create_trace_context(*args, **kwargs):
+ trace_context = create_trace_config()
+
+ async def overwrite_timestamps(session, trace_config_ctx, params):
+ span = trace_config_ctx.span
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=99999)
+
+ trace_context.on_request_end.insert(0, overwrite_timestamps)
+
+ return trace_context
+
+ with mock.patch(
+ "sentry_sdk.integrations.aiohttp.create_trace_config",
+ fake_create_trace_context,
+ ):
+ client = await aiohttp_client(app)
+ await client.get("/")
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+@pytest.mark.asyncio
+async def test_request_source_if_duration_over_threshold(
+ sentry_init, aiohttp_raw_server, aiohttp_client, capture_events
+):
+ sentry_init(
+ integrations=[AioHttpIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=100,
+ )
+
+ # server for making span request
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ async def handler_with_outgoing_request(request):
+ span_client = await aiohttp_client(raw_server)
+ await span_client.get("/")
+ return web.Response(text="hello")
+
+ app = web.Application()
+ app.router.add_get(r"/", handler_with_outgoing_request)
+
+ events = capture_events()
+
+ def fake_create_trace_context(*args, **kwargs):
+ trace_context = create_trace_config()
+
+ async def overwrite_timestamps(session, trace_config_ctx, params):
+ span = trace_config_ctx.span
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=100001)
+
+ trace_context.on_request_end.insert(0, overwrite_timestamps)
+
+ return trace_context
+
+ with mock.patch(
+ "sentry_sdk.integrations.aiohttp.create_trace_config",
+ fake_create_trace_context,
+ ):
+ client = await aiohttp_client(app)
+ await client.get("/")
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE) == "tests.integrations.aiohttp.test_aiohttp"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/aiohttp/test_aiohttp.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "handler_with_outgoing_request"
+
+
+@pytest.mark.asyncio
+async def test_span_origin(
+ sentry_init,
+ aiohttp_raw_server,
+ aiohttp_client,
+ capture_events,
+):
+ sentry_init(
+ integrations=[AioHttpIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # server for making span request
+ async def handler(request):
+ return web.Response(text="OK")
+
+ raw_server = await aiohttp_raw_server(handler)
+
+ async def hello(request):
+ span_client = await aiohttp_client(raw_server)
+ await span_client.get("/")
+ return web.Response(text="hello")
+
+ app = web.Application()
+ app.router.add_get(r"/", hello)
+
+ events = capture_events()
+
+ client = await aiohttp_client(app)
+ await client.get("/")
+
+ (event,) = events
+ assert event["contexts"]["trace"]["origin"] == "auto.http.aiohttp"
+ assert event["spans"][0]["origin"] == "auto.http.aiohttp"
+
+
+@pytest.mark.parametrize(
+ ("integration_kwargs", "exception_to_raise", "should_capture"),
+ (
+ ({}, None, False),
+ ({}, HTTPBadRequest, False),
+ (
+ {},
+ HTTPUnavailableForLegalReasons(None),
+ False,
+ ), # Highest 4xx status code (451)
+ ({}, HTTPInternalServerError, True),
+ ({}, HTTPNetworkAuthenticationRequired, True), # Highest 5xx status code (511)
+ ({"failed_request_status_codes": set()}, HTTPInternalServerError, False),
+ (
+ {"failed_request_status_codes": set()},
+ HTTPNetworkAuthenticationRequired,
+ False,
+ ),
+ ({"failed_request_status_codes": {404, *range(500, 600)}}, HTTPNotFound, True),
+ (
+ {"failed_request_status_codes": {404, *range(500, 600)}},
+ HTTPInternalServerError,
+ True,
+ ),
+ (
+ {"failed_request_status_codes": {404, *range(500, 600)}},
+ HTTPBadRequest,
+ False,
+ ),
+ ),
+)
+@pytest.mark.asyncio
+async def test_failed_request_status_codes(
+ sentry_init,
+ aiohttp_client,
+ capture_events,
+ integration_kwargs,
+ exception_to_raise,
+ should_capture,
+):
+ sentry_init(integrations=[AioHttpIntegration(**integration_kwargs)])
+ events = capture_events()
+
+ async def handle(_):
+ if exception_to_raise is not None:
+ raise exception_to_raise
+ else:
+ return web.Response(status=200)
+
+ app = web.Application()
+ app.router.add_get("/", handle)
+
+ client = await aiohttp_client(app)
+ resp = await client.get("/")
+
+ expected_status = (
+ 200 if exception_to_raise is None else exception_to_raise.status_code
+ )
+ assert resp.status == expected_status
+
+ if should_capture:
+ (event,) = events
+ assert event["exception"]["values"][0]["type"] == exception_to_raise.__name__
+ else:
+ assert not events
+
+
+@pytest.mark.asyncio
+async def test_failed_request_status_codes_with_returned_status(
+ sentry_init, aiohttp_client, capture_events
+):
+ """
+ Returning a web.Response with a failed_request_status_code should not be reported to Sentry.
+ """
+ sentry_init(integrations=[AioHttpIntegration(failed_request_status_codes={500})])
+ events = capture_events()
+
+ async def handle(_):
+ return web.Response(status=500)
+
+ app = web.Application()
+ app.router.add_get("/", handle)
+
+ client = await aiohttp_client(app)
+ resp = await client.get("/")
+
+ assert resp.status == 500
+ assert not events
+
+
+@pytest.mark.asyncio
+async def test_failed_request_status_codes_non_http_exception(
+ sentry_init, aiohttp_client, capture_events
+):
+ """
+ If an exception, which is not an instance of HTTPException, is raised, it should be captured, even if
+ failed_request_status_codes is empty.
+ """
+ sentry_init(integrations=[AioHttpIntegration(failed_request_status_codes=set())])
+ events = capture_events()
+
+ async def handle(_):
+ 1 / 0
+
+ app = web.Application()
+ app.router.add_get("/", handle)
+
+ client = await aiohttp_client(app)
+ resp = await client.get("/")
+ assert resp.status == 500
+
+ (event,) = events
+ assert event["exception"]["values"][0]["type"] == "ZeroDivisionError"
diff --git a/tests/integrations/anthropic/__init__.py b/tests/integrations/anthropic/__init__.py
new file mode 100644
index 0000000000..29ac4e6ff4
--- /dev/null
+++ b/tests/integrations/anthropic/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("anthropic")
diff --git a/tests/integrations/anthropic/test_anthropic.py b/tests/integrations/anthropic/test_anthropic.py
new file mode 100644
index 0000000000..e86f7e1fa9
--- /dev/null
+++ b/tests/integrations/anthropic/test_anthropic.py
@@ -0,0 +1,4428 @@
+import pytest
+from unittest import mock
+import json
+
+try:
+ from unittest.mock import AsyncMock
+except ImportError:
+
+ class AsyncMock(mock.MagicMock):
+ async def __call__(self, *args, **kwargs):
+ return super(AsyncMock, self).__call__(*args, **kwargs)
+
+
+from anthropic import Anthropic, AnthropicError, AsyncAnthropic
+from anthropic.types import MessageDeltaUsage, TextDelta, Usage
+from anthropic.types.content_block_delta_event import ContentBlockDeltaEvent
+from anthropic.types.content_block_start_event import ContentBlockStartEvent
+from anthropic.types.content_block_stop_event import ContentBlockStopEvent
+from anthropic.types.message import Message
+from anthropic.types.message_delta_event import MessageDeltaEvent
+from anthropic.types.message_start_event import MessageStartEvent
+
+try:
+ from anthropic.types import ErrorResponse, OverloadedError
+ from anthropic import APIStatusError
+except ImportError:
+ ErrorResponse = None
+ OverloadedError = None
+ APIStatusError = None
+
+try:
+ from anthropic.types import InputJSONDelta
+except ImportError:
+ try:
+ from anthropic.types import InputJsonDelta as InputJSONDelta
+ except ImportError:
+ pass
+
+try:
+ from anthropic.lib.streaming import TextEvent
+except ImportError:
+ TextEvent = None
+
+try:
+ # 0.27+
+ from anthropic.types.raw_message_delta_event import Delta
+ from anthropic.types.tool_use_block import ToolUseBlock
+except ImportError:
+ # pre 0.27
+ from anthropic.types.message_delta_event import Delta
+
+try:
+ from anthropic.types.text_block import TextBlock
+except ImportError:
+ from anthropic.types.content_block import ContentBlock as TextBlock
+
+from sentry_sdk import start_transaction, start_span
+from sentry_sdk._types import BLOB_DATA_SUBSTITUTE
+from sentry_sdk.consts import OP, SPANDATA
+from sentry_sdk.integrations.anthropic import (
+ AnthropicIntegration,
+ _set_output_data,
+ _collect_ai_data,
+ _transform_anthropic_content_block,
+ _RecordedUsage,
+)
+from sentry_sdk.ai.utils import transform_content_part, transform_message_content
+from sentry_sdk.utils import package_version
+
+
+ANTHROPIC_VERSION = package_version("anthropic")
+
+EXAMPLE_MESSAGE = Message(
+ id="msg_01XFDUDYJgAACzvnptvVoYEL",
+ model="model",
+ role="assistant",
+ content=[TextBlock(type="text", text="Hi, I'm Claude.")],
+ type="message",
+ stop_reason="end_turn",
+ usage=Usage(input_tokens=10, output_tokens=20),
+)
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_nonstreaming_create_message(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ response = client.messages.create(
+ max_tokens=1024, messages=messages, model="model"
+ )
+
+ assert response == EXAMPLE_MESSAGE
+ usage = response.usage
+
+ assert usage.input_tokens == 10
+ assert usage.output_tokens == 20
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi, I'm Claude."
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == ["end_turn"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_nonstreaming_create_message_async(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+ client = AsyncAnthropic(api_key="z")
+ client.messages._post = AsyncMock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ response = await client.messages.create(
+ max_tokens=1024, messages=messages, model="model"
+ )
+
+ assert response == EXAMPLE_MESSAGE
+ usage = response.usage
+
+ assert usage.input_tokens == 10
+ assert usage.output_tokens == 20
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi, I'm Claude."
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_streaming_create_message(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+):
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="max_tokens"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ message = client.messages.create(
+ max_tokens=1024, messages=messages, model="model", stream=True
+ )
+
+ for _ in message:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi! I'm Claude!"
+
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == ["max_tokens"]
+
+
+def test_streaming_create_message_close(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="max_tokens"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ messages = client.messages.create(
+ max_tokens=1024, messages=messages, model="model", stream=True
+ )
+
+ for _ in range(4):
+ next(messages)
+
+ messages.close()
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!"
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 41),
+ reason="Error classes moved in https://github.com/anthropics/anthropic-sdk-python/commit/4e0b15e22fe40e9aa513459564f641bf97c90954.",
+)
+def test_streaming_create_message_api_error(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ErrorResponse(
+ type="error",
+ error=OverloadedError(
+ message="Overloaded", type="overloaded_error"
+ ),
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with pytest.raises(APIStatusError), mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ message = client.messages.create(
+ max_tokens=1024, messages=messages, model="model", stream=True
+ )
+
+ for _ in message:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!"
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+ assert span["status"] == "internal_error"
+ assert span["tags"]["status"] == "internal_error"
+ assert event["contexts"]["trace"]["status"] == "internal_error"
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_stream_messages(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+):
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="max_tokens"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ ) as stream:
+ for event in stream:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi! I'm Claude!"
+
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == ["max_tokens"]
+
+
+def test_stream_messages_close(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="max_tokens"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ ) as stream:
+ for _ in range(4):
+ next(stream)
+
+ # New versions add TextEvent, so consume one more event.
+ if TextEvent is not None and isinstance(next(stream), TextEvent):
+ next(stream)
+
+ stream.close()
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!"
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 41),
+ reason="Error classes moved in https://github.com/anthropics/anthropic-sdk-python/commit/4e0b15e22fe40e9aa513459564f641bf97c90954.",
+)
+def test_stream_messages_api_error(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ErrorResponse(
+ type="error",
+ error=OverloadedError(
+ message="Overloaded", type="overloaded_error"
+ ),
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with pytest.raises(APIStatusError), mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ ) as stream:
+ for event in stream:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!"
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+ assert span["status"] == "internal_error"
+ assert span["tags"]["status"] == "internal_error"
+ assert event["contexts"]["trace"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_streaming_create_message_async(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ client = AsyncAnthropic(api_key="z")
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="max_tokens"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ ),
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ default_integrations=False,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ message = await client.messages.create(
+ max_tokens=1024, messages=messages, model="model", stream=True
+ )
+
+ async for _ in message:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi! I'm Claude!"
+
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == ["max_tokens"]
+
+
+@pytest.mark.asyncio
+async def test_streaming_create_message_async_close(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ client = AsyncAnthropic(api_key="z")
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="max_tokens"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ messages = await client.messages.create(
+ max_tokens=1024, messages=messages, model="model", stream=True
+ )
+
+ for _ in range(4):
+ await messages.__anext__()
+ await messages.close()
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!"
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 41),
+ reason="Error classes moved in https://github.com/anthropics/anthropic-sdk-python/commit/4e0b15e22fe40e9aa513459564f641bf97c90954.",
+)
+@pytest.mark.asyncio
+async def test_streaming_create_message_async_api_error(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ client = AsyncAnthropic(api_key="z")
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ErrorResponse(
+ type="error",
+ error=OverloadedError(
+ message="Overloaded", type="overloaded_error"
+ ),
+ ),
+ ]
+ )
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with pytest.raises(APIStatusError), mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ message = await client.messages.create(
+ max_tokens=1024, messages=messages, model="model", stream=True
+ )
+
+ async for _ in message:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!"
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+ assert span["status"] == "internal_error"
+ assert span["tags"]["status"] == "internal_error"
+ assert event["contexts"]["trace"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_stream_message_async(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ client = AsyncAnthropic(api_key="z")
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ ),
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ async with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ ) as stream:
+ async for event in stream:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi! I'm Claude!"
+
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 41),
+ reason="Error classes moved in https://github.com/anthropics/anthropic-sdk-python/commit/4e0b15e22fe40e9aa513459564f641bf97c90954.",
+)
+@pytest.mark.asyncio
+async def test_stream_messages_async_api_error(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ client = AsyncAnthropic(api_key="z")
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ErrorResponse(
+ type="error",
+ error=OverloadedError(
+ message="Overloaded", type="overloaded_error"
+ ),
+ ),
+ ]
+ )
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with pytest.raises(APIStatusError), mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ async with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ ) as stream:
+ async for event in stream:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!"
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+ assert span["status"] == "internal_error"
+ assert span["tags"]["status"] == "internal_error"
+ assert event["contexts"]["trace"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_stream_messages_async_close(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ client = AsyncAnthropic(api_key="z")
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="max_tokens"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ async with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ ) as stream:
+ for _ in range(4):
+ await stream.__anext__()
+
+ # New versions add TextEvent, so consume one more event.
+ if TextEvent is not None and isinstance(
+ await stream.__anext__(), TextEvent
+ ):
+ await stream.__anext__()
+
+ await stream.close()
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT)
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "Hello, Claude"}]'
+ )
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!"
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL"
+
+
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 27),
+ reason="Versions <0.27.0 do not include InputJSONDelta, which was introduced in >=0.27.0 along with a new message delta type for tool calling.",
+)
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_streaming_create_message_with_input_json_delta(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+):
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=Message(
+ id="msg_0",
+ content=[],
+ model="claude-3-5-sonnet-20240620",
+ role="assistant",
+ stop_reason=None,
+ stop_sequence=None,
+ type="message",
+ usage=Usage(input_tokens=366, output_tokens=10),
+ ),
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=ToolUseBlock(
+ id="toolu_0", input={}, name="get_weather", type="tool_use"
+ ),
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json='{"location": "', type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="S", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="an ", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json="Francisco, C", type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json='A"}', type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="tool_use", stop_sequence=None),
+ usage=MessageDeltaUsage(output_tokens=41),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "What is the weather like in San Francisco?",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ message = client.messages.create(
+ max_tokens=1024, messages=messages, model="model", stream=True
+ )
+
+ for _ in message:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "What is the weather like in San Francisco?"}]'
+ )
+ assert (
+ span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ == '{"location": "San Francisco, CA"}'
+ )
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 366
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 41
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 407
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 27),
+ reason="Versions <0.27.0 do not include InputJSONDelta, which was introduced in >=0.27.0 along with a new message delta type for tool calling.",
+)
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_stream_messages_with_input_json_delta(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+):
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=Message(
+ id="msg_0",
+ content=[],
+ model="claude-3-5-sonnet-20240620",
+ role="assistant",
+ stop_reason=None,
+ stop_sequence=None,
+ type="message",
+ usage=Usage(input_tokens=366, output_tokens=10),
+ ),
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=ToolUseBlock(
+ id="toolu_0", input={}, name="get_weather", type="tool_use"
+ ),
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json='{"location": "', type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="S", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="an ", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json="Francisco, C", type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json='A"}', type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="tool_use", stop_sequence=None),
+ usage=MessageDeltaUsage(output_tokens=41),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "What is the weather like in San Francisco?",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ ) as stream:
+ for event in stream:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "What is the weather like in San Francisco?"}]'
+ )
+ assert (
+ span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ == '{"location": "San Francisco, CA"}'
+ )
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 366
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 41
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 407
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 27),
+ reason="Versions <0.27.0 do not include InputJSONDelta, which was introduced in >=0.27.0 along with a new message delta type for tool calling.",
+)
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_streaming_create_message_with_input_json_delta_async(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ client = AsyncAnthropic(api_key="z")
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=Message(
+ id="msg_0",
+ content=[],
+ model="claude-3-5-sonnet-20240620",
+ role="assistant",
+ stop_reason=None,
+ stop_sequence=None,
+ type="message",
+ usage=Usage(input_tokens=366, output_tokens=10),
+ ),
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=ToolUseBlock(
+ id="toolu_0", input={}, name="get_weather", type="tool_use"
+ ),
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json='{"location": "', type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="S", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json="an ", type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json="Francisco, C", type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json='A"}', type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="tool_use", stop_sequence=None),
+ usage=MessageDeltaUsage(output_tokens=41),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "What is the weather like in San Francisco?",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ message = await client.messages.create(
+ max_tokens=1024, messages=messages, model="model", stream=True
+ )
+
+ async for _ in message:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "What is the weather like in San Francisco?"}]'
+ )
+ assert (
+ span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ == '{"location": "San Francisco, CA"}'
+ )
+
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 366
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 41
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 407
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 27),
+ reason="Versions <0.27.0 do not include InputJSONDelta, which was introduced in >=0.27.0 along with a new message delta type for tool calling.",
+)
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_stream_message_with_input_json_delta_async(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ client = AsyncAnthropic(api_key="z")
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=Message(
+ id="msg_0",
+ content=[],
+ model="claude-3-5-sonnet-20240620",
+ role="assistant",
+ stop_reason=None,
+ stop_sequence=None,
+ type="message",
+ usage=Usage(input_tokens=366, output_tokens=10),
+ ),
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=ToolUseBlock(
+ id="toolu_0", input={}, name="get_weather", type="tool_use"
+ ),
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json='{"location": "', type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="S", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json="an ", type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json="Francisco, C", type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=InputJSONDelta(
+ partial_json='A"}', type="input_json_delta"
+ ),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(stop_reason="tool_use", stop_sequence=None),
+ usage=MessageDeltaUsage(output_tokens=41),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "What is the weather like in San Francisco?",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ async with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ ) as stream:
+ async for event in stream:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ == '[{"role": "user", "content": "What is the weather like in San Francisco?"}]'
+ )
+ assert (
+ span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ == '{"location": "San Francisco, CA"}'
+ )
+
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 366
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 41
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 407
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+def test_exception_message_create(sentry_init, capture_events):
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(
+ side_effect=AnthropicError("API rate limit reached")
+ )
+ with pytest.raises(AnthropicError):
+ client.messages.create(
+ model="some-model",
+ messages=[{"role": "system", "content": "I'm throwing an exception"}],
+ max_tokens=1024,
+ )
+
+ (event, transaction) = events
+ assert event["level"] == "error"
+ assert transaction["contexts"]["trace"]["status"] == "internal_error"
+
+
+def test_span_status_error(sentry_init, capture_events):
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with start_transaction(name="anthropic"):
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(
+ side_effect=AnthropicError("API rate limit reached")
+ )
+ with pytest.raises(AnthropicError):
+ client.messages.create(
+ model="some-model",
+ messages=[{"role": "system", "content": "I'm throwing an exception"}],
+ max_tokens=1024,
+ )
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+ assert transaction["spans"][0]["status"] == "internal_error"
+ assert transaction["spans"][0]["tags"]["status"] == "internal_error"
+ assert transaction["spans"][0]["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert transaction["spans"][0]["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+
+
+@pytest.mark.asyncio
+async def test_span_status_error_async(sentry_init, capture_events):
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with start_transaction(name="anthropic"):
+ client = AsyncAnthropic(api_key="z")
+ client.messages._post = AsyncMock(
+ side_effect=AnthropicError("API rate limit reached")
+ )
+ with pytest.raises(AnthropicError):
+ await client.messages.create(
+ model="some-model",
+ messages=[{"role": "system", "content": "I'm throwing an exception"}],
+ max_tokens=1024,
+ )
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+ assert transaction["spans"][0]["status"] == "internal_error"
+ assert transaction["spans"][0]["tags"]["status"] == "internal_error"
+ assert transaction["spans"][0]["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert transaction["spans"][0]["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+
+
+@pytest.mark.asyncio
+async def test_exception_message_create_async(sentry_init, capture_events):
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ client = AsyncAnthropic(api_key="z")
+ client.messages._post = AsyncMock(
+ side_effect=AnthropicError("API rate limit reached")
+ )
+ with pytest.raises(AnthropicError):
+ await client.messages.create(
+ model="some-model",
+ messages=[{"role": "system", "content": "I'm throwing an exception"}],
+ max_tokens=1024,
+ )
+
+ (event, transaction) = events
+ assert event["level"] == "error"
+ assert transaction["contexts"]["trace"]["status"] == "internal_error"
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[AnthropicIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.anthropic"
+ assert event["spans"][0]["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert event["spans"][0]["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+
+
+@pytest.mark.asyncio
+async def test_span_origin_async(sentry_init, capture_events):
+ sentry_init(
+ integrations=[AnthropicIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = AsyncAnthropic(api_key="z")
+ client.messages._post = AsyncMock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ await client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.anthropic"
+ assert event["spans"][0]["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert event["spans"][0]["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+
+
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 27),
+ reason="Versions <0.27.0 do not include InputJSONDelta.",
+)
+def test_collect_ai_data_with_input_json_delta():
+ event = ContentBlockDeltaEvent(
+ delta=InputJSONDelta(partial_json="test", type="input_json_delta"),
+ index=0,
+ type="content_block_delta",
+ )
+ model = None
+
+ usage = _RecordedUsage()
+ usage.output_tokens = 20
+ usage.input_tokens = 10
+
+ content_blocks = []
+
+ model, new_usage, new_content_blocks, response_id, finish_reason = _collect_ai_data(
+ event, model, usage, content_blocks
+ )
+ assert model is None
+ assert new_usage.input_tokens == usage.input_tokens
+ assert new_usage.output_tokens == usage.output_tokens
+ assert new_content_blocks == ["test"]
+ assert response_id is None
+ assert finish_reason is None
+
+
+@pytest.mark.skipif(
+ ANTHROPIC_VERSION < (0, 27),
+ reason="Versions <0.27.0 do not include InputJSONDelta.",
+)
+def test_set_output_data_with_input_json_delta(sentry_init):
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with start_transaction(name="test"):
+ span = start_span()
+ integration = AnthropicIntegration()
+ json_deltas = ["{'test': 'data',", "'more': 'json'}"]
+ _set_output_data(
+ span,
+ integration,
+ model="",
+ input_tokens=10,
+ output_tokens=20,
+ cache_read_input_tokens=0,
+ cache_write_input_tokens=0,
+ content_blocks=[{"text": "".join(json_deltas), "type": "text"}],
+ )
+
+ assert (
+ span._data.get(SPANDATA.GEN_AI_RESPONSE_TEXT)
+ == "{'test': 'data','more': 'json'}"
+ )
+ assert span._data.get(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS) == 10
+ assert span._data.get(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS) == 20
+ assert span._data.get(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS) == 30
+
+
+# Test messages with mixed roles including "ai" that should be mapped to "assistant"
+@pytest.mark.parametrize(
+ "test_message,expected_role",
+ [
+ ({"role": "system", "content": "You are helpful."}, "system"),
+ ({"role": "user", "content": "Hello"}, "user"),
+ (
+ {"role": "ai", "content": "Hi there!"},
+ "assistant",
+ ), # Should be mapped to "assistant"
+ (
+ {"role": "assistant", "content": "How can I help?"},
+ "assistant",
+ ), # Should stay "assistant"
+ ],
+)
+def test_anthropic_message_role_mapping(
+ sentry_init, capture_events, test_message, expected_role
+):
+ """Test that Anthropic integration properly maps message roles like 'ai' to 'assistant'"""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = Anthropic(api_key="z")
+
+ def mock_messages_create(*args, **kwargs):
+ return Message(
+ id="msg_1",
+ content=[TextBlock(text="Hi there!", type="text")],
+ model="claude-3-opus",
+ role="assistant",
+ stop_reason="end_turn",
+ stop_sequence=None,
+ type="message",
+ usage=Usage(input_tokens=10, output_tokens=5),
+ )
+
+ client.messages._post = mock.Mock(return_value=mock_messages_create())
+
+ test_messages = [test_message]
+
+ with start_transaction(name="anthropic tx"):
+ client.messages.create(
+ model="claude-3-opus", max_tokens=10, messages=test_messages
+ )
+
+ (event,) = events
+ span = event["spans"][0]
+
+ # Verify that the span was created correctly
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+
+ # Parse the stored messages
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ assert stored_messages[0]["role"] == expected_role
+
+
+def test_anthropic_message_truncation(sentry_init, capture_events):
+ """Test that large messages are truncated properly in Anthropic integration."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ large_content = (
+ "This is a very long message that will exceed our size limits. " * 1000
+ )
+ messages = [
+ {"role": "user", "content": "small message 1"},
+ {"role": "assistant", "content": large_content},
+ {"role": "user", "content": large_content},
+ {"role": "assistant", "content": "small message 4"},
+ {"role": "user", "content": "small message 5"},
+ ]
+
+ with start_transaction():
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) > 0
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ chat_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == OP.GEN_AI_CHAT
+ ]
+ assert len(chat_spans) > 0
+
+ chat_span = chat_spans[0]
+ assert chat_span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert chat_span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in chat_span["data"]
+
+ messages_data = chat_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert isinstance(messages_data, str)
+
+ parsed_messages = json.loads(messages_data)
+ assert isinstance(parsed_messages, list)
+ assert len(parsed_messages) == 1
+ assert "small message 5" in str(parsed_messages[0])
+
+ assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5
+
+
+@pytest.mark.asyncio
+async def test_anthropic_message_truncation_async(sentry_init, capture_events):
+ """Test that large messages are truncated properly in Anthropic integration."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncAnthropic(api_key="z")
+ client.messages._post = mock.AsyncMock(return_value=EXAMPLE_MESSAGE)
+
+ large_content = (
+ "This is a very long message that will exceed our size limits. " * 1000
+ )
+ messages = [
+ {"role": "user", "content": "small message 1"},
+ {"role": "assistant", "content": large_content},
+ {"role": "user", "content": large_content},
+ {"role": "assistant", "content": "small message 4"},
+ {"role": "user", "content": "small message 5"},
+ ]
+
+ with start_transaction():
+ await client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) > 0
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ chat_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == OP.GEN_AI_CHAT
+ ]
+ assert len(chat_spans) > 0
+
+ chat_span = chat_spans[0]
+ assert chat_span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert chat_span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in chat_span["data"]
+
+ messages_data = chat_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert isinstance(messages_data, str)
+
+ parsed_messages = json.loads(messages_data)
+ assert isinstance(parsed_messages, list)
+ assert len(parsed_messages) == 1
+ assert "small message 5" in str(parsed_messages[0])
+
+ assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_nonstreaming_create_message_with_system_prompt(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ """Test that system prompts are properly captured in GEN_AI_REQUEST_MESSAGES."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ response = client.messages.create(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ system="You are a helpful assistant.",
+ )
+
+ assert response == EXAMPLE_MESSAGE
+ usage = response.usage
+
+ assert usage.input_tokens == 10
+ assert usage.output_tokens == 20
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS in span["data"]
+ system_instructions = json.loads(
+ span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ )
+ assert system_instructions == [
+ {"type": "text", "content": "You are a helpful assistant."}
+ ]
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == "user"
+ assert stored_messages[0]["content"] == "Hello, Claude"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi, I'm Claude."
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == ["end_turn"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_nonstreaming_create_message_with_system_prompt_async(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ """Test that system prompts are properly captured in GEN_AI_REQUEST_MESSAGES (async)."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+ client = AsyncAnthropic(api_key="z")
+ client.messages._post = AsyncMock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ response = await client.messages.create(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ system="You are a helpful assistant.",
+ )
+
+ assert response == EXAMPLE_MESSAGE
+ usage = response.usage
+
+ assert usage.input_tokens == 10
+ assert usage.output_tokens == 20
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS in span["data"]
+ system_instructions = json.loads(
+ span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ )
+ assert system_instructions == [
+ {"type": "text", "content": "You are a helpful assistant."}
+ ]
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == "user"
+ assert stored_messages[0]["content"] == "Hello, Claude"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi, I'm Claude."
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == ["end_turn"]
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_streaming_create_message_with_system_prompt(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+):
+ """Test that system prompts are properly captured in streaming mode."""
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ message = client.messages.create(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ stream=True,
+ system="You are a helpful assistant.",
+ )
+
+ for _ in message:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS in span["data"]
+ system_instructions = json.loads(
+ span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ )
+ assert system_instructions == [
+ {"type": "text", "content": "You are a helpful assistant."}
+ ]
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == "user"
+ assert stored_messages[0]["content"] == "Hello, Claude"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi! I'm Claude!"
+
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_stream_messages_with_system_prompt(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+):
+ """Test that system prompts are properly captured in streaming mode."""
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ system="You are a helpful assistant.",
+ ) as stream:
+ for event in stream:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS in span["data"]
+ system_instructions = json.loads(
+ span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ )
+ assert system_instructions == [
+ {"type": "text", "content": "You are a helpful assistant."}
+ ]
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == "user"
+ assert stored_messages[0]["content"] == "Hello, Claude"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi! I'm Claude!"
+
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_stream_message_with_system_prompt_async(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ """Test that system prompts are properly captured in streaming mode (async)."""
+ client = AsyncAnthropic(api_key="z")
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ async with client.messages.stream(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ system="You are a helpful assistant.",
+ ) as stream:
+ async for event in stream:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS in span["data"]
+ system_instructions = json.loads(
+ span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ )
+ assert system_instructions == [
+ {"type": "text", "content": "You are a helpful assistant."}
+ ]
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == "user"
+ assert stored_messages[0]["content"] == "Hello, Claude"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi! I'm Claude!"
+
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_streaming_create_message_with_system_prompt_async(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ """Test that system prompts are properly captured in streaming mode (async)."""
+ client = AsyncAnthropic(api_key="z")
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ message=EXAMPLE_MESSAGE,
+ type="message_start",
+ ),
+ ContentBlockStartEvent(
+ type="content_block_start",
+ index=0,
+ content_block=TextBlock(type="text", text=""),
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="Hi", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text="!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockDeltaEvent(
+ delta=TextDelta(text=" I'm Claude!", type="text_delta"),
+ index=0,
+ type="content_block_delta",
+ ),
+ ContentBlockStopEvent(type="content_block_stop", index=0),
+ MessageDeltaEvent(
+ delta=Delta(),
+ usage=MessageDeltaUsage(output_tokens=10),
+ type="message_delta",
+ ),
+ ]
+ )
+ )
+ )
+
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello, Claude",
+ }
+ ]
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ message = await client.messages.create(
+ max_tokens=1024,
+ messages=messages,
+ model="model",
+ stream=True,
+ system="You are a helpful assistant.",
+ )
+
+ async for _ in message:
+ pass
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "anthropic"
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat model"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS in span["data"]
+ system_instructions = json.loads(
+ span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ )
+ assert system_instructions == [
+ {"type": "text", "content": "You are a helpful assistant."}
+ ]
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == "user"
+ assert stored_messages[0]["content"] == "Hello, Claude"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi! I'm Claude!"
+
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+def test_system_prompt_with_complex_structure(sentry_init, capture_events):
+ """Test that complex system prompt structures (list of text blocks) are properly captured."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ # System prompt as list of text blocks
+ system_prompt = [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ]
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Hello",
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ response = client.messages.create(
+ max_tokens=1024, messages=messages, model="model", system=system_prompt
+ )
+
+ assert response == EXAMPLE_MESSAGE
+ assert len(events) == 1
+ (event,) = events
+
+ assert len(event["spans"]) == 1
+ (span,) = event["spans"]
+
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS in span["data"]
+ system_instructions = json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS])
+
+ # System content should be a list of text blocks
+ assert isinstance(system_instructions, list)
+ assert system_instructions == [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == "user"
+ assert stored_messages[0]["content"] == "Hello"
+
+
+# Tests for transform_content_part (shared) and _transform_anthropic_content_block helper functions
+
+
+def test_transform_content_part_anthropic_base64_image():
+ """Test that base64 encoded images are transformed to blob format."""
+ content_block = {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": "base64encodeddata...",
+ },
+ }
+
+ result = transform_content_part(content_block)
+
+ assert result == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "content": "base64encodeddata...",
+ }
+
+
+def test_transform_content_part_anthropic_url_image():
+ """Test that URL-referenced images are transformed to uri format."""
+ content_block = {
+ "type": "image",
+ "source": {
+ "type": "url",
+ "url": "https://example.com/image.jpg",
+ },
+ }
+
+ result = transform_content_part(content_block)
+
+ assert result == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "",
+ "uri": "https://example.com/image.jpg",
+ }
+
+
+def test_transform_content_part_anthropic_file_image():
+ """Test that file_id-referenced images are transformed to file format."""
+ content_block = {
+ "type": "image",
+ "source": {
+ "type": "file",
+ "file_id": "file_abc123",
+ },
+ }
+
+ result = transform_content_part(content_block)
+
+ assert result == {
+ "type": "file",
+ "modality": "image",
+ "mime_type": "",
+ "file_id": "file_abc123",
+ }
+
+
+def test_transform_content_part_anthropic_base64_document():
+ """Test that base64 encoded PDFs are transformed to blob format."""
+ content_block = {
+ "type": "document",
+ "source": {
+ "type": "base64",
+ "media_type": "application/pdf",
+ "data": "base64encodedpdfdata...",
+ },
+ }
+
+ result = transform_content_part(content_block)
+
+ assert result == {
+ "type": "blob",
+ "modality": "document",
+ "mime_type": "application/pdf",
+ "content": "base64encodedpdfdata...",
+ }
+
+
+def test_transform_content_part_anthropic_url_document():
+ """Test that URL-referenced documents are transformed to uri format."""
+ content_block = {
+ "type": "document",
+ "source": {
+ "type": "url",
+ "url": "https://example.com/document.pdf",
+ },
+ }
+
+ result = transform_content_part(content_block)
+
+ assert result == {
+ "type": "uri",
+ "modality": "document",
+ "mime_type": "",
+ "uri": "https://example.com/document.pdf",
+ }
+
+
+def test_transform_content_part_anthropic_file_document():
+ """Test that file_id-referenced documents are transformed to file format."""
+ content_block = {
+ "type": "document",
+ "source": {
+ "type": "file",
+ "file_id": "file_doc456",
+ "media_type": "application/pdf",
+ },
+ }
+
+ result = transform_content_part(content_block)
+
+ assert result == {
+ "type": "file",
+ "modality": "document",
+ "mime_type": "application/pdf",
+ "file_id": "file_doc456",
+ }
+
+
+def test_transform_anthropic_content_block_text_document():
+ """Test that plain text documents are transformed correctly (Anthropic-specific)."""
+ content_block = {
+ "type": "document",
+ "source": {
+ "type": "text",
+ "media_type": "text/plain",
+ "data": "This is plain text content.",
+ },
+ }
+
+ # Use Anthropic-specific helper for text-type documents
+ result = _transform_anthropic_content_block(content_block)
+
+ assert result == {
+ "type": "text",
+ "text": "This is plain text content.",
+ }
+
+
+def test_transform_content_part_text_block():
+ """Test that regular text blocks return None (not transformed)."""
+ content_block = {
+ "type": "text",
+ "text": "Hello, world!",
+ }
+
+ # Shared transform_content_part returns None for text blocks
+ result = transform_content_part(content_block)
+
+ assert result is None
+
+
+def test_transform_message_content_string():
+ """Test that string content is returned as-is."""
+ result = transform_message_content("Hello, world!")
+ assert result == "Hello, world!"
+
+
+def test_transform_message_content_list_anthropic():
+ """Test that list content with Anthropic format is transformed correctly."""
+ content = [
+ {"type": "text", "text": "Hello!"},
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/png",
+ "data": "base64data...",
+ },
+ },
+ ]
+
+ result = transform_message_content(content)
+
+ assert len(result) == 2
+ # Text block stays as-is (transform returns None, keeps original)
+ assert result[0] == {"type": "text", "text": "Hello!"}
+ assert result[1] == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/png",
+ "content": "base64data...",
+ }
+
+
+# Integration tests for binary data in messages
+
+
+def test_message_with_base64_image(sentry_init, capture_events):
+ """Test that messages with base64 images are properly captured."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What's in this image?"},
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": "base64encodeddatahere...",
+ },
+ },
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == "user"
+ content = stored_messages[0]["content"]
+ assert len(content) == 2
+ assert content[0] == {"type": "text", "text": "What's in this image?"}
+ assert content[1] == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "content": BLOB_DATA_SUBSTITUTE,
+ }
+
+
+def test_message_with_url_image(sentry_init, capture_events):
+ """Test that messages with URL-referenced images are properly captured."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this image."},
+ {
+ "type": "image",
+ "source": {
+ "type": "url",
+ "url": "https://example.com/photo.png",
+ },
+ },
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ content = stored_messages[0]["content"]
+ assert content[1] == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "",
+ "uri": "https://example.com/photo.png",
+ }
+
+
+def test_message_with_file_image(sentry_init, capture_events):
+ """Test that messages with file_id-referenced images are properly captured."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What do you see?"},
+ {
+ "type": "image",
+ "source": {
+ "type": "file",
+ "file_id": "file_img_12345",
+ "media_type": "image/webp",
+ },
+ },
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ content = stored_messages[0]["content"]
+ assert content[1] == {
+ "type": "file",
+ "modality": "image",
+ "mime_type": "image/webp",
+ "file_id": "file_img_12345",
+ }
+
+
+def test_message_with_base64_pdf(sentry_init, capture_events):
+ """Test that messages with base64-encoded PDF documents are properly captured."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Summarize this document."},
+ {
+ "type": "document",
+ "source": {
+ "type": "base64",
+ "media_type": "application/pdf",
+ "data": "JVBERi0xLjQKJeLj...base64pdfdata",
+ },
+ },
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ content = stored_messages[0]["content"]
+ assert content[1] == {
+ "type": "blob",
+ "modality": "document",
+ "mime_type": "application/pdf",
+ "content": BLOB_DATA_SUBSTITUTE,
+ }
+
+
+def test_message_with_url_pdf(sentry_init, capture_events):
+ """Test that messages with URL-referenced PDF documents are properly captured."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What is in this PDF?"},
+ {
+ "type": "document",
+ "source": {
+ "type": "url",
+ "url": "https://example.com/report.pdf",
+ },
+ },
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ content = stored_messages[0]["content"]
+ assert content[1] == {
+ "type": "uri",
+ "modality": "document",
+ "mime_type": "",
+ "uri": "https://example.com/report.pdf",
+ }
+
+
+def test_message_with_file_document(sentry_init, capture_events):
+ """Test that messages with file_id-referenced documents are properly captured."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Analyze this document."},
+ {
+ "type": "document",
+ "source": {
+ "type": "file",
+ "file_id": "file_doc_67890",
+ "media_type": "application/pdf",
+ },
+ },
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ content = stored_messages[0]["content"]
+ assert content[1] == {
+ "type": "file",
+ "modality": "document",
+ "mime_type": "application/pdf",
+ "file_id": "file_doc_67890",
+ }
+
+
+def test_message_with_mixed_content(sentry_init, capture_events):
+ """Test that messages with mixed content (text, images, documents) are properly captured."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Compare this image with the document."},
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/png",
+ "data": "iVBORw0KGgo...base64imagedata",
+ },
+ },
+ {
+ "type": "image",
+ "source": {
+ "type": "url",
+ "url": "https://example.com/comparison.jpg",
+ },
+ },
+ {
+ "type": "document",
+ "source": {
+ "type": "base64",
+ "media_type": "application/pdf",
+ "data": "JVBERi0xLjQK...base64pdfdata",
+ },
+ },
+ {"type": "text", "text": "Please provide a detailed analysis."},
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ content = stored_messages[0]["content"]
+
+ assert len(content) == 5
+ assert content[0] == {
+ "type": "text",
+ "text": "Compare this image with the document.",
+ }
+ assert content[1] == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/png",
+ "content": BLOB_DATA_SUBSTITUTE,
+ }
+ assert content[2] == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "",
+ "uri": "https://example.com/comparison.jpg",
+ }
+ assert content[3] == {
+ "type": "blob",
+ "modality": "document",
+ "mime_type": "application/pdf",
+ "content": BLOB_DATA_SUBSTITUTE,
+ }
+ assert content[4] == {
+ "type": "text",
+ "text": "Please provide a detailed analysis.",
+ }
+
+
+def test_message_with_multiple_images_different_formats(sentry_init, capture_events):
+ """Test that messages with multiple images of different source types are handled."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": "base64data1...",
+ },
+ },
+ {
+ "type": "image",
+ "source": {
+ "type": "url",
+ "url": "https://example.com/img2.gif",
+ },
+ },
+ {
+ "type": "image",
+ "source": {
+ "type": "file",
+ "file_id": "file_img_789",
+ "media_type": "image/webp",
+ },
+ },
+ {"type": "text", "text": "Compare these three images."},
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ content = stored_messages[0]["content"]
+
+ assert len(content) == 4
+ assert content[0] == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "content": BLOB_DATA_SUBSTITUTE,
+ }
+ assert content[1] == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "",
+ "uri": "https://example.com/img2.gif",
+ }
+ assert content[2] == {
+ "type": "file",
+ "modality": "image",
+ "mime_type": "image/webp",
+ "file_id": "file_img_789",
+ }
+ assert content[3] == {"type": "text", "text": "Compare these three images."}
+
+
+def test_binary_content_not_stored_when_pii_disabled(sentry_init, capture_events):
+ """Test that binary content is not stored when send_default_pii is False."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What's in this image?"},
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": "base64encodeddatahere...",
+ },
+ },
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ # Messages should not be stored
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+
+
+def test_binary_content_not_stored_when_prompts_disabled(sentry_init, capture_events):
+ """Test that binary content is not stored when include_prompts is False."""
+ sentry_init(
+ integrations=[AnthropicIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ client = Anthropic(api_key="z")
+ client.messages._post = mock.Mock(return_value=EXAMPLE_MESSAGE)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What's in this image?"},
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": "base64encodeddatahere...",
+ },
+ },
+ ],
+ }
+ ]
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(max_tokens=1024, messages=messages, model="model")
+
+ assert len(events) == 1
+ (event,) = events
+ (span,) = event["spans"]
+
+ # Messages should not be stored
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+
+
+def test_cache_tokens_nonstreaming(sentry_init, capture_events):
+ """Test cache read/write tokens are tracked for non-streaming responses."""
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+ client = Anthropic(api_key="z")
+
+ client.messages._post = mock.Mock(
+ return_value=Message(
+ id="id",
+ model="claude-3-5-sonnet-20241022",
+ role="assistant",
+ content=[TextBlock(type="text", text="Response")],
+ type="message",
+ usage=Usage(
+ input_tokens=100,
+ output_tokens=50,
+ cache_read_input_tokens=80,
+ cache_creation_input_tokens=20,
+ ),
+ )
+ )
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "Hello"}],
+ model="claude-3-5-sonnet-20241022",
+ )
+
+ (span,) = events[0]["spans"]
+ # input_tokens normalized: 100 + 80 (cache_read) + 20 (cache_write) = 200
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 250
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20
+
+
+def test_input_tokens_include_cache_write_nonstreaming(sentry_init, capture_events):
+ """
+ Test that gen_ai.usage.input_tokens includes cache_write tokens (non-streaming).
+
+ Reproduces a real Anthropic cache-write response. Anthropic's usage.input_tokens
+ only counts non-cached tokens, but gen_ai.usage.input_tokens should be the TOTAL
+ so downstream cost calculations don't produce negative values.
+
+ Real Anthropic response (from E2E test):
+ Usage(input_tokens=19, output_tokens=14,
+ cache_creation_input_tokens=2846, cache_read_input_tokens=0)
+ """
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+ client = Anthropic(api_key="z")
+
+ client.messages._post = mock.Mock(
+ return_value=Message(
+ id="id",
+ model="claude-sonnet-4-20250514",
+ role="assistant",
+ content=[TextBlock(type="text", text="3 + 3 equals 6.")],
+ type="message",
+ usage=Usage(
+ input_tokens=19,
+ output_tokens=14,
+ cache_read_input_tokens=0,
+ cache_creation_input_tokens=2846,
+ ),
+ )
+ )
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "What is 3+3?"}],
+ model="claude-sonnet-4-20250514",
+ )
+
+ (span,) = events[0]["spans"]
+
+ # input_tokens should be total: 19 (non-cached) + 2846 (cache_write) = 2865
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 2865
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 2879 # 2865 + 14
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 0
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 2846
+
+
+def test_input_tokens_include_cache_read_nonstreaming(sentry_init, capture_events):
+ """
+ Test that gen_ai.usage.input_tokens includes cache_read tokens (non-streaming).
+
+ Reproduces a real Anthropic cache-hit response. This is the scenario that
+ caused negative gen_ai.cost.input_tokens: input_tokens=19 but cached=2846,
+ so the backend computed 19 - 2846 = -2827 "regular" tokens.
+
+ Real Anthropic response (from E2E test):
+ Usage(input_tokens=19, output_tokens=14,
+ cache_creation_input_tokens=0, cache_read_input_tokens=2846)
+ """
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+ client = Anthropic(api_key="z")
+
+ client.messages._post = mock.Mock(
+ return_value=Message(
+ id="id",
+ model="claude-sonnet-4-20250514",
+ role="assistant",
+ content=[TextBlock(type="text", text="5 + 5 = 10.")],
+ type="message",
+ usage=Usage(
+ input_tokens=19,
+ output_tokens=14,
+ cache_read_input_tokens=2846,
+ cache_creation_input_tokens=0,
+ ),
+ )
+ )
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "What is 5+5?"}],
+ model="claude-sonnet-4-20250514",
+ )
+
+ (span,) = events[0]["spans"]
+
+ # input_tokens should be total: 19 (non-cached) + 2846 (cache_read) = 2865
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 2865
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 2879 # 2865 + 14
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 2846
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 0
+
+
+def test_input_tokens_include_cache_read_streaming(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ """
+ Test that gen_ai.usage.input_tokens includes cache_read tokens (streaming).
+
+ Same cache-hit scenario as non-streaming, using realistic streaming events.
+ """
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ type="message_start",
+ message=Message(
+ id="id",
+ model="claude-sonnet-4-20250514",
+ role="assistant",
+ content=[],
+ type="message",
+ usage=Usage(
+ input_tokens=19,
+ output_tokens=0,
+ cache_read_input_tokens=2846,
+ cache_creation_input_tokens=0,
+ ),
+ ),
+ ),
+ MessageDeltaEvent(
+ type="message_delta",
+ delta=Delta(stop_reason="end_turn"),
+ usage=MessageDeltaUsage(output_tokens=14),
+ ),
+ ]
+ )
+ )
+
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ for _ in client.messages.create(
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "What is 5+5?"}],
+ model="claude-sonnet-4-20250514",
+ stream=True,
+ ):
+ pass
+
+ (span,) = events[0]["spans"]
+
+ # input_tokens should be total: 19 + 2846 = test_stream_messages_input_tokens_include_cache_read_streaming
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 2865
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 2879 # 2865 + 14
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 2846
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 0
+
+
+def test_stream_messages_input_tokens_include_cache_read_streaming(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ """
+ Test that gen_ai.usage.input_tokens includes cache_read tokens (streaming).
+
+ Same cache-hit scenario as non-streaming, using realistic streaming events.
+ """
+ client = Anthropic(api_key="z")
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ type="message_start",
+ message=Message(
+ id="id",
+ model="claude-sonnet-4-20250514",
+ role="assistant",
+ content=[],
+ type="message",
+ usage=Usage(
+ input_tokens=19,
+ output_tokens=0,
+ cache_read_input_tokens=2846,
+ cache_creation_input_tokens=0,
+ ),
+ ),
+ ),
+ MessageDeltaEvent(
+ type="message_delta",
+ delta=Delta(stop_reason="end_turn"),
+ usage=MessageDeltaUsage(output_tokens=14),
+ ),
+ ]
+ )
+ )
+
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ with client.messages.stream(
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "What is 5+5?"}],
+ model="claude-sonnet-4-20250514",
+ ) as stream:
+ for event in stream:
+ pass
+
+ (span,) = events[0]["spans"]
+
+ # input_tokens should be total: 19 + 2846 = 2865
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 2865
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 2879 # 2865 + 14
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 2846
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 0
+
+
+def test_input_tokens_unchanged_without_caching(sentry_init, capture_events):
+ """
+ Test that input_tokens is unchanged when there are no cached tokens.
+
+ Real Anthropic response (from E2E test, simple call without caching):
+ Usage(input_tokens=20, output_tokens=12)
+ """
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+ client = Anthropic(api_key="z")
+
+ client.messages._post = mock.Mock(
+ return_value=Message(
+ id="id",
+ model="claude-sonnet-4-20250514",
+ role="assistant",
+ content=[TextBlock(type="text", text="2+2 equals 4.")],
+ type="message",
+ usage=Usage(
+ input_tokens=20,
+ output_tokens=12,
+ ),
+ )
+ )
+
+ with start_transaction(name="anthropic"):
+ client.messages.create(
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "What is 2+2?"}],
+ model="claude-sonnet-4-20250514",
+ )
+
+ (span,) = events[0]["spans"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 32 # 20 + 12
+
+
+def test_cache_tokens_streaming(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ """Test cache tokens are tracked for streaming responses."""
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ type="message_start",
+ message=Message(
+ id="id",
+ model="claude-3-5-sonnet-20241022",
+ role="assistant",
+ content=[],
+ type="message",
+ usage=Usage(
+ input_tokens=100,
+ output_tokens=0,
+ cache_read_input_tokens=80,
+ cache_creation_input_tokens=20,
+ ),
+ ),
+ ),
+ MessageDeltaEvent(
+ type="message_delta",
+ delta=Delta(stop_reason="end_turn"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ ),
+ ]
+ )
+ )
+
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ for _ in client.messages.create(
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "Hello"}],
+ model="claude-3-5-sonnet-20241022",
+ stream=True,
+ ):
+ pass
+
+ (span,) = events[0]["spans"]
+ # input_tokens normalized: 100 + 80 (cache_read) + 20 (cache_write) = 200
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 210
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20
+
+
+def test_stream_messages_cache_tokens(
+ sentry_init, capture_events, get_model_response, server_side_event_chunks
+):
+ """Test cache tokens are tracked for streaming responses."""
+ client = Anthropic(api_key="z")
+
+ response = get_model_response(
+ server_side_event_chunks(
+ [
+ MessageStartEvent(
+ type="message_start",
+ message=Message(
+ id="id",
+ model="claude-3-5-sonnet-20241022",
+ role="assistant",
+ content=[],
+ type="message",
+ usage=Usage(
+ input_tokens=100,
+ output_tokens=0,
+ cache_read_input_tokens=80,
+ cache_creation_input_tokens=20,
+ ),
+ ),
+ ),
+ MessageDeltaEvent(
+ type="message_delta",
+ delta=Delta(stop_reason="end_turn"),
+ usage=MessageDeltaUsage(output_tokens=10),
+ ),
+ ]
+ )
+ )
+
+ sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with mock.patch.object(
+ client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with start_transaction(name="anthropic"):
+ with client.messages.stream(
+ max_tokens=1024,
+ messages=[{"role": "user", "content": "Hello"}],
+ model="claude-3-5-sonnet-20241022",
+ ) as stream:
+ for event in stream:
+ pass
+
+ (span,) = events[0]["spans"]
+ # input_tokens normalized: 100 + 80 (cache_read) + 20 (cache_write) = 200
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 210
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20
diff --git a/tests/integrations/ariadne/__init__.py b/tests/integrations/ariadne/__init__.py
new file mode 100644
index 0000000000..6d592b7a41
--- /dev/null
+++ b/tests/integrations/ariadne/__init__.py
@@ -0,0 +1,5 @@
+import pytest
+
+pytest.importorskip("ariadne")
+pytest.importorskip("fastapi")
+pytest.importorskip("flask")
diff --git a/tests/integrations/ariadne/test_ariadne.py b/tests/integrations/ariadne/test_ariadne.py
new file mode 100644
index 0000000000..2c3b086aa5
--- /dev/null
+++ b/tests/integrations/ariadne/test_ariadne.py
@@ -0,0 +1,276 @@
+from ariadne import gql, graphql_sync, ObjectType, QueryType, make_executable_schema
+from ariadne.asgi import GraphQL
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+from flask import Flask, request, jsonify
+
+from sentry_sdk.integrations.ariadne import AriadneIntegration
+from sentry_sdk.integrations.fastapi import FastApiIntegration
+from sentry_sdk.integrations.flask import FlaskIntegration
+from sentry_sdk.integrations.starlette import StarletteIntegration
+
+
+def schema_factory():
+ type_defs = gql(
+ """
+ type Query {
+ greeting(name: String): Greeting
+ error: String
+ }
+
+ type Greeting {
+ name: String
+ }
+ """
+ )
+
+ query = QueryType()
+ greeting = ObjectType("Greeting")
+
+ @query.field("greeting")
+ def resolve_greeting(*_, **kwargs):
+ name = kwargs.pop("name")
+ return {"name": name}
+
+ @query.field("error")
+ def resolve_error(obj, *_):
+ raise RuntimeError("resolver failed")
+
+ @greeting.field("name")
+ def resolve_name(obj, *_):
+ return "Hello, {}!".format(obj["name"])
+
+ return make_executable_schema(type_defs, query)
+
+
+def test_capture_request_and_response_if_send_pii_is_on_async(
+ sentry_init, capture_events
+):
+ sentry_init(
+ send_default_pii=True,
+ integrations=[
+ AriadneIntegration(),
+ FastApiIntegration(),
+ StarletteIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ schema = schema_factory()
+
+ async_app = FastAPI()
+ async_app.mount("/graphql/", GraphQL(schema))
+
+ query = {"query": "query ErrorQuery {error}"}
+ client = TestClient(async_app)
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "ariadne"
+ assert event["contexts"]["response"] == {
+ "data": {
+ "data": {"error": None},
+ "errors": [
+ {
+ "locations": [{"column": 19, "line": 1}],
+ "message": "resolver failed",
+ "path": ["error"],
+ }
+ ],
+ }
+ }
+ assert event["request"]["api_target"] == "graphql"
+ assert event["request"]["data"] == query
+
+
+def test_capture_request_and_response_if_send_pii_is_on_sync(
+ sentry_init, capture_events
+):
+ sentry_init(
+ send_default_pii=True,
+ integrations=[AriadneIntegration(), FlaskIntegration()],
+ )
+ events = capture_events()
+
+ schema = schema_factory()
+
+ sync_app = Flask(__name__)
+
+ @sync_app.route("/graphql", methods=["POST"])
+ def graphql_server():
+ data = request.get_json()
+ success, result = graphql_sync(schema, data)
+ return jsonify(result), 200
+
+ query = {"query": "query ErrorQuery {error}"}
+ client = sync_app.test_client()
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "ariadne"
+ assert event["contexts"]["response"] == {
+ "data": {
+ "data": {"error": None},
+ "errors": [
+ {
+ "locations": [{"column": 19, "line": 1}],
+ "message": "resolver failed",
+ "path": ["error"],
+ }
+ ],
+ }
+ }
+ assert event["request"]["api_target"] == "graphql"
+ assert event["request"]["data"] == query
+
+
+def test_do_not_capture_request_and_response_if_send_pii_is_off_async(
+ sentry_init, capture_events
+):
+ sentry_init(
+ integrations=[
+ AriadneIntegration(),
+ FastApiIntegration(),
+ StarletteIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ schema = schema_factory()
+
+ async_app = FastAPI()
+ async_app.mount("/graphql/", GraphQL(schema))
+
+ query = {"query": "query ErrorQuery {error}"}
+ client = TestClient(async_app)
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "ariadne"
+ assert "data" not in event["request"]
+ assert "response" not in event["contexts"]
+
+
+def test_do_not_capture_request_and_response_if_send_pii_is_off_sync(
+ sentry_init, capture_events
+):
+ sentry_init(
+ integrations=[AriadneIntegration(), FlaskIntegration()],
+ )
+ events = capture_events()
+
+ schema = schema_factory()
+
+ sync_app = Flask(__name__)
+
+ @sync_app.route("/graphql", methods=["POST"])
+ def graphql_server():
+ data = request.get_json()
+ success, result = graphql_sync(schema, data)
+ return jsonify(result), 200
+
+ query = {"query": "query ErrorQuery {error}"}
+ client = sync_app.test_client()
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "ariadne"
+ assert "data" not in event["request"]
+ assert "response" not in event["contexts"]
+
+
+def test_capture_validation_error(sentry_init, capture_events):
+ sentry_init(
+ send_default_pii=True,
+ integrations=[
+ AriadneIntegration(),
+ FastApiIntegration(),
+ StarletteIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ schema = schema_factory()
+
+ async_app = FastAPI()
+ async_app.mount("/graphql/", GraphQL(schema))
+
+ query = {"query": "query ErrorQuery {doesnt_exist}"}
+ client = TestClient(async_app)
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "ariadne"
+ assert event["contexts"]["response"] == {
+ "data": {
+ "errors": [
+ {
+ "locations": [{"column": 19, "line": 1}],
+ "message": "Cannot query field 'doesnt_exist' on type 'Query'.",
+ }
+ ]
+ }
+ }
+ assert event["request"]["api_target"] == "graphql"
+ assert event["request"]["data"] == query
+
+
+def test_no_event_if_no_errors_async(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ AriadneIntegration(),
+ FastApiIntegration(),
+ StarletteIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ schema = schema_factory()
+
+ async_app = FastAPI()
+ async_app.mount("/graphql/", GraphQL(schema))
+
+ query = {
+ "query": "query GreetingQuery($name: String) { greeting(name: $name) {name} }",
+ "variables": {"name": "some name"},
+ }
+ client = TestClient(async_app)
+ client.post("/graphql", json=query)
+
+ assert len(events) == 0
+
+
+def test_no_event_if_no_errors_sync(sentry_init, capture_events):
+ sentry_init(
+ integrations=[AriadneIntegration(), FlaskIntegration()],
+ )
+ events = capture_events()
+
+ schema = schema_factory()
+
+ sync_app = Flask(__name__)
+
+ @sync_app.route("/graphql", methods=["POST"])
+ def graphql_server():
+ data = request.get_json()
+ success, result = graphql_sync(schema, data)
+ return jsonify(result), 200
+
+ query = {
+ "query": "query GreetingQuery($name: String) { greeting(name: $name) {name} }",
+ "variables": {"name": "some name"},
+ }
+ client = sync_app.test_client()
+ client.post("/graphql", json=query)
+
+ assert len(events) == 0
diff --git a/tests/integrations/arq/test_arq.py b/tests/integrations/arq/test_arq.py
index d7e0e8af85..177f047101 100644
--- a/tests/integrations/arq/test_arq.py
+++ b/tests/integrations/arq/test_arq.py
@@ -1,16 +1,30 @@
+import asyncio
+from datetime import timedelta
+
import pytest
-from sentry_sdk import start_transaction
+from sentry_sdk import get_client, start_transaction
from sentry_sdk.integrations.arq import ArqIntegration
+import arq.worker
+from arq import cron
from arq.connections import ArqRedis
from arq.jobs import Job
from arq.utils import timestamp_ms
-from arq.worker import Retry, Worker
from fakeredis.aioredis import FakeRedis
+def async_partial(async_fn, *args, **kwargs):
+ # asyncio.iscoroutinefunction (Used in the integration code) in Python < 3.8
+ # does not detect async functions in functools.partial objects.
+ # This partial implementation returns a coroutine instead.
+ async def wrapped(ctx):
+ return await async_fn(ctx, *args, **kwargs)
+
+ return wrapped
+
+
@pytest.fixture(autouse=True)
def patch_fakeredis_info_command():
from fakeredis._fakesocket import FakeSocket
@@ -28,31 +42,157 @@ def info(self, section):
@pytest.fixture
def init_arq(sentry_init):
- def inner(functions, allow_abort_jobs=False):
+ def inner(
+ cls_functions=None,
+ cls_cron_jobs=None,
+ kw_functions=None,
+ kw_cron_jobs=None,
+ allow_abort_jobs_=False,
+ ):
+ cls_functions = cls_functions or []
+ cls_cron_jobs = cls_cron_jobs or []
+
+ kwargs = {}
+ if kw_functions is not None:
+ kwargs["functions"] = kw_functions
+ if kw_cron_jobs is not None:
+ kwargs["cron_jobs"] = kw_cron_jobs
+
+ sentry_init(
+ integrations=[ArqIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ server = FakeRedis()
+ pool = ArqRedis(pool_or_conn=server.connection_pool)
+
+ class WorkerSettings:
+ functions = cls_functions
+ cron_jobs = cls_cron_jobs
+ redis_pool = pool
+ allow_abort_jobs = allow_abort_jobs_
+
+ if not WorkerSettings.functions:
+ del WorkerSettings.functions
+ if not WorkerSettings.cron_jobs:
+ del WorkerSettings.cron_jobs
+
+ worker = arq.worker.create_worker(WorkerSettings, **kwargs)
+
+ return pool, worker
+
+ return inner
+
+
+@pytest.fixture
+def init_arq_with_dict_settings(sentry_init):
+ def inner(
+ cls_functions=None,
+ cls_cron_jobs=None,
+ kw_functions=None,
+ kw_cron_jobs=None,
+ allow_abort_jobs_=False,
+ ):
+ cls_functions = cls_functions or []
+ cls_cron_jobs = cls_cron_jobs or []
+
+ kwargs = {}
+ if kw_functions is not None:
+ kwargs["functions"] = kw_functions
+ if kw_cron_jobs is not None:
+ kwargs["cron_jobs"] = kw_cron_jobs
+
sentry_init(
integrations=[ArqIntegration()],
traces_sample_rate=1.0,
send_default_pii=True,
- debug=True,
)
server = FakeRedis()
pool = ArqRedis(pool_or_conn=server.connection_pool)
- return pool, Worker(
- functions, redis_pool=pool, allow_abort_jobs=allow_abort_jobs
+
+ worker_settings = {
+ "functions": cls_functions,
+ "cron_jobs": cls_cron_jobs,
+ "redis_pool": pool,
+ "allow_abort_jobs": allow_abort_jobs_,
+ }
+
+ if not worker_settings["functions"]:
+ del worker_settings["functions"]
+ if not worker_settings["cron_jobs"]:
+ del worker_settings["cron_jobs"]
+
+ worker = arq.worker.create_worker(worker_settings, **kwargs)
+
+ return pool, worker
+
+ return inner
+
+
+@pytest.fixture
+def init_arq_with_kwarg_settings(sentry_init):
+ """Test fixture that passes settings_cls as keyword argument only."""
+
+ def inner(
+ cls_functions=None,
+ cls_cron_jobs=None,
+ kw_functions=None,
+ kw_cron_jobs=None,
+ allow_abort_jobs_=False,
+ ):
+ cls_functions = cls_functions or []
+ cls_cron_jobs = cls_cron_jobs or []
+
+ kwargs = {}
+ if kw_functions is not None:
+ kwargs["functions"] = kw_functions
+ if kw_cron_jobs is not None:
+ kwargs["cron_jobs"] = kw_cron_jobs
+
+ sentry_init(
+ integrations=[ArqIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
)
+ server = FakeRedis()
+ pool = ArqRedis(pool_or_conn=server.connection_pool)
+
+ class WorkerSettings:
+ functions = cls_functions
+ cron_jobs = cls_cron_jobs
+ redis_pool = pool
+ allow_abort_jobs = allow_abort_jobs_
+
+ if not WorkerSettings.functions:
+ del WorkerSettings.functions
+ if not WorkerSettings.cron_jobs:
+ del WorkerSettings.cron_jobs
+
+ # Pass settings_cls as keyword argument (not positional)
+ worker = arq.worker.create_worker(settings_cls=WorkerSettings, **kwargs)
+
+ return pool, worker
+
return inner
@pytest.mark.asyncio
-async def test_job_result(init_arq):
+@pytest.mark.parametrize(
+ "init_arq_settings",
+ ["init_arq", "init_arq_with_dict_settings", "init_arq_with_kwarg_settings"],
+)
+async def test_job_result(init_arq_settings, request):
async def increase(ctx, num):
return num + 1
+ init_fixture_method = request.getfixturevalue(init_arq_settings)
+
increase.__qualname__ = increase.__name__
- pool, worker = init_arq([increase])
+ pool, worker = init_fixture_method([increase])
job = await pool.enqueue_job("increase", 3)
@@ -67,14 +207,19 @@ async def increase(ctx, num):
@pytest.mark.asyncio
-async def test_job_retry(capture_events, init_arq):
+@pytest.mark.parametrize(
+ "init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
+)
+async def test_job_retry(capture_events, init_arq_settings, request):
async def retry_job(ctx):
if ctx["job_try"] < 2:
- raise Retry
+ raise arq.worker.Retry
+
+ init_fixture_method = request.getfixturevalue(init_arq_settings)
retry_job.__qualname__ = retry_job.__name__
- pool, worker = init_arq([retry_job])
+ pool, worker = init_fixture_method([retry_job])
job = await pool.enqueue_job("retry_job")
@@ -97,52 +242,104 @@ async def retry_job(ctx):
assert event["extra"]["arq-job"]["retry"] == 2
+@pytest.mark.parametrize(
+ "source", [("cls_functions", "cls_cron_jobs"), ("kw_functions", "kw_cron_jobs")]
+)
@pytest.mark.parametrize("job_fails", [True, False], ids=["error", "success"])
+@pytest.mark.parametrize(
+ "init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
+)
@pytest.mark.asyncio
-async def test_job_transaction(capture_events, init_arq, job_fails):
+async def test_job_transaction(
+ capture_events, init_arq_settings, source, job_fails, request
+):
async def division(_, a, b=0):
return a / b
+ init_fixture_method = request.getfixturevalue(init_arq_settings)
+
division.__qualname__ = division.__name__
- pool, worker = init_arq([division])
+ cron_func = async_partial(division, a=1, b=int(not job_fails))
+ cron_func.__qualname__ = division.__name__
+
+ cron_job = cron(cron_func, minute=0, run_at_startup=True)
+
+ functions_key, cron_jobs_key = source
+ pool, worker = init_fixture_method(
+ **{functions_key: [division], cron_jobs_key: [cron_job]}
+ )
events = capture_events()
job = await pool.enqueue_job("division", 1, b=int(not job_fails))
await worker.run_job(job.job_id, timestamp_ms())
- if job_fails:
- error_event = events.pop(0)
- assert error_event["exception"]["values"][0]["type"] == "ZeroDivisionError"
- assert error_event["exception"]["values"][0]["mechanism"]["type"] == "arq"
+ loop = asyncio.get_event_loop()
+ task = loop.create_task(worker.async_run())
+ await asyncio.sleep(1)
- (event,) = events
- assert event["type"] == "transaction"
- assert event["transaction"] == "division"
- assert event["transaction_info"] == {"source": "task"}
+ task.cancel()
+
+ await worker.close()
if job_fails:
- assert event["contexts"]["trace"]["status"] == "internal_error"
- else:
- assert event["contexts"]["trace"]["status"] == "ok"
+ error_func_event = events.pop(0)
+ error_cron_event = events.pop(1)
+
+ assert error_func_event["exception"]["values"][0]["type"] == "ZeroDivisionError"
+ assert error_func_event["exception"]["values"][0]["mechanism"]["type"] == "arq"
+
+ func_extra = error_func_event["extra"]["arq-job"]
+ assert func_extra["task"] == "division"
+
+ assert error_cron_event["exception"]["values"][0]["type"] == "ZeroDivisionError"
+ assert error_cron_event["exception"]["values"][0]["mechanism"]["type"] == "arq"
+
+ cron_extra = error_cron_event["extra"]["arq-job"]
+ assert cron_extra["task"] == "cron:division"
+
+ [func_event, cron_event] = events
+
+ assert func_event["type"] == "transaction"
+ assert func_event["transaction"] == "division"
+ assert func_event["transaction_info"] == {"source": "task"}
+
+ assert "arq_task_id" in func_event["tags"]
+ assert "arq_task_retry" in func_event["tags"]
+
+ func_extra = func_event["extra"]["arq-job"]
- assert "arq_task_id" in event["tags"]
- assert "arq_task_retry" in event["tags"]
+ assert func_extra["task"] == "division"
+ assert func_extra["kwargs"] == {"b": int(not job_fails)}
+ assert func_extra["retry"] == 1
- extra = event["extra"]["arq-job"]
- assert extra["task"] == "division"
- assert extra["args"] == [1]
- assert extra["kwargs"] == {"b": int(not job_fails)}
- assert extra["retry"] == 1
+ assert cron_event["type"] == "transaction"
+ assert cron_event["transaction"] == "cron:division"
+ assert cron_event["transaction_info"] == {"source": "task"}
+ assert "arq_task_id" in cron_event["tags"]
+ assert "arq_task_retry" in cron_event["tags"]
+ cron_extra = cron_event["extra"]["arq-job"]
+
+ assert cron_extra["task"] == "cron:division"
+ assert cron_extra["kwargs"] == {}
+ assert cron_extra["retry"] == 1
+
+
+@pytest.mark.parametrize("source", ["cls_functions", "kw_functions"])
+@pytest.mark.parametrize(
+ "init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
+)
@pytest.mark.asyncio
-async def test_enqueue_job(capture_events, init_arq):
+async def test_enqueue_job(capture_events, init_arq_settings, source, request):
async def dummy_job(_):
pass
- pool, _ = init_arq([dummy_job])
+ init_fixture_method = request.getfixturevalue(init_arq_settings)
+
+ pool, _ = init_fixture_method(**{source: [dummy_job]})
events = capture_events()
@@ -157,3 +354,121 @@ async def dummy_job(_):
assert len(event["spans"])
assert event["spans"][0]["op"] == "queue.submit.arq"
assert event["spans"][0]["description"] == "dummy_job"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
+)
+async def test_execute_job_without_integration(init_arq_settings, request):
+ async def dummy_job(_ctx):
+ pass
+
+ init_fixture_method = request.getfixturevalue(init_arq_settings)
+
+ dummy_job.__qualname__ = dummy_job.__name__
+
+ pool, worker = init_fixture_method([dummy_job])
+ # remove the integration to trigger the edge case
+ get_client().integrations.pop("arq")
+
+ job = await pool.enqueue_job("dummy_job")
+
+ await worker.run_job(job.job_id, timestamp_ms())
+
+ assert await job.result() is None
+
+
+@pytest.mark.parametrize("source", ["cls_functions", "kw_functions"])
+@pytest.mark.parametrize(
+ "init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
+)
+@pytest.mark.asyncio
+async def test_span_origin_producer(capture_events, init_arq_settings, source, request):
+ async def dummy_job(_):
+ pass
+
+ init_fixture_method = request.getfixturevalue(init_arq_settings)
+
+ pool, _ = init_fixture_method(**{source: [dummy_job]})
+
+ events = capture_events()
+
+ with start_transaction():
+ await pool.enqueue_job("dummy_job")
+
+ (event,) = events
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.queue.arq"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
+)
+async def test_span_origin_consumer(capture_events, init_arq_settings, request):
+ async def job(ctx):
+ pass
+
+ init_fixture_method = request.getfixturevalue(init_arq_settings)
+
+ job.__qualname__ = job.__name__
+
+ pool, worker = init_fixture_method([job])
+
+ job = await pool.enqueue_job("retry_job")
+
+ events = capture_events()
+
+ await worker.run_job(job.job_id, timestamp_ms())
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.queue.arq"
+ assert event["spans"][0]["origin"] == "auto.db.redis"
+ assert event["spans"][1]["origin"] == "auto.db.redis"
+
+
+@pytest.mark.asyncio
+async def test_job_concurrency(capture_events, init_arq):
+ """
+ 10 - division starts
+ 70 - sleepy starts
+ 110 - division raises error
+ 120 - sleepy finishes
+
+ """
+
+ async def sleepy(_):
+ await asyncio.sleep(0.05)
+
+ async def division(_):
+ await asyncio.sleep(0.1)
+ return 1 / 0
+
+ sleepy.__qualname__ = sleepy.__name__
+ division.__qualname__ = division.__name__
+
+ pool, worker = init_arq([sleepy, division])
+
+ events = capture_events()
+
+ await pool.enqueue_job(
+ "division", _job_id="123", _defer_by=timedelta(milliseconds=10)
+ )
+ await pool.enqueue_job(
+ "sleepy", _job_id="456", _defer_by=timedelta(milliseconds=70)
+ )
+
+ loop = asyncio.get_event_loop()
+ task = loop.create_task(worker.async_run())
+ await asyncio.sleep(1)
+
+ task.cancel()
+
+ await worker.close()
+
+ exception_event = events[1]
+ assert exception_event["exception"]["values"][0]["type"] == "ZeroDivisionError"
+ assert exception_event["transaction"] == "division"
+ assert exception_event["extra"]["arq-job"]["task"] == "division"
diff --git a/tests/integrations/asgi/__init__.py b/tests/integrations/asgi/__init__.py
index 1fb057c1fc..ecc2bcfe95 100644
--- a/tests/integrations/asgi/__init__.py
+++ b/tests/integrations/asgi/__init__.py
@@ -1,4 +1,5 @@
import pytest
-asyncio = pytest.importorskip("asyncio")
-pytest_asyncio = pytest.importorskip("pytest_asyncio")
+pytest.importorskip("asyncio")
+pytest.importorskip("pytest_asyncio")
+pytest.importorskip("async_asgi_testclient")
diff --git a/tests/integrations/asgi/test_asgi.py b/tests/integrations/asgi/test_asgi.py
index ce28b1e8b9..7f44c9d00a 100644
--- a/tests/integrations/asgi/test_asgi.py
+++ b/tests/integrations/asgi/test_asgi.py
@@ -1,30 +1,32 @@
-import sys
-
from collections import Counter
import pytest
import sentry_sdk
from sentry_sdk import capture_message
+from sentry_sdk.tracing import TransactionSource
+from sentry_sdk.integrations._asgi_common import _get_ip, _get_headers
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware, _looks_like_asgi3
-async_asgi_testclient = pytest.importorskip("async_asgi_testclient")
from async_asgi_testclient import TestClient
-minimum_python_36 = pytest.mark.skipif(
- sys.version_info < (3, 6), reason="ASGI is only supported in Python >= 3.6"
-)
-
-
@pytest.fixture
def asgi3_app():
async def app(scope, receive, send):
- if (
+ if scope["type"] == "lifespan":
+ while True:
+ message = await receive()
+ if message["type"] == "lifespan.startup":
+ await send({"type": "lifespan.startup.complete"})
+ elif message["type"] == "lifespan.shutdown":
+ await send({"type": "lifespan.shutdown.complete"})
+ return
+ elif (
scope["type"] == "http"
and "route" in scope
and scope["route"] == "/trigger/error"
):
- division_by_zero = 1 / 0 # noqa
+ 1 / 0
await send(
{
@@ -48,6 +50,42 @@ async def app(scope, receive, send):
@pytest.fixture
def asgi3_app_with_error():
+ async def send_with_error(event):
+ 1 / 0
+
+ async def app(scope, receive, send):
+ if scope["type"] == "lifespan":
+ while True:
+ message = await receive()
+ if message["type"] == "lifespan.startup":
+ ... # Do some startup here!
+ await send({"type": "lifespan.startup.complete"})
+ elif message["type"] == "lifespan.shutdown":
+ ... # Do some shutdown here!
+ await send({"type": "lifespan.shutdown.complete"})
+ return
+ else:
+ await send_with_error(
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [
+ [b"content-type", b"text/plain"],
+ ],
+ }
+ )
+ await send_with_error(
+ {
+ "type": "http.response.body",
+ "body": b"Hello, world!",
+ }
+ )
+
+ return app
+
+
+@pytest.fixture
+def asgi3_app_with_error_and_msg():
async def app(scope, receive, send):
await send(
{
@@ -59,7 +97,8 @@ async def app(scope, receive, send):
}
)
- division_by_zero = 1 / 0 # noqa
+ capture_message("Let's try dividing by 0")
+ 1 / 0
await send(
{
@@ -88,7 +127,32 @@ async def app(scope, receive, send):
return app
-@minimum_python_36
+@pytest.fixture
+def asgi3_custom_transaction_app():
+ async def app(scope, receive, send):
+ sentry_sdk.get_current_scope().set_transaction_name(
+ "foobar", source=TransactionSource.CUSTOM
+ )
+ await send(
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [
+ [b"content-type", b"text/plain"],
+ ],
+ }
+ )
+
+ await send(
+ {
+ "type": "http.response.body",
+ "body": b"Hello, world!",
+ }
+ )
+
+ return app
+
+
def test_invalid_transaction_style(asgi3_app):
with pytest.raises(ValueError) as exp:
SentryAsgiMiddleware(asgi3_app, transaction_style="URL")
@@ -99,105 +163,410 @@ def test_invalid_transaction_style(asgi3_app):
)
-@minimum_python_36
@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
async def test_capture_transaction(
sentry_init,
asgi3_app,
capture_events,
+ capture_items,
+ span_streaming,
):
- sentry_init(send_default_pii=True, traces_sample_rate=1.0)
+ sentry_init(
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
+ )
app = SentryAsgiMiddleware(asgi3_app)
async with TestClient(app) as client:
- events = capture_events()
- await client.get("/?somevalue=123")
-
- (transaction_event,) = events
-
- assert transaction_event["type"] == "transaction"
- assert transaction_event["transaction"] == "generic ASGI request"
- assert transaction_event["contexts"]["trace"]["op"] == "http.server"
- assert transaction_event["request"] == {
- "headers": {
- "host": "localhost",
- "remote-addr": "127.0.0.1",
- "user-agent": "ASGI-Test-Client",
- },
- "method": "GET",
- "query_string": "somevalue=123",
- "url": "http://localhost/",
- }
+ if span_streaming:
+ items = capture_items("span")
+ else:
+ events = capture_events()
+ await client.get("/some_url?somevalue=123")
+
+ sentry_sdk.flush()
+
+ if span_streaming:
+ assert len(items) == 1
+ span = items[0].payload
+
+ assert span["is_segment"] is True
+ assert span["name"] == "/some_url"
+
+ assert span["attributes"]["sentry.span.source"] == "url"
+ assert span["attributes"]["sentry.op"] == "http.server"
+
+ assert span["attributes"]["url.full"] == "http://localhost/some_url"
+ assert span["attributes"]["network.protocol.name"] == "http"
+ assert span["attributes"]["http.request.method"] == "GET"
+ assert span["attributes"]["http.query"] == "somevalue=123"
+ assert span["attributes"]["http.request.header.host"] == "localhost"
+ assert span["attributes"]["http.request.header.remote-addr"] == "127.0.0.1"
+ assert (
+ span["attributes"]["http.request.header.user-agent"] == "ASGI-Test-Client"
+ )
+
+ else:
+ (transaction_event,) = events
+
+ assert transaction_event["type"] == "transaction"
+ assert transaction_event["transaction"] == "/some_url"
+ assert transaction_event["transaction_info"] == {"source": "url"}
+ assert transaction_event["contexts"]["trace"]["op"] == "http.server"
+ assert transaction_event["request"] == {
+ "headers": {
+ "host": "localhost",
+ "remote-addr": "127.0.0.1",
+ "user-agent": "ASGI-Test-Client",
+ },
+ "method": "GET",
+ "query_string": "somevalue=123",
+ "url": "http://localhost/some_url",
+ }
-@minimum_python_36
@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
async def test_capture_transaction_with_error(
sentry_init,
asgi3_app_with_error,
capture_events,
+ capture_items,
DictionaryContaining, # noqa: N803
+ span_streaming,
):
- sentry_init(send_default_pii=True, traces_sample_rate=1.0)
+ sentry_init(
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
+ )
+
app = SentryAsgiMiddleware(asgi3_app_with_error)
+ if span_streaming:
+ items = capture_items("event", "span")
+ else:
+ events = capture_events()
+
with pytest.raises(ZeroDivisionError):
async with TestClient(app) as client:
- events = capture_events()
- await client.get("/")
+ await client.get("/some_url")
- (error_event, transaction_event) = events
+ sentry_sdk.flush()
+
+ if span_streaming:
+ assert len(items) == 2
+ assert items[0].type == "event"
+ assert items[1].type == "span"
- assert error_event["transaction"] == "generic ASGI request"
+ error_event = items[0].payload
+ span_item = items[1].payload
+ else:
+ (error_event, transaction_event) = events
+
+ assert error_event["transaction"] == "/some_url"
+ assert error_event["transaction_info"] == {"source": "url"}
assert error_event["contexts"]["trace"]["op"] == "http.server"
assert error_event["exception"]["values"][0]["type"] == "ZeroDivisionError"
assert error_event["exception"]["values"][0]["value"] == "division by zero"
assert error_event["exception"]["values"][0]["mechanism"]["handled"] is False
assert error_event["exception"]["values"][0]["mechanism"]["type"] == "asgi"
- assert transaction_event["type"] == "transaction"
- assert transaction_event["contexts"]["trace"] == DictionaryContaining(
- error_event["contexts"]["trace"]
+ if span_streaming:
+ assert span_item["trace_id"] == error_event["contexts"]["trace"]["trace_id"]
+ assert span_item["span_id"] == error_event["contexts"]["trace"]["span_id"]
+ assert span_item.get("parent_span_id") == error_event["contexts"]["trace"].get(
+ "parent_span_id"
+ )
+ assert span_item["status"] == "error"
+
+ else:
+ assert transaction_event["type"] == "transaction"
+ assert transaction_event["contexts"]["trace"] == DictionaryContaining(
+ error_event["contexts"]["trace"]
+ )
+ assert transaction_event["contexts"]["trace"]["status"] == "internal_error"
+ assert transaction_event["transaction"] == error_event["transaction"]
+ assert transaction_event["request"] == error_event["request"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
+async def test_has_trace_if_performance_enabled(
+ sentry_init,
+ asgi3_app_with_error_and_msg,
+ capture_events,
+ capture_items,
+ span_streaming,
+):
+ sentry_init(
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
)
- assert transaction_event["contexts"]["trace"]["status"] == "internal_error"
- assert transaction_event["transaction"] == error_event["transaction"]
- assert transaction_event["request"] == error_event["request"]
+ app = SentryAsgiMiddleware(asgi3_app_with_error_and_msg)
+
+ with pytest.raises(ZeroDivisionError):
+ async with TestClient(app) as client:
+ if span_streaming:
+ items = capture_items("event", "span")
+ else:
+ events = capture_events()
+ await client.get("/")
+
+ sentry_sdk.flush()
+
+ if span_streaming:
+ msg_event, error_event, span = items
+
+ assert msg_event.type == "event"
+ msg_event = msg_event.payload
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event.type == "event"
+ error_event = error_event.payload
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert span.type == "span"
+ span = span.payload
+ assert span["trace_id"] is not None
+
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == msg_event["contexts"]["trace"]["trace_id"]
+ == span["trace_id"]
+ )
+
+ else:
+ msg_event, error_event, transaction_event = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert transaction_event["contexts"]["trace"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ == msg_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_has_trace_if_performance_disabled(
+ sentry_init,
+ asgi3_app_with_error_and_msg,
+ capture_events,
+):
+ sentry_init()
+ app = SentryAsgiMiddleware(asgi3_app_with_error_and_msg)
+
+ with pytest.raises(ZeroDivisionError):
+ async with TestClient(app) as client:
+ events = capture_events()
+ await client.get("/")
+
+ msg_event, error_event = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
+@pytest.mark.asyncio
+async def test_trace_from_headers_if_performance_enabled(
+ sentry_init,
+ asgi3_app_with_error_and_msg,
+ capture_events,
+ capture_items,
+ span_streaming,
+):
+ sentry_init(
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
+ )
+ app = SentryAsgiMiddleware(asgi3_app_with_error_and_msg)
+
+ trace_id = "582b43a4192642f0b136d5159a501701"
+ sentry_trace_header = "{}-{}-{}".format(trace_id, "6e8f22c393e68f19", 1)
+
+ with pytest.raises(ZeroDivisionError):
+ async with TestClient(app) as client:
+ if span_streaming:
+ items = capture_items("event", "span")
+ else:
+ events = capture_events()
+ await client.get("/", headers={"sentry-trace": sentry_trace_header})
+
+ sentry_sdk.flush()
+
+ if span_streaming:
+ msg_event, error_event, span = items
+
+ assert msg_event.type == "event"
+ msg_event = msg_event.payload
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event.type == "event"
+ error_event = error_event.payload
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert span.type == "span"
+ span = span.payload
+ assert span["trace_id"] is not None
+
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert span["trace_id"] == trace_id
+
+ else:
+ msg_event, error_event, transaction_event = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert transaction_event["contexts"]["trace"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert transaction_event["contexts"]["trace"]["trace_id"] == trace_id
-@minimum_python_36
@pytest.mark.asyncio
-async def test_websocket(sentry_init, asgi3_ws_app, capture_events, request):
- sentry_init(debug=True, send_default_pii=True)
+async def test_trace_from_headers_if_performance_disabled(
+ sentry_init,
+ asgi3_app_with_error_and_msg,
+ capture_events,
+):
+ sentry_init()
+ app = SentryAsgiMiddleware(asgi3_app_with_error_and_msg)
+
+ trace_id = "582b43a4192642f0b136d5159a501701"
+ sentry_trace_header = "{}-{}-{}".format(trace_id, "6e8f22c393e68f19", 1)
+
+ with pytest.raises(ZeroDivisionError):
+ async with TestClient(app) as client:
+ events = capture_events()
+ await client.get("/", headers={"sentry-trace": sentry_trace_header})
+
+ msg_event, error_event = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
- events = capture_events()
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
+async def test_websocket(
+ sentry_init,
+ asgi3_ws_app,
+ capture_events,
+ capture_items,
+ request,
+ span_streaming,
+):
+ sentry_init(
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
+ )
asgi3_ws_app = SentryAsgiMiddleware(asgi3_ws_app)
- scope = {
- "type": "websocket",
- "endpoint": asgi3_app,
- "client": ("127.0.0.1", 60457),
- "route": "some_url",
- "headers": [
- ("accept", "*/*"),
- ],
- }
+ request_url = "/ws"
with pytest.raises(ValueError):
- async with TestClient(asgi3_ws_app, scope=scope) as client:
- async with client.websocket_connect("/ws") as ws:
- await ws.receive_text()
+ client = TestClient(asgi3_ws_app)
+ if span_streaming:
+ items = capture_items("event", "span")
+ else:
+ events = capture_events()
+ async with client.websocket_connect(request_url) as ws:
+ await ws.receive_text()
- msg_event, error_event = events
+ sentry_sdk.flush()
- assert msg_event["message"] == "Some message to the world!"
+ if span_streaming:
+ msg_event, error_event, span = items
- (exc,) = error_event["exception"]["values"]
- assert exc["type"] == "ValueError"
- assert exc["value"] == "Oh no"
+ assert msg_event.type == "event"
+ msg_event = msg_event.payload
+ assert msg_event["transaction"] == request_url
+ assert msg_event["transaction_info"] == {"source": "url"}
+ assert msg_event["message"] == "Some message to the world!"
+
+ assert error_event.type == "event"
+ error_event = error_event.payload
+ (exc,) = error_event["exception"]["values"]
+ assert exc["type"] == "ValueError"
+ assert exc["value"] == "Oh no"
+
+ assert span.type == "span"
+ span = span.payload
+ assert span["name"] == request_url
+ assert span["attributes"]["sentry.span.source"] == "url"
+
+ else:
+ msg_event, error_event, transaction_event = events
+
+ assert msg_event["transaction"] == request_url
+ assert msg_event["transaction_info"] == {"source": "url"}
+ assert msg_event["message"] == "Some message to the world!"
+
+ (exc,) = error_event["exception"]["values"]
+ assert exc["type"] == "ValueError"
+ assert exc["value"] == "Oh no"
+
+ assert transaction_event["transaction"] == request_url
+ assert transaction_event["transaction_info"] == {"source": "url"}
-@minimum_python_36
@pytest.mark.asyncio
async def test_auto_session_tracking_with_aggregates(
sentry_init, asgi3_app, capture_envelopes
@@ -225,18 +594,17 @@ async def test_auto_session_tracking_with_aggregates(
for envelope in envelopes:
count_item_types[envelope.items[0].type] += 1
- assert count_item_types["transaction"] == 4
+ assert count_item_types["transaction"] == 3
assert count_item_types["event"] == 1
assert count_item_types["sessions"] == 1
- assert len(envelopes) == 6
+ assert len(envelopes) == 5
session_aggregates = envelopes[-1].items[0].payload.json["aggregates"]
- assert session_aggregates[0]["exited"] == 3
+ assert session_aggregates[0]["exited"] == 2
assert session_aggregates[0]["crashed"] == 1
assert len(session_aggregates) == 1
-@minimum_python_36
@pytest.mark.parametrize(
"url,transaction_style,expected_transaction,expected_source",
[
@@ -249,41 +617,63 @@ async def test_auto_session_tracking_with_aggregates(
(
"/message",
"endpoint",
- "tests.integrations.asgi.test_asgi.asgi3_app_with_error..app",
+ "tests.integrations.asgi.test_asgi.asgi3_app..app",
"component",
),
],
)
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
@pytest.mark.asyncio
async def test_transaction_style(
sentry_init,
- asgi3_app_with_error,
+ asgi3_app,
capture_events,
+ capture_items,
url,
transaction_style,
expected_transaction,
expected_source,
+ span_streaming,
):
- sentry_init(send_default_pii=True, traces_sample_rate=1.0)
- app = SentryAsgiMiddleware(
- asgi3_app_with_error, transaction_style=transaction_style
+ sentry_init(
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
)
+ app = SentryAsgiMiddleware(asgi3_app, transaction_style=transaction_style)
scope = {
- "endpoint": asgi3_app_with_error,
+ "endpoint": asgi3_app,
"route": url,
"client": ("127.0.0.1", 60457),
}
- with pytest.raises(ZeroDivisionError):
- async with TestClient(app, scope=scope) as client:
+ async with TestClient(app, scope=scope) as client:
+ if span_streaming:
+ items = capture_items("span")
+ else:
events = capture_events()
- await client.get(url)
+ await client.get(url)
+
+ sentry_sdk.flush()
+
+ if span_streaming:
+ assert len(items) == 1
+ span = items[0].payload
- (_, transaction_event) = events
+ assert span["name"] == expected_transaction
+ assert span["attributes"]["sentry.span.source"] == expected_source
- assert transaction_event["transaction"] == expected_transaction
- assert transaction_event["transaction_info"] == {"source": expected_source}
+ else:
+ (transaction_event,) = events
+
+ assert transaction_event["transaction"] == expected_transaction
+ assert transaction_event["transaction_info"] == {"source": expected_source}
def mock_asgi2_app():
@@ -303,7 +693,6 @@ async def __call__():
pass
-@minimum_python_36
def test_looks_like_asgi3(asgi3_app):
# branch: inspect.isclass(app)
assert _looks_like_asgi3(MockAsgi3App)
@@ -320,7 +709,6 @@ def test_looks_like_asgi3(asgi3_app):
assert not _looks_like_asgi3(asgi2)
-@minimum_python_36
def test_get_ip_x_forwarded_for():
headers = [
(b"x-forwarded-for", b"8.8.8.8"),
@@ -329,8 +717,7 @@ def test_get_ip_x_forwarded_for():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- ip = middleware._get_ip(scope)
+ ip = _get_ip(scope)
assert ip == "8.8.8.8"
# x-forwarded-for overrides x-real-ip
@@ -342,8 +729,7 @@ def test_get_ip_x_forwarded_for():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- ip = middleware._get_ip(scope)
+ ip = _get_ip(scope)
assert ip == "8.8.8.8"
# when multiple x-forwarded-for headers are, the first is taken
@@ -356,12 +742,10 @@ def test_get_ip_x_forwarded_for():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- ip = middleware._get_ip(scope)
+ ip = _get_ip(scope)
assert ip == "5.5.5.5"
-@minimum_python_36
def test_get_ip_x_real_ip():
headers = [
(b"x-real-ip", b"10.10.10.10"),
@@ -370,8 +754,7 @@ def test_get_ip_x_real_ip():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- ip = middleware._get_ip(scope)
+ ip = _get_ip(scope)
assert ip == "10.10.10.10"
# x-forwarded-for overrides x-real-ip
@@ -383,12 +766,10 @@ def test_get_ip_x_real_ip():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- ip = middleware._get_ip(scope)
+ ip = _get_ip(scope)
assert ip == "8.8.8.8"
-@minimum_python_36
def test_get_ip():
# if now headers are provided the ip is taken from the client.
headers = []
@@ -396,8 +777,7 @@ def test_get_ip():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- ip = middleware._get_ip(scope)
+ ip = _get_ip(scope)
assert ip == "127.0.0.1"
# x-forwarded-for header overides the ip from client
@@ -408,8 +788,7 @@ def test_get_ip():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- ip = middleware._get_ip(scope)
+ ip = _get_ip(scope)
assert ip == "8.8.8.8"
# x-real-for header overides the ip from client
@@ -420,12 +799,10 @@ def test_get_ip():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- ip = middleware._get_ip(scope)
+ ip = _get_ip(scope)
assert ip == "10.10.10.10"
-@minimum_python_36
def test_get_headers():
headers = [
(b"x-real-ip", b"10.10.10.10"),
@@ -436,9 +813,192 @@ def test_get_headers():
"client": ("127.0.0.1", 60457),
"headers": headers,
}
- middleware = SentryAsgiMiddleware({})
- headers = middleware._get_headers(scope)
+ headers = _get_headers(scope)
assert headers == {
"x-real-ip": "10.10.10.10",
"some_header": "123, abc",
}
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "request_url,transaction_style,expected_transaction_name,expected_transaction_source",
+ [
+ (
+ "/message/123456",
+ "endpoint",
+ "/message/123456",
+ "url",
+ ),
+ (
+ "/message/123456",
+ "url",
+ "/message/123456",
+ "url",
+ ),
+ ],
+)
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
+async def test_transaction_name(
+ sentry_init,
+ request_url,
+ transaction_style,
+ expected_transaction_name,
+ expected_transaction_source,
+ asgi3_app,
+ capture_envelopes,
+ capture_items,
+ span_streaming,
+):
+ """
+ Tests that the transaction name is something meaningful.
+ """
+ sentry_init(
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
+ )
+
+ if span_streaming:
+ items = capture_items("span")
+ else:
+ envelopes = capture_envelopes()
+
+ app = SentryAsgiMiddleware(asgi3_app, transaction_style=transaction_style)
+
+ async with TestClient(app) as client:
+ await client.get(request_url)
+
+ if span_streaming:
+ sentry_sdk.flush()
+
+ assert len(items) == 1
+ span = items[0].payload
+
+ assert span["name"] == expected_transaction_name
+ assert span["attributes"]["sentry.span.source"] == expected_transaction_source
+
+ else:
+ (transaction_envelope,) = envelopes
+ transaction_event = transaction_envelope.get_transaction_event()
+
+ assert transaction_event["transaction"] == expected_transaction_name
+ assert (
+ transaction_event["transaction_info"]["source"]
+ == expected_transaction_source
+ )
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "request_url, transaction_style,expected_transaction_name,expected_transaction_source",
+ [
+ (
+ "/message/123456",
+ "endpoint",
+ "/message/123456",
+ "url",
+ ),
+ (
+ "/message/123456",
+ "url",
+ "/message/123456",
+ "url",
+ ),
+ ],
+)
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
+async def test_transaction_name_in_traces_sampler(
+ sentry_init,
+ request_url,
+ transaction_style,
+ expected_transaction_name,
+ expected_transaction_source,
+ asgi3_app,
+ span_streaming,
+):
+ """
+ Tests that a custom traces_sampler has a meaningful transaction name.
+ In this case the URL or endpoint, because we do not have the route yet.
+ """
+
+ def dummy_traces_sampler(sampling_context):
+ if span_streaming:
+ assert sampling_context["span_context"]["name"] == expected_transaction_name
+ assert (
+ sampling_context["span_context"]["attributes"]["sentry.span.source"]
+ == expected_transaction_source
+ )
+ else:
+ assert (
+ sampling_context["transaction_context"]["name"]
+ == expected_transaction_name
+ )
+ assert (
+ sampling_context["transaction_context"]["source"]
+ == expected_transaction_source
+ )
+
+ sentry_init(
+ traces_sampler=dummy_traces_sampler,
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
+ )
+
+ app = SentryAsgiMiddleware(asgi3_app, transaction_style=transaction_style)
+
+ async with TestClient(app) as client:
+ await client.get(request_url)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "span_streaming",
+ [True, False],
+)
+async def test_custom_transaction_name(
+ sentry_init,
+ asgi3_custom_transaction_app,
+ capture_events,
+ capture_items,
+ span_streaming,
+):
+ sentry_init(
+ traces_sample_rate=1.0,
+ _experiments={
+ "trace_lifecycle": "stream" if span_streaming else "static",
+ },
+ )
+ app = SentryAsgiMiddleware(asgi3_custom_transaction_app)
+
+ async with TestClient(app) as client:
+ if span_streaming:
+ items = capture_items("span")
+ else:
+ events = capture_events()
+ await client.get("/test")
+
+ sentry_sdk.flush()
+
+ if span_streaming:
+ assert len(items) == 1
+ span = items[0].payload
+
+ assert span["is_segment"] is True
+ assert span["name"] == "foobar"
+ assert span["attributes"]["sentry.span.source"] == "custom"
+
+ else:
+ (transaction_event,) = events
+ assert transaction_event["type"] == "transaction"
+ assert transaction_event["transaction"] == "foobar"
+ assert transaction_event["transaction_info"] == {"source": "custom"}
diff --git a/tests/integrations/asyncio/test_asyncio.py b/tests/integrations/asyncio/test_asyncio.py
new file mode 100644
index 0000000000..d32849c7b5
--- /dev/null
+++ b/tests/integrations/asyncio/test_asyncio.py
@@ -0,0 +1,683 @@
+import asyncio
+import inspect
+import sys
+from unittest.mock import MagicMock, Mock, patch
+
+if sys.version_info >= (3, 8):
+ from unittest.mock import AsyncMock
+
+import pytest
+
+import sentry_sdk
+from sentry_sdk.consts import OP
+from sentry_sdk.integrations.asyncio import (
+ AsyncioIntegration,
+ patch_asyncio,
+ enable_asyncio_integration,
+)
+
+try:
+ from contextvars import Context, ContextVar
+except ImportError:
+ pass # All tests will be skipped with incompatible versions
+
+
+minimum_python_38 = pytest.mark.skipif(
+ sys.version_info < (3, 8), reason="Asyncio tests need Python >= 3.8"
+)
+
+
+minimum_python_39 = pytest.mark.skipif(
+ sys.version_info < (3, 9), reason="Test requires Python >= 3.9"
+)
+
+
+minimum_python_311 = pytest.mark.skipif(
+ sys.version_info < (3, 11),
+ reason="Asyncio task context parameter was introduced in Python 3.11",
+)
+
+
+async def foo():
+ await asyncio.sleep(0.01)
+
+
+async def bar():
+ await asyncio.sleep(0.01)
+
+
+async def boom():
+ 1 / 0
+
+
+def get_sentry_task_factory(mock_get_running_loop):
+ """
+ Patches (mocked) asyncio and gets the sentry_task_factory.
+ """
+ mock_loop = mock_get_running_loop.return_value
+ patch_asyncio()
+ patched_factory = mock_loop.set_task_factory.call_args[0][0]
+
+ return patched_factory
+
+
+@minimum_python_38
+@pytest.mark.asyncio(loop_scope="module")
+async def test_create_task(
+ sentry_init,
+ capture_events,
+):
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ integrations=[
+ AsyncioIntegration(),
+ ],
+ )
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test_transaction_for_create_task"):
+ with sentry_sdk.start_span(op="root", name="not so important"):
+ foo_task = asyncio.create_task(foo())
+ bar_task = asyncio.create_task(bar())
+
+ if hasattr(foo_task.get_coro(), "__name__"):
+ assert foo_task.get_coro().__name__ == "foo"
+ if hasattr(bar_task.get_coro(), "__name__"):
+ assert bar_task.get_coro().__name__ == "bar"
+
+ tasks = [foo_task, bar_task]
+
+ await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
+
+ sentry_sdk.flush()
+
+ (transaction_event,) = events
+
+ assert transaction_event["spans"][0]["op"] == "root"
+ assert transaction_event["spans"][0]["description"] == "not so important"
+
+ assert transaction_event["spans"][1]["op"] == OP.FUNCTION
+ assert transaction_event["spans"][1]["description"] == "foo"
+ assert (
+ transaction_event["spans"][1]["parent_span_id"]
+ == transaction_event["spans"][0]["span_id"]
+ )
+
+ assert transaction_event["spans"][2]["op"] == OP.FUNCTION
+ assert transaction_event["spans"][2]["description"] == "bar"
+ assert (
+ transaction_event["spans"][2]["parent_span_id"]
+ == transaction_event["spans"][0]["span_id"]
+ )
+
+
+@minimum_python_38
+@pytest.mark.asyncio(loop_scope="module")
+async def test_gather(
+ sentry_init,
+ capture_events,
+):
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ integrations=[
+ AsyncioIntegration(),
+ ],
+ )
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test_transaction_for_gather"):
+ with sentry_sdk.start_span(op="root", name="not so important"):
+ await asyncio.gather(foo(), bar(), return_exceptions=True)
+
+ sentry_sdk.flush()
+
+ (transaction_event,) = events
+
+ assert transaction_event["spans"][0]["op"] == "root"
+ assert transaction_event["spans"][0]["description"] == "not so important"
+
+ assert transaction_event["spans"][1]["op"] == OP.FUNCTION
+ assert transaction_event["spans"][1]["description"] == "foo"
+ assert (
+ transaction_event["spans"][1]["parent_span_id"]
+ == transaction_event["spans"][0]["span_id"]
+ )
+
+ assert transaction_event["spans"][2]["op"] == OP.FUNCTION
+ assert transaction_event["spans"][2]["description"] == "bar"
+ assert (
+ transaction_event["spans"][2]["parent_span_id"]
+ == transaction_event["spans"][0]["span_id"]
+ )
+
+
+@minimum_python_38
+@pytest.mark.asyncio(loop_scope="module")
+async def test_exception(
+ sentry_init,
+ capture_events,
+):
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ integrations=[
+ AsyncioIntegration(),
+ ],
+ )
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test_exception"):
+ with sentry_sdk.start_span(op="root", name="not so important"):
+ tasks = [asyncio.create_task(boom()), asyncio.create_task(bar())]
+ await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
+
+ sentry_sdk.flush()
+
+ (error_event, _) = events
+
+ assert error_event["transaction"] == "test_exception"
+ assert error_event["contexts"]["trace"]["op"] == "function"
+ assert error_event["exception"]["values"][0]["type"] == "ZeroDivisionError"
+ assert error_event["exception"]["values"][0]["value"] == "division by zero"
+ assert error_event["exception"]["values"][0]["mechanism"]["handled"] is False
+ assert error_event["exception"]["values"][0]["mechanism"]["type"] == "asyncio"
+
+
+@minimum_python_38
+@pytest.mark.asyncio(loop_scope="module")
+async def test_task_result(sentry_init):
+ sentry_init(
+ integrations=[
+ AsyncioIntegration(),
+ ],
+ )
+
+ async def add(a, b):
+ return a + b
+
+ result = await asyncio.create_task(add(1, 2))
+ assert result == 3, result
+
+
+@minimum_python_311
+@pytest.mark.asyncio(loop_scope="module")
+async def test_task_with_context(sentry_init):
+ """
+ Integration test to ensure working context parameter in Python 3.11+
+ """
+ sentry_init(
+ integrations=[
+ AsyncioIntegration(),
+ ],
+ )
+
+ var = ContextVar("var")
+ var.set("original value")
+
+ async def change_value():
+ var.set("changed value")
+
+ async def retrieve_value():
+ return var.get()
+
+ # Create a context and run both tasks within the context
+ ctx = Context()
+ async with asyncio.TaskGroup() as tg:
+ tg.create_task(change_value(), context=ctx)
+ retrieve_task = tg.create_task(retrieve_value(), context=ctx)
+
+ assert retrieve_task.result() == "changed value"
+
+
+@minimum_python_38
+@patch("asyncio.get_running_loop")
+def test_patch_asyncio(mock_get_running_loop):
+ """
+ Test that the patch_asyncio function will patch the task factory.
+ """
+ mock_loop = mock_get_running_loop.return_value
+ mock_loop.get_task_factory.return_value._is_sentry_task_factory = False
+
+ patch_asyncio()
+
+ assert mock_loop.set_task_factory.called
+
+ set_task_factory_args, _ = mock_loop.set_task_factory.call_args
+ assert len(set_task_factory_args) == 1
+
+ sentry_task_factory, *_ = set_task_factory_args
+ assert callable(sentry_task_factory)
+
+
+@minimum_python_38
+@patch("asyncio.get_running_loop")
+@patch("sentry_sdk.integrations.asyncio.Task")
+def test_sentry_task_factory_no_factory(MockTask, mock_get_running_loop): # noqa: N803
+ mock_loop = mock_get_running_loop.return_value
+ mock_coro = MagicMock()
+
+ # Set the original task factory to None
+ mock_loop.get_task_factory.return_value = None
+
+ # Retieve sentry task factory (since it is an inner function within patch_asyncio)
+ sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)
+
+ # The call we are testing
+ ret_val = sentry_task_factory(mock_loop, mock_coro)
+
+ assert MockTask.called
+ assert ret_val == MockTask.return_value
+
+ task_args, task_kwargs = MockTask.call_args
+ assert len(task_args) == 1
+
+ coro_param, *_ = task_args
+ assert inspect.iscoroutine(coro_param)
+
+ assert "loop" in task_kwargs
+ assert task_kwargs["loop"] == mock_loop
+
+
+@minimum_python_38
+@patch("asyncio.get_running_loop")
+def test_sentry_task_factory_with_factory(mock_get_running_loop):
+ mock_loop = mock_get_running_loop.return_value
+ mock_coro = MagicMock()
+
+ # The original task factory will be mocked out here, let's retrieve the value for later
+ orig_task_factory = mock_loop.get_task_factory.return_value
+ orig_task_factory._is_sentry_task_factory = False
+
+ # Retieve sentry task factory (since it is an inner function within patch_asyncio)
+ sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)
+
+ # The call we are testing
+ ret_val = sentry_task_factory(mock_loop, mock_coro)
+
+ assert orig_task_factory.called
+ assert ret_val == orig_task_factory.return_value
+
+ task_factory_args, _ = orig_task_factory.call_args
+ assert len(task_factory_args) == 2
+
+ loop_arg, coro_arg = task_factory_args
+ assert loop_arg == mock_loop
+ assert inspect.iscoroutine(coro_arg)
+
+
+@minimum_python_311
+@patch("asyncio.get_running_loop")
+@patch("sentry_sdk.integrations.asyncio.Task")
+def test_sentry_task_factory_context_no_factory(
+ MockTask,
+ mock_get_running_loop, # noqa: N803
+):
+ mock_loop = mock_get_running_loop.return_value
+ mock_coro = MagicMock()
+ mock_context = MagicMock()
+
+ # Set the original task factory to None
+ mock_loop.get_task_factory.return_value = None
+
+ # Retieve sentry task factory (since it is an inner function within patch_asyncio)
+ sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)
+
+ # The call we are testing
+ ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context)
+
+ assert MockTask.called
+ assert ret_val == MockTask.return_value
+
+ task_args, task_kwargs = MockTask.call_args
+ assert len(task_args) == 1
+
+ coro_param, *_ = task_args
+ assert inspect.iscoroutine(coro_param)
+
+ assert "loop" in task_kwargs
+ assert task_kwargs["loop"] == mock_loop
+ assert "context" in task_kwargs
+ assert task_kwargs["context"] == mock_context
+
+
+@minimum_python_311
+@patch("asyncio.get_running_loop")
+def test_sentry_task_factory_context_with_factory(mock_get_running_loop):
+ mock_loop = mock_get_running_loop.return_value
+ mock_coro = MagicMock()
+ mock_context = MagicMock()
+
+ # The original task factory will be mocked out here, let's retrieve the value for later
+ orig_task_factory = mock_loop.get_task_factory.return_value
+ orig_task_factory._is_sentry_task_factory = False
+
+ # Retieve sentry task factory (since it is an inner function within patch_asyncio)
+ sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)
+
+ # The call we are testing
+ ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context)
+
+ assert orig_task_factory.called
+ assert ret_val == orig_task_factory.return_value
+
+ task_factory_args, task_factory_kwargs = orig_task_factory.call_args
+ assert len(task_factory_args) == 2
+
+ loop_arg, coro_arg = task_factory_args
+ assert loop_arg == mock_loop
+ assert inspect.iscoroutine(coro_arg)
+
+ assert "context" in task_factory_kwargs
+ assert task_factory_kwargs["context"] == mock_context
+
+
+@minimum_python_38
+@pytest.mark.asyncio(loop_scope="module")
+async def test_span_origin(
+ sentry_init,
+ capture_events,
+):
+ sentry_init(
+ integrations=[AsyncioIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="something"):
+ tasks = [
+ asyncio.create_task(foo()),
+ ]
+ await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
+
+ sentry_sdk.flush()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.function.asyncio"
+
+
+@minimum_python_38
+@pytest.mark.asyncio
+async def test_task_spans_false(
+ sentry_init,
+ capture_events,
+ uninstall_integration,
+):
+ uninstall_integration("asyncio")
+
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[
+ AsyncioIntegration(task_spans=False),
+ ],
+ )
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test_no_spans"):
+ tasks = [asyncio.create_task(foo()), asyncio.create_task(bar())]
+ await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
+
+ sentry_sdk.flush()
+
+ (transaction_event,) = events
+
+ assert not transaction_event["spans"]
+
+
+@minimum_python_38
+@pytest.mark.asyncio
+async def test_enable_asyncio_integration_with_task_spans_false(
+ sentry_init,
+ capture_events,
+ uninstall_integration,
+):
+ """
+ Test that enable_asyncio_integration() helper works with task_spans=False.
+ """
+ uninstall_integration("asyncio")
+
+ sentry_init(traces_sample_rate=1.0)
+
+ assert "asyncio" not in sentry_sdk.get_client().integrations
+
+ enable_asyncio_integration(task_spans=False)
+
+ assert "asyncio" in sentry_sdk.get_client().integrations
+ assert sentry_sdk.get_client().integrations["asyncio"].task_spans is False
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test"):
+ await asyncio.create_task(foo())
+
+ sentry_sdk.flush()
+
+ (transaction,) = events
+ assert not transaction["spans"]
+
+
+@minimum_python_38
+@pytest.mark.asyncio
+async def test_delayed_enable_integration(sentry_init, capture_events):
+ sentry_init(traces_sample_rate=1.0)
+
+ assert "asyncio" not in sentry_sdk.get_client().integrations
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test"):
+ await asyncio.create_task(foo())
+
+ assert len(events) == 1
+ (transaction,) = events
+ assert not transaction["spans"]
+
+ enable_asyncio_integration()
+
+ events = capture_events()
+
+ assert "asyncio" in sentry_sdk.get_client().integrations
+
+ with sentry_sdk.start_transaction(name="test"):
+ await asyncio.create_task(foo())
+
+ assert len(events) == 1
+ (transaction,) = events
+ assert transaction["spans"]
+ assert transaction["spans"][0]["origin"] == "auto.function.asyncio"
+
+
+@minimum_python_38
+@pytest.mark.asyncio
+async def test_delayed_enable_integration_with_options(sentry_init, capture_events):
+ sentry_init(traces_sample_rate=1.0)
+
+ assert "asyncio" not in sentry_sdk.get_client().integrations
+
+ mock_init = MagicMock(return_value=None)
+ mock_setup_once = MagicMock()
+ with patch(
+ "sentry_sdk.integrations.asyncio.AsyncioIntegration.__init__", mock_init
+ ):
+ with patch(
+ "sentry_sdk.integrations.asyncio.AsyncioIntegration.setup_once",
+ mock_setup_once,
+ ):
+ enable_asyncio_integration("arg", kwarg="kwarg")
+
+ assert "asyncio" in sentry_sdk.get_client().integrations
+ mock_init.assert_called_once_with("arg", kwarg="kwarg")
+ mock_setup_once.assert_called_once()
+
+
+@minimum_python_38
+@pytest.mark.asyncio
+async def test_delayed_enable_enabled_integration(sentry_init, uninstall_integration):
+ # Ensure asyncio integration is not already installed from previous tests
+ uninstall_integration("asyncio")
+
+ integration = AsyncioIntegration()
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ assert "asyncio" in sentry_sdk.get_client().integrations
+
+ # Get the task factory after initial setup - it should be Sentry's
+ loop = asyncio.get_running_loop()
+ task_factory_before = loop.get_task_factory()
+ assert getattr(task_factory_before, "_is_sentry_task_factory", False) is True
+
+ enable_asyncio_integration()
+
+ assert "asyncio" in sentry_sdk.get_client().integrations
+
+ # The task factory should be the same (loop not re-patched)
+ task_factory_after = loop.get_task_factory()
+ assert task_factory_before is task_factory_after
+
+
+@minimum_python_38
+@pytest.mark.asyncio
+async def test_delayed_enable_integration_after_disabling(sentry_init, capture_events):
+ sentry_init(disabled_integrations=[AsyncioIntegration()], traces_sample_rate=1.0)
+
+ assert "asyncio" not in sentry_sdk.get_client().integrations
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test"):
+ await asyncio.create_task(foo())
+
+ assert len(events) == 1
+ (transaction,) = events
+ assert not transaction["spans"]
+
+ enable_asyncio_integration()
+
+ events = capture_events()
+
+ assert "asyncio" in sentry_sdk.get_client().integrations
+
+ with sentry_sdk.start_transaction(name="test"):
+ await asyncio.create_task(foo())
+
+ assert len(events) == 1
+ (transaction,) = events
+ assert transaction["spans"]
+ assert transaction["spans"][0]["origin"] == "auto.function.asyncio"
+
+
+@minimum_python_39
+@pytest.mark.asyncio(loop_scope="module")
+async def test_internal_tasks_not_wrapped(sentry_init, capture_events):
+ from sentry_sdk.utils import mark_sentry_task_internal
+
+ sentry_init(integrations=[AsyncioIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ # Create a user task that should be wrapped
+ async def user_task():
+ await asyncio.sleep(0.01)
+ return "user_result"
+
+ # Create an internal task that should NOT be wrapped
+ async def internal_task():
+ await asyncio.sleep(0.01)
+ return "internal_result"
+
+ with sentry_sdk.start_transaction(name="test_transaction"):
+ user_task_obj = asyncio.create_task(user_task())
+
+ with mark_sentry_task_internal():
+ internal_task_obj = asyncio.create_task(internal_task())
+
+ user_result = await user_task_obj
+ internal_result = await internal_task_obj
+
+ assert user_result == "user_result"
+ assert internal_result == "internal_result"
+
+ assert len(events) == 1
+ transaction = events[0]
+
+ user_spans = []
+ internal_spans = []
+
+ for span in transaction.get("spans", []):
+ if "user_task" in span.get("description", ""):
+ user_spans.append(span)
+ elif "internal_task" in span.get("description", ""):
+ internal_spans.append(span)
+
+ assert len(user_spans) > 0, (
+ f"User task should have been traced. All spans: {[s.get('description') for s in transaction.get('spans', [])]}"
+ )
+ assert len(internal_spans) == 0, (
+ f"Internal task should NOT have been traced. All spans: {[s.get('description') for s in transaction.get('spans', [])]}"
+ )
+
+
+@minimum_python_38
+def test_loop_close_patching(sentry_init):
+ sentry_init(integrations=[AsyncioIntegration()])
+
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ try:
+ with patch("asyncio.get_running_loop", return_value=loop):
+ assert not hasattr(loop, "_sentry_flush_patched")
+ AsyncioIntegration.setup_once()
+ assert hasattr(loop, "_sentry_flush_patched")
+ assert loop._sentry_flush_patched is True
+
+ finally:
+ if not loop.is_closed():
+ loop.close()
+
+
+@minimum_python_38
+def test_loop_close_flushes_async_transport(sentry_init):
+ from sentry_sdk.transport import ASYNC_TRANSPORT_AVAILABLE, AsyncHttpTransport
+
+ if not ASYNC_TRANSPORT_AVAILABLE:
+ pytest.skip("httpcore[asyncio] not installed")
+
+ sentry_init(integrations=[AsyncioIntegration()])
+
+ # Save the current event loop to restore it later
+ try:
+ original_loop = asyncio.get_event_loop()
+ except RuntimeError:
+ original_loop = None
+
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ try:
+ with patch("asyncio.get_running_loop", return_value=loop):
+ AsyncioIntegration.setup_once()
+
+ mock_client = Mock()
+ mock_transport = Mock(spec=AsyncHttpTransport)
+ mock_client.transport = mock_transport
+ mock_client.close_async = AsyncMock(return_value=None)
+
+ with patch("sentry_sdk.get_client", return_value=mock_client):
+ loop.close()
+
+ mock_client.close_async.assert_called_once()
+ mock_client.close_async.assert_awaited_once()
+
+ finally:
+ if not loop.is_closed():
+ loop.close()
+ if original_loop:
+ asyncio.set_event_loop(original_loop)
diff --git a/tests/integrations/asyncio/test_asyncio_py3.py b/tests/integrations/asyncio/test_asyncio_py3.py
deleted file mode 100644
index 98106ed01f..0000000000
--- a/tests/integrations/asyncio/test_asyncio_py3.py
+++ /dev/null
@@ -1,172 +0,0 @@
-import asyncio
-import sys
-
-import pytest
-
-import sentry_sdk
-from sentry_sdk.consts import OP
-from sentry_sdk.integrations.asyncio import AsyncioIntegration
-
-
-minimum_python_37 = pytest.mark.skipif(
- sys.version_info < (3, 7), reason="Asyncio tests need Python >= 3.7"
-)
-
-
-async def foo():
- await asyncio.sleep(0.01)
-
-
-async def bar():
- await asyncio.sleep(0.01)
-
-
-async def boom():
- 1 / 0
-
-
-@pytest.fixture(scope="session")
-def event_loop(request):
- """Create an instance of the default event loop for each test case."""
- loop = asyncio.get_event_loop_policy().new_event_loop()
- yield loop
- loop.close()
-
-
-@minimum_python_37
-@pytest.mark.asyncio
-async def test_create_task(
- sentry_init,
- capture_events,
- event_loop,
-):
- sentry_init(
- traces_sample_rate=1.0,
- send_default_pii=True,
- debug=True,
- integrations=[
- AsyncioIntegration(),
- ],
- )
-
- events = capture_events()
-
- with sentry_sdk.start_transaction(name="test_transaction_for_create_task"):
- with sentry_sdk.start_span(op="root", description="not so important"):
- tasks = [event_loop.create_task(foo()), event_loop.create_task(bar())]
- await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
-
- sentry_sdk.flush()
-
- (transaction_event,) = events
-
- assert transaction_event["spans"][0]["op"] == "root"
- assert transaction_event["spans"][0]["description"] == "not so important"
-
- assert transaction_event["spans"][1]["op"] == OP.FUNCTION
- assert transaction_event["spans"][1]["description"] == "foo"
- assert (
- transaction_event["spans"][1]["parent_span_id"]
- == transaction_event["spans"][0]["span_id"]
- )
-
- assert transaction_event["spans"][2]["op"] == OP.FUNCTION
- assert transaction_event["spans"][2]["description"] == "bar"
- assert (
- transaction_event["spans"][2]["parent_span_id"]
- == transaction_event["spans"][0]["span_id"]
- )
-
-
-@minimum_python_37
-@pytest.mark.asyncio
-async def test_gather(
- sentry_init,
- capture_events,
-):
- sentry_init(
- traces_sample_rate=1.0,
- send_default_pii=True,
- debug=True,
- integrations=[
- AsyncioIntegration(),
- ],
- )
-
- events = capture_events()
-
- with sentry_sdk.start_transaction(name="test_transaction_for_gather"):
- with sentry_sdk.start_span(op="root", description="not so important"):
- await asyncio.gather(foo(), bar(), return_exceptions=True)
-
- sentry_sdk.flush()
-
- (transaction_event,) = events
-
- assert transaction_event["spans"][0]["op"] == "root"
- assert transaction_event["spans"][0]["description"] == "not so important"
-
- assert transaction_event["spans"][1]["op"] == OP.FUNCTION
- assert transaction_event["spans"][1]["description"] == "foo"
- assert (
- transaction_event["spans"][1]["parent_span_id"]
- == transaction_event["spans"][0]["span_id"]
- )
-
- assert transaction_event["spans"][2]["op"] == OP.FUNCTION
- assert transaction_event["spans"][2]["description"] == "bar"
- assert (
- transaction_event["spans"][2]["parent_span_id"]
- == transaction_event["spans"][0]["span_id"]
- )
-
-
-@minimum_python_37
-@pytest.mark.asyncio
-async def test_exception(
- sentry_init,
- capture_events,
- event_loop,
-):
- sentry_init(
- traces_sample_rate=1.0,
- send_default_pii=True,
- debug=True,
- integrations=[
- AsyncioIntegration(),
- ],
- )
-
- events = capture_events()
-
- with sentry_sdk.start_transaction(name="test_exception"):
- with sentry_sdk.start_span(op="root", description="not so important"):
- tasks = [event_loop.create_task(boom()), event_loop.create_task(bar())]
- await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
-
- sentry_sdk.flush()
-
- (error_event, _) = events
-
- assert error_event["transaction"] == "test_exception"
- assert error_event["contexts"]["trace"]["op"] == "function"
- assert error_event["exception"]["values"][0]["type"] == "ZeroDivisionError"
- assert error_event["exception"]["values"][0]["value"] == "division by zero"
- assert error_event["exception"]["values"][0]["mechanism"]["handled"] is False
- assert error_event["exception"]["values"][0]["mechanism"]["type"] == "asyncio"
-
-
-@minimum_python_37
-@pytest.mark.asyncio
-async def test_task_result(sentry_init):
- sentry_init(
- integrations=[
- AsyncioIntegration(),
- ],
- )
-
- async def add(a, b):
- return a + b
-
- result = await asyncio.create_task(add(1, 2))
- assert result == 3, result
diff --git a/tests/integrations/asyncpg/__init__.py b/tests/integrations/asyncpg/__init__.py
new file mode 100644
index 0000000000..d988407a2d
--- /dev/null
+++ b/tests/integrations/asyncpg/__init__.py
@@ -0,0 +1,10 @@
+import os
+import sys
+import pytest
+
+pytest.importorskip("asyncpg")
+pytest.importorskip("pytest_asyncio")
+
+# Load `asyncpg_helpers` into the module search path to test query source path names relative to module. See
+# `test_query_source_with_module_in_search_path`
+sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
diff --git a/tests/integrations/asyncpg/asyncpg_helpers/__init__.py b/tests/integrations/asyncpg/asyncpg_helpers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/asyncpg/asyncpg_helpers/helpers.py b/tests/integrations/asyncpg/asyncpg_helpers/helpers.py
new file mode 100644
index 0000000000..8de809ba1b
--- /dev/null
+++ b/tests/integrations/asyncpg/asyncpg_helpers/helpers.py
@@ -0,0 +1,2 @@
+async def execute_query_in_connection(query, connection):
+ await connection.execute(query)
diff --git a/tests/integrations/asyncpg/test_asyncpg.py b/tests/integrations/asyncpg/test_asyncpg.py
new file mode 100644
index 0000000000..2dcce52070
--- /dev/null
+++ b/tests/integrations/asyncpg/test_asyncpg.py
@@ -0,0 +1,859 @@
+"""
+Tests need pytest-asyncio installed.
+
+Tests need a local postgresql instance running, this can best be done using
+```sh
+docker run --rm --name some-postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=sentry -d -p 5432:5432 postgres
+```
+
+The tests use the following credentials to establish a database connection.
+"""
+
+import os
+import datetime
+from contextlib import contextmanager
+from unittest import mock
+
+import asyncpg
+import pytest
+import pytest_asyncio
+from asyncpg import connect, Connection
+
+from sentry_sdk import capture_message, start_transaction
+from sentry_sdk.integrations.asyncpg import AsyncPGIntegration
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.tracing_utils import record_sql_queries
+from tests.conftest import ApproxDict
+
+PG_HOST = os.getenv("SENTRY_PYTHON_TEST_POSTGRES_HOST", "localhost")
+PG_PORT = int(os.getenv("SENTRY_PYTHON_TEST_POSTGRES_PORT", "5432"))
+PG_USER = os.getenv("SENTRY_PYTHON_TEST_POSTGRES_USER", "postgres")
+PG_PASSWORD = os.getenv("SENTRY_PYTHON_TEST_POSTGRES_PASSWORD", "sentry")
+PG_NAME_BASE = os.getenv("SENTRY_PYTHON_TEST_POSTGRES_NAME", "postgres")
+
+
+def _get_db_name():
+ pid = os.getpid()
+ return f"{PG_NAME_BASE}_{pid}"
+
+
+PG_NAME = _get_db_name()
+
+PG_CONNECTION_URI = "postgresql://{}:{}@{}/{}".format(
+ PG_USER, PG_PASSWORD, PG_HOST, PG_NAME
+)
+CRUMBS_CONNECT = {
+ "category": "query",
+ "data": ApproxDict(
+ {
+ "db.name": PG_NAME,
+ "db.system": "postgresql",
+ "db.user": PG_USER,
+ "db.driver.name": "asyncpg",
+ "server.address": PG_HOST,
+ "server.port": PG_PORT,
+ }
+ ),
+ "message": "connect",
+ "type": "default",
+}
+
+
+@pytest_asyncio.fixture(autouse=True)
+async def _clean_pg():
+ # Create the test database if it doesn't exist
+ default_conn = await connect(
+ "postgresql://{}:{}@{}".format(PG_USER, PG_PASSWORD, PG_HOST)
+ )
+ try:
+ # Check if database exists, create if not
+ result = await default_conn.fetchval(
+ "SELECT 1 FROM pg_database WHERE datname = $1", PG_NAME
+ )
+ if not result:
+ await default_conn.execute(f'CREATE DATABASE "{PG_NAME}"')
+ finally:
+ await default_conn.close()
+
+ # Now connect to our test database and set up the table
+ conn = await connect(PG_CONNECTION_URI)
+ await conn.execute("DROP TABLE IF EXISTS users")
+ await conn.execute(
+ """
+ CREATE TABLE users(
+ id serial PRIMARY KEY,
+ name text,
+ password text,
+ dob date
+ )
+ """
+ )
+ await conn.close()
+
+
+@pytest.mark.asyncio
+async def test_connect(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.close()
+
+ capture_message("hi")
+
+ (event,) = events
+
+ for crumb in event["breadcrumbs"]["values"]:
+ del crumb["timestamp"]
+
+ assert event["breadcrumbs"]["values"] == [CRUMBS_CONNECT]
+
+
+@pytest.mark.asyncio
+async def test_execute(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES ('Alice', 'pw', '1990-12-25')",
+ )
+
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "Bob",
+ "secret_pw",
+ datetime.date(1984, 3, 1),
+ )
+
+ row = await conn.fetchrow("SELECT * FROM users WHERE name = $1", "Bob")
+ assert row == (2, "Bob", "secret_pw", datetime.date(1984, 3, 1))
+
+ row = await conn.fetchrow("SELECT * FROM users WHERE name = 'Bob'")
+ assert row == (2, "Bob", "secret_pw", datetime.date(1984, 3, 1))
+
+ await conn.close()
+
+ capture_message("hi")
+
+ (event,) = events
+
+ for crumb in event["breadcrumbs"]["values"]:
+ del crumb["timestamp"]
+
+ assert event["breadcrumbs"]["values"] == [
+ CRUMBS_CONNECT,
+ {
+ "category": "query",
+ "data": {},
+ "message": "INSERT INTO users(name, password, dob) VALUES ('Alice', 'pw', '1990-12-25')",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {},
+ "message": "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {},
+ "message": "SELECT * FROM users WHERE name = $1",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {},
+ "message": "SELECT * FROM users WHERE name = 'Bob'",
+ "type": "default",
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_execute_many(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.executemany(
+ "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ [
+ ("Bob", "secret_pw", datetime.date(1984, 3, 1)),
+ ("Alice", "pw", datetime.date(1990, 12, 25)),
+ ],
+ )
+
+ await conn.close()
+
+ capture_message("hi")
+
+ (event,) = events
+
+ for crumb in event["breadcrumbs"]["values"]:
+ del crumb["timestamp"]
+
+ assert event["breadcrumbs"]["values"] == [
+ CRUMBS_CONNECT,
+ {
+ "category": "query",
+ "data": {"db.executemany": True},
+ "message": "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "type": "default",
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_record_params(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[AsyncPGIntegration(record_params=True)],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "Bob",
+ "secret_pw",
+ datetime.date(1984, 3, 1),
+ )
+
+ await conn.close()
+
+ capture_message("hi")
+
+ (event,) = events
+
+ for crumb in event["breadcrumbs"]["values"]:
+ del crumb["timestamp"]
+
+ assert event["breadcrumbs"]["values"] == [
+ CRUMBS_CONNECT,
+ {
+ "category": "query",
+ "data": {
+ "db.params": ["Bob", "secret_pw", "datetime.date(1984, 3, 1)"],
+ "db.paramstyle": "format",
+ },
+ "message": "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "type": "default",
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_cursor(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.executemany(
+ "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ [
+ ("Bob", "secret_pw", datetime.date(1984, 3, 1)),
+ ("Alice", "pw", datetime.date(1990, 12, 25)),
+ ],
+ )
+
+ async with conn.transaction():
+ # Postgres requires non-scrollable cursors to be created
+ # and used in a transaction.
+ async for record in conn.cursor(
+ "SELECT * FROM users WHERE dob > $1", datetime.date(1970, 1, 1)
+ ):
+ print(record)
+
+ await conn.close()
+
+ capture_message("hi")
+
+ (event,) = events
+
+ for crumb in event["breadcrumbs"]["values"]:
+ del crumb["timestamp"]
+
+ assert event["breadcrumbs"]["values"] == [
+ CRUMBS_CONNECT,
+ {
+ "category": "query",
+ "data": {"db.executemany": True},
+ "message": "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "type": "default",
+ },
+ {"category": "query", "data": {}, "message": "BEGIN;", "type": "default"},
+ {
+ "category": "query",
+ "data": {},
+ "message": "SELECT * FROM users WHERE dob > $1",
+ "type": "default",
+ },
+ {"category": "query", "data": {}, "message": "COMMIT;", "type": "default"},
+ ]
+
+
+@pytest.mark.asyncio
+async def test_cursor_manual(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.executemany(
+ "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ [
+ ("Bob", "secret_pw", datetime.date(1984, 3, 1)),
+ ("Alice", "pw", datetime.date(1990, 12, 25)),
+ ],
+ )
+ #
+ async with conn.transaction():
+ # Postgres requires non-scrollable cursors to be created
+ # and used in a transaction.
+ cur = await conn.cursor(
+ "SELECT * FROM users WHERE dob > $1", datetime.date(1970, 1, 1)
+ )
+ record = await cur.fetchrow()
+ print(record)
+ while await cur.forward(1):
+ record = await cur.fetchrow()
+ print(record)
+
+ await conn.close()
+
+ capture_message("hi")
+
+ (event,) = events
+
+ for crumb in event["breadcrumbs"]["values"]:
+ del crumb["timestamp"]
+
+ assert event["breadcrumbs"]["values"] == [
+ CRUMBS_CONNECT,
+ {
+ "category": "query",
+ "data": {"db.executemany": True},
+ "message": "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "type": "default",
+ },
+ {"category": "query", "data": {}, "message": "BEGIN;", "type": "default"},
+ {
+ "category": "query",
+ "data": {},
+ "message": "SELECT * FROM users WHERE dob > $1",
+ "type": "default",
+ },
+ {"category": "query", "data": {}, "message": "COMMIT;", "type": "default"},
+ ]
+
+
+@pytest.mark.asyncio
+async def test_prepared_stmt(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.executemany(
+ "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ [
+ ("Bob", "secret_pw", datetime.date(1984, 3, 1)),
+ ("Alice", "pw", datetime.date(1990, 12, 25)),
+ ],
+ )
+
+ stmt = await conn.prepare("SELECT * FROM users WHERE name = $1")
+
+ print(await stmt.fetchval("Bob"))
+ print(await stmt.fetchval("Alice"))
+
+ await conn.close()
+
+ capture_message("hi")
+
+ (event,) = events
+
+ for crumb in event["breadcrumbs"]["values"]:
+ del crumb["timestamp"]
+
+ assert event["breadcrumbs"]["values"] == [
+ CRUMBS_CONNECT,
+ {
+ "category": "query",
+ "data": {"db.executemany": True},
+ "message": "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {},
+ "message": "SELECT * FROM users WHERE name = $1",
+ "type": "default",
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_connection_pool(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ pool_size = 2
+
+ pool = await asyncpg.create_pool(
+ PG_CONNECTION_URI, min_size=pool_size, max_size=pool_size
+ )
+
+ async with pool.acquire() as conn:
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "Bob",
+ "secret_pw",
+ datetime.date(1984, 3, 1),
+ )
+
+ async with pool.acquire() as conn:
+ row = await conn.fetchrow("SELECT * FROM users WHERE name = $1", "Bob")
+ assert row == (1, "Bob", "secret_pw", datetime.date(1984, 3, 1))
+
+ await pool.close()
+
+ capture_message("hi")
+
+ (event,) = events
+
+ for crumb in event["breadcrumbs"]["values"]:
+ del crumb["timestamp"]
+
+ assert event["breadcrumbs"]["values"] == [
+ # The connection pool opens pool_size connections so we have the crumbs pool_size times
+ *[CRUMBS_CONNECT] * pool_size,
+ {
+ "category": "query",
+ "data": {},
+ "message": "INSERT INTO users(name, password, dob) VALUES($1, $2, $3)",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {},
+ "message": "SELECT pg_advisory_unlock_all(); CLOSE ALL; UNLISTEN *; RESET ALL;",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {},
+ "message": "SELECT * FROM users WHERE name = $1",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {},
+ "message": "SELECT pg_advisory_unlock_all(); CLOSE ALL; UNLISTEN *; RESET ALL;",
+ "type": "default",
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_query_source_disabled(sentry_init, capture_events):
+ sentry_options = {
+ "integrations": [AsyncPGIntegration()],
+ "traces_sample_rate": 1.0,
+ "enable_db_query_source": False,
+ "db_query_source_threshold_ms": 0,
+ }
+
+ sentry_init(**sentry_options)
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES ('Alice', 'secret', '1990-12-25')",
+ )
+
+ await conn.close()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("INSERT INTO")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("enable_db_query_source", [None, True])
+async def test_query_source_enabled(
+ sentry_init, capture_events, enable_db_query_source
+):
+ sentry_options = {
+ "integrations": [AsyncPGIntegration()],
+ "traces_sample_rate": 1.0,
+ "db_query_source_threshold_ms": 0,
+ }
+ if enable_db_query_source is not None:
+ sentry_options["enable_db_query_source"] = enable_db_query_source
+
+ sentry_init(**sentry_options)
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES ('Alice', 'secret', '1990-12-25')",
+ )
+
+ await conn.close()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("INSERT INTO")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+
+@pytest.mark.asyncio
+async def test_query_source(sentry_init, capture_events):
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES ('Alice', 'secret', '1990-12-25')",
+ )
+
+ await conn.close()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("INSERT INTO")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE) == "tests.integrations.asyncpg.test_asyncpg"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/asyncpg/test_asyncpg.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "test_query_source"
+
+
+@pytest.mark.asyncio
+async def test_query_source_with_module_in_search_path(sentry_init, capture_events):
+ """
+ Test that query source is relative to the path of the module it ran in
+ """
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=0,
+ )
+
+ events = capture_events()
+
+ from asyncpg_helpers.helpers import execute_query_in_connection
+
+ with start_transaction(name="test_transaction", sampled=True):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await execute_query_in_connection(
+ "INSERT INTO users(name, password, dob) VALUES ('Alice', 'secret', '1990-12-25')",
+ conn,
+ )
+
+ await conn.close()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("INSERT INTO")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "asyncpg_helpers.helpers"
+ assert data.get(SPANDATA.CODE_FILEPATH) == "asyncpg_helpers/helpers.py"
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "execute_query_in_connection"
+
+
+@pytest.mark.asyncio
+async def test_no_query_source_if_duration_too_short(sentry_init, capture_events):
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=100,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ @contextmanager
+ def fake_record_sql_queries(*args, **kwargs):
+ with record_sql_queries(*args, **kwargs) as span:
+ pass
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=99999)
+ yield span
+
+ with mock.patch(
+ "sentry_sdk.integrations.asyncpg.record_sql_queries",
+ fake_record_sql_queries,
+ ):
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES ('Alice', 'secret', '1990-12-25')",
+ )
+
+ await conn.close()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("INSERT INTO")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+@pytest.mark.asyncio
+async def test_query_source_if_duration_over_threshold(sentry_init, capture_events):
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=100,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ @contextmanager
+ def fake_record_sql_queries(*args, **kwargs):
+ with record_sql_queries(*args, **kwargs) as span:
+ pass
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=100001)
+ yield span
+
+ with mock.patch(
+ "sentry_sdk.integrations.asyncpg.record_sql_queries",
+ fake_record_sql_queries,
+ ):
+ await conn.execute(
+ "INSERT INTO users(name, password, dob) VALUES ('Alice', 'secret', '1990-12-25')",
+ )
+
+ await conn.close()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("INSERT INTO")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE) == "tests.integrations.asyncpg.test_asyncpg"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/asyncpg/test_asyncpg.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert (
+ data.get(SPANDATA.CODE_FUNCTION)
+ == "test_query_source_if_duration_over_threshold"
+ )
+
+
+@pytest.mark.asyncio
+async def test_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+
+ await conn.execute("SELECT 1")
+ await conn.fetchrow("SELECT 2")
+ await conn.close()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ for span in event["spans"]:
+ assert span["origin"] == "auto.db.asyncpg"
+
+
+@pytest.mark.asyncio
+async def test_multiline_query_description_normalized(sentry_init, capture_events):
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+ await conn.execute(
+ """
+ SELECT
+ id,
+ name
+ FROM
+ users
+ WHERE
+ name = 'Alice'
+ """
+ )
+ await conn.close()
+
+ (event,) = events
+
+ spans = [
+ s
+ for s in event["spans"]
+ if s["op"] == "db" and "SELECT" in s.get("description", "")
+ ]
+ assert len(spans) == 1
+ assert spans[0]["description"] == "SELECT id, name FROM users WHERE name = 'Alice'"
+
+
+@pytest.mark.asyncio
+async def test_before_send_transaction_sees_normalized_description(
+ sentry_init, capture_events
+):
+ def before_send_transaction(event, hint):
+ for span in event.get("spans", []):
+ desc = span.get("description", "")
+ if "SELECT id, name FROM users" in desc:
+ span["description"] = "filtered"
+ return event
+
+ sentry_init(
+ integrations=[AsyncPGIntegration()],
+ traces_sample_rate=1.0,
+ before_send_transaction=before_send_transaction,
+ )
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ conn: Connection = await connect(PG_CONNECTION_URI)
+ await conn.execute(
+ """
+ SELECT
+ id,
+ name
+ FROM
+ users
+ """
+ )
+ await conn.close()
+
+ (event,) = events
+ spans = [
+ s
+ for s in event["spans"]
+ if s["op"] == "db" and "filtered" in s.get("description", "")
+ ]
+
+ assert len(spans) == 1
+ assert spans[0]["description"] == "filtered"
diff --git a/tests/integrations/aws_lambda/__init__.py b/tests/integrations/aws_lambda/__init__.py
new file mode 100644
index 0000000000..449f4dc95d
--- /dev/null
+++ b/tests/integrations/aws_lambda/__init__.py
@@ -0,0 +1,5 @@
+import pytest
+
+pytest.importorskip("boto3")
+pytest.importorskip("fastapi")
+pytest.importorskip("uvicorn")
diff --git a/tests/integrations/aws_lambda/client.py b/tests/integrations/aws_lambda/client.py
deleted file mode 100644
index d8e430f3d7..0000000000
--- a/tests/integrations/aws_lambda/client.py
+++ /dev/null
@@ -1,239 +0,0 @@
-import sys
-import os
-import shutil
-import tempfile
-import subprocess
-import boto3
-import uuid
-import base64
-
-
-def get_boto_client():
- return boto3.client(
- "lambda",
- aws_access_key_id=os.environ["SENTRY_PYTHON_TEST_AWS_ACCESS_KEY_ID"],
- aws_secret_access_key=os.environ["SENTRY_PYTHON_TEST_AWS_SECRET_ACCESS_KEY"],
- region_name="us-east-1",
- )
-
-
-def build_no_code_serverless_function_and_layer(
- client, tmpdir, fn_name, runtime, timeout, initial_handler
-):
- """
- Util function that auto instruments the no code implementation of the python
- sdk by creating a layer containing the Python-sdk, and then creating a func
- that uses that layer
- """
- from scripts.build_aws_lambda_layer import build_layer_dir
-
- build_layer_dir(dest_abs_path=tmpdir)
-
- with open(os.path.join(tmpdir, "serverless-ball.zip"), "rb") as serverless_zip:
- response = client.publish_layer_version(
- LayerName="python-serverless-sdk-test",
- Description="Created as part of testsuite for getsentry/sentry-python",
- Content={"ZipFile": serverless_zip.read()},
- )
-
- with open(os.path.join(tmpdir, "ball.zip"), "rb") as zip:
- client.create_function(
- FunctionName=fn_name,
- Runtime=runtime,
- Timeout=timeout,
- Environment={
- "Variables": {
- "SENTRY_INITIAL_HANDLER": initial_handler,
- "SENTRY_DSN": "https://123abc@example.com/123",
- "SENTRY_TRACES_SAMPLE_RATE": "1.0",
- }
- },
- Role=os.environ["SENTRY_PYTHON_TEST_AWS_IAM_ROLE"],
- Handler="sentry_sdk.integrations.init_serverless_sdk.sentry_lambda_handler",
- Layers=[response["LayerVersionArn"]],
- Code={"ZipFile": zip.read()},
- Description="Created as part of testsuite for getsentry/sentry-python",
- )
-
-
-def run_lambda_function(
- client,
- runtime,
- code,
- payload,
- add_finalizer,
- syntax_check=True,
- timeout=30,
- layer=None,
- initial_handler=None,
- subprocess_kwargs=(),
-):
- subprocess_kwargs = dict(subprocess_kwargs)
-
- with tempfile.TemporaryDirectory() as tmpdir:
- if initial_handler:
- # If Initial handler value is provided i.e. it is not the default
- # `test_lambda.test_handler`, then create another dir level so that our path is
- # test_dir.test_lambda.test_handler
- test_dir_path = os.path.join(tmpdir, "test_dir")
- python_init_file = os.path.join(test_dir_path, "__init__.py")
- os.makedirs(test_dir_path)
- with open(python_init_file, "w"):
- # Create __init__ file to make it a python package
- pass
-
- test_lambda_py = os.path.join(tmpdir, "test_dir", "test_lambda.py")
- else:
- test_lambda_py = os.path.join(tmpdir, "test_lambda.py")
-
- with open(test_lambda_py, "w") as f:
- f.write(code)
-
- if syntax_check:
- # Check file for valid syntax first, and that the integration does not
- # crash when not running in Lambda (but rather a local deployment tool
- # such as chalice's)
- subprocess.check_call([sys.executable, test_lambda_py])
-
- fn_name = "test_function_{}".format(uuid.uuid4())
-
- if layer is None:
- setup_cfg = os.path.join(tmpdir, "setup.cfg")
- with open(setup_cfg, "w") as f:
- f.write("[install]\nprefix=")
-
- subprocess.check_call(
- [sys.executable, "setup.py", "sdist", "-d", os.path.join(tmpdir, "..")],
- **subprocess_kwargs
- )
-
- subprocess.check_call(
- "pip install mock==3.0.0 funcsigs -t .",
- cwd=tmpdir,
- shell=True,
- **subprocess_kwargs
- )
-
- # https://docs.aws.amazon.com/lambda/latest/dg/lambda-python-how-to-create-deployment-package.html
- subprocess.check_call(
- "pip install ../*.tar.gz -t .",
- cwd=tmpdir,
- shell=True,
- **subprocess_kwargs
- )
-
- shutil.make_archive(os.path.join(tmpdir, "ball"), "zip", tmpdir)
-
- with open(os.path.join(tmpdir, "ball.zip"), "rb") as zip:
- client.create_function(
- FunctionName=fn_name,
- Runtime=runtime,
- Timeout=timeout,
- Role=os.environ["SENTRY_PYTHON_TEST_AWS_IAM_ROLE"],
- Handler="test_lambda.test_handler",
- Code={"ZipFile": zip.read()},
- Description="Created as part of testsuite for getsentry/sentry-python",
- )
- else:
- subprocess.run(
- ["zip", "-q", "-x", "**/__pycache__/*", "-r", "ball.zip", "./"],
- cwd=tmpdir,
- check=True,
- )
-
- # Default initial handler
- if not initial_handler:
- initial_handler = "test_lambda.test_handler"
-
- build_no_code_serverless_function_and_layer(
- client, tmpdir, fn_name, runtime, timeout, initial_handler
- )
-
- @add_finalizer
- def clean_up():
- client.delete_function(FunctionName=fn_name)
-
- # this closes the web socket so we don't get a
- # ResourceWarning: unclosed
- # warning on every test
- # based on https://github.com/boto/botocore/pull/1810
- # (if that's ever merged, this can just become client.close())
- session = client._endpoint.http_session
- managers = [session._manager] + list(session._proxy_managers.values())
- for manager in managers:
- manager.clear()
-
- response = client.invoke(
- FunctionName=fn_name,
- InvocationType="RequestResponse",
- LogType="Tail",
- Payload=payload,
- )
-
- assert 200 <= response["StatusCode"] < 300, response
- return response
-
-
-_REPL_CODE = """
-import os
-
-def test_handler(event, context):
- line = {line!r}
- if line.startswith(">>> "):
- exec(line[4:])
- elif line.startswith("$ "):
- os.system(line[2:])
- else:
- print("Start a line with $ or >>>")
-
- return b""
-"""
-
-try:
- import click
-except ImportError:
- pass
-else:
-
- @click.command()
- @click.option(
- "--runtime", required=True, help="name of the runtime to use, eg python3.8"
- )
- @click.option("--verbose", is_flag=True, default=False)
- def repl(runtime, verbose):
- """
- Launch a "REPL" against AWS Lambda to inspect their runtime.
- """
-
- cleanup = []
- client = get_boto_client()
-
- print("Start a line with `$ ` to run shell commands, or `>>> ` to run Python")
-
- while True:
- line = input()
-
- response = run_lambda_function(
- client,
- runtime,
- _REPL_CODE.format(line=line),
- b"",
- cleanup.append,
- subprocess_kwargs={
- "stdout": subprocess.DEVNULL,
- "stderr": subprocess.DEVNULL,
- }
- if not verbose
- else {},
- )
-
- for line in base64.b64decode(response["LogResult"]).splitlines():
- print(line.decode("utf8"))
-
- for f in cleanup:
- f()
-
- cleanup = []
-
- if __name__ == "__main__":
- repl()
diff --git a/tests/integrations/aws_lambda/lambda_functions/BasicException/index.py b/tests/integrations/aws_lambda/lambda_functions/BasicException/index.py
new file mode 100644
index 0000000000..875b984e2a
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions/BasicException/index.py
@@ -0,0 +1,6 @@
+def handler(event, context):
+ raise RuntimeError("Oh!")
+
+ return {
+ "event": event,
+ }
diff --git a/tests/integrations/aws_lambda/lambda_functions/BasicOk/index.py b/tests/integrations/aws_lambda/lambda_functions/BasicOk/index.py
new file mode 100644
index 0000000000..257fea04f0
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions/BasicOk/index.py
@@ -0,0 +1,4 @@
+def handler(event, context):
+ return {
+ "event": event,
+ }
diff --git a/tests/integrations/aws_lambda/lambda_functions/InitError/index.py b/tests/integrations/aws_lambda/lambda_functions/InitError/index.py
new file mode 100644
index 0000000000..20b4fcc111
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions/InitError/index.py
@@ -0,0 +1,3 @@
+# We have no handler() here and try to call a non-existing function.
+
+func() # noqa: F821
diff --git a/tests/integrations/aws_lambda/lambda_functions/TimeoutError/index.py b/tests/integrations/aws_lambda/lambda_functions/TimeoutError/index.py
new file mode 100644
index 0000000000..01334bbfbc
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions/TimeoutError/index.py
@@ -0,0 +1,8 @@
+import time
+
+
+def handler(event, context):
+ time.sleep(15)
+ return {
+ "event": event,
+ }
diff --git a/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceDisabled/.gitignore b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceDisabled/.gitignore
new file mode 100644
index 0000000000..1c56884372
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceDisabled/.gitignore
@@ -0,0 +1,11 @@
+# Need to add some ignore rules in this directory, because the unit tests will add the Sentry SDK and its dependencies
+# into this directory to create a Lambda function package that contains everything needed to instrument a Lambda function using Sentry.
+
+# Ignore everything
+*
+
+# But not index.py
+!index.py
+
+# And not .gitignore itself
+!.gitignore
diff --git a/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceDisabled/index.py b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceDisabled/index.py
new file mode 100644
index 0000000000..12f43f0009
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceDisabled/index.py
@@ -0,0 +1,14 @@
+import os
+import sentry_sdk
+from sentry_sdk.integrations.aws_lambda import AwsLambdaIntegration
+
+
+sentry_sdk.init(
+ dsn=os.environ.get("SENTRY_DSN"),
+ traces_sample_rate=None, # this is the default, just added for clarity
+ integrations=[AwsLambdaIntegration()],
+)
+
+
+def handler(event, context):
+ raise Exception("Oh!")
diff --git a/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceEnabled/.gitignore b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceEnabled/.gitignore
new file mode 100644
index 0000000000..1c56884372
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceEnabled/.gitignore
@@ -0,0 +1,11 @@
+# Need to add some ignore rules in this directory, because the unit tests will add the Sentry SDK and its dependencies
+# into this directory to create a Lambda function package that contains everything needed to instrument a Lambda function using Sentry.
+
+# Ignore everything
+*
+
+# But not index.py
+!index.py
+
+# And not .gitignore itself
+!.gitignore
diff --git a/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceEnabled/index.py b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceEnabled/index.py
new file mode 100644
index 0000000000..c694299682
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/RaiseErrorPerformanceEnabled/index.py
@@ -0,0 +1,14 @@
+import os
+import sentry_sdk
+from sentry_sdk.integrations.aws_lambda import AwsLambdaIntegration
+
+
+sentry_sdk.init(
+ dsn=os.environ.get("SENTRY_DSN"),
+ traces_sample_rate=1.0,
+ integrations=[AwsLambdaIntegration()],
+)
+
+
+def handler(event, context):
+ raise Exception("Oh!")
diff --git a/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TimeoutErrorScopeModified/.gitignore b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TimeoutErrorScopeModified/.gitignore
new file mode 100644
index 0000000000..1c56884372
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TimeoutErrorScopeModified/.gitignore
@@ -0,0 +1,11 @@
+# Need to add some ignore rules in this directory, because the unit tests will add the Sentry SDK and its dependencies
+# into this directory to create a Lambda function package that contains everything needed to instrument a Lambda function using Sentry.
+
+# Ignore everything
+*
+
+# But not index.py
+!index.py
+
+# And not .gitignore itself
+!.gitignore
diff --git a/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TimeoutErrorScopeModified/index.py b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TimeoutErrorScopeModified/index.py
new file mode 100644
index 0000000000..109245b90d
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TimeoutErrorScopeModified/index.py
@@ -0,0 +1,19 @@
+import os
+import time
+
+import sentry_sdk
+from sentry_sdk.integrations.aws_lambda import AwsLambdaIntegration
+
+sentry_sdk.init(
+ dsn=os.environ.get("SENTRY_DSN"),
+ traces_sample_rate=1.0,
+ integrations=[AwsLambdaIntegration(timeout_warning=True)],
+)
+
+
+def handler(event, context):
+ sentry_sdk.set_tag("custom_tag", "custom_value")
+ time.sleep(15)
+ return {
+ "event": event,
+ }
diff --git a/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TracesSampler/.gitignore b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TracesSampler/.gitignore
new file mode 100644
index 0000000000..1c56884372
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TracesSampler/.gitignore
@@ -0,0 +1,11 @@
+# Need to add some ignore rules in this directory, because the unit tests will add the Sentry SDK and its dependencies
+# into this directory to create a Lambda function package that contains everything needed to instrument a Lambda function using Sentry.
+
+# Ignore everything
+*
+
+# But not index.py
+!index.py
+
+# And not .gitignore itself
+!.gitignore
diff --git a/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TracesSampler/index.py b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TracesSampler/index.py
new file mode 100644
index 0000000000..ce797faf71
--- /dev/null
+++ b/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/TracesSampler/index.py
@@ -0,0 +1,49 @@
+import json
+import os
+import sentry_sdk
+from sentry_sdk.integrations.aws_lambda import AwsLambdaIntegration
+
+# Global variables to store sampling context for verification
+sampling_context_data = {
+ "aws_event_present": False,
+ "aws_context_present": False,
+ "event_data": None,
+}
+
+
+def trace_sampler(sampling_context):
+ # Store the sampling context for verification
+ global sampling_context_data
+
+ # Check if aws_event and aws_context are in the sampling_context
+ if "aws_event" in sampling_context:
+ sampling_context_data["aws_event_present"] = True
+ sampling_context_data["event_data"] = sampling_context["aws_event"]
+
+ if "aws_context" in sampling_context:
+ sampling_context_data["aws_context_present"] = True
+
+ print("Sampling context data:", sampling_context_data)
+ return 1.0 # Always sample
+
+
+sentry_sdk.init(
+ dsn=os.environ.get("SENTRY_DSN"),
+ traces_sample_rate=1.0,
+ traces_sampler=trace_sampler,
+ integrations=[AwsLambdaIntegration()],
+)
+
+
+def handler(event, context):
+ # Return the sampling context data for verification
+ return {
+ "statusCode": 200,
+ "body": json.dumps(
+ {
+ "message": "Hello from Lambda with embedded Sentry SDK!",
+ "event": event,
+ "sampling_context_data": sampling_context_data,
+ }
+ ),
+ }
diff --git a/tests/integrations/aws_lambda/test_aws.py b/tests/integrations/aws_lambda/test_aws.py
deleted file mode 100644
index 78c9770317..0000000000
--- a/tests/integrations/aws_lambda/test_aws.py
+++ /dev/null
@@ -1,666 +0,0 @@
-"""
-# AWS Lambda system tests
-
-This testsuite uses boto3 to upload actual lambda functions to AWS, execute
-them and assert some things about the externally observed behavior. What that
-means for you is that those tests won't run without AWS access keys:
-
- export SENTRY_PYTHON_TEST_AWS_ACCESS_KEY_ID=..
- export SENTRY_PYTHON_TEST_AWS_SECRET_ACCESS_KEY=...
- export SENTRY_PYTHON_TEST_AWS_IAM_ROLE="arn:aws:iam::920901907255:role/service-role/lambda"
-
-If you need to debug a new runtime, use this REPL to figure things out:
-
- pip3 install click
- python3 tests/integrations/aws_lambda/client.py --runtime=python4.0
-"""
-import base64
-import json
-import os
-import re
-from textwrap import dedent
-
-import pytest
-
-boto3 = pytest.importorskip("boto3")
-
-LAMBDA_PRELUDE = """
-from __future__ import print_function
-
-from sentry_sdk.integrations.aws_lambda import AwsLambdaIntegration, get_lambda_bootstrap
-import sentry_sdk
-import json
-import time
-
-from sentry_sdk.transport import HttpTransport
-
-def event_processor(event):
- # AWS Lambda truncates the log output to 4kb, which is small enough to miss
- # parts of even a single error-event/transaction-envelope pair if considered
- # in full, so only grab the data we need.
-
- event_data = {}
- event_data["contexts"] = {}
- event_data["contexts"]["trace"] = event.get("contexts", {}).get("trace")
- event_data["exception"] = event.get("exception")
- event_data["extra"] = event.get("extra")
- event_data["level"] = event.get("level")
- event_data["request"] = event.get("request")
- event_data["tags"] = event.get("tags")
- event_data["transaction"] = event.get("transaction")
-
- return event_data
-
-def envelope_processor(envelope):
- # AWS Lambda truncates the log output to 4kb, which is small enough to miss
- # parts of even a single error-event/transaction-envelope pair if considered
- # in full, so only grab the data we need.
-
- (item,) = envelope.items
- envelope_json = json.loads(item.get_bytes())
-
- envelope_data = {}
- envelope_data["contexts"] = {}
- envelope_data["type"] = envelope_json["type"]
- envelope_data["transaction"] = envelope_json["transaction"]
- envelope_data["contexts"]["trace"] = envelope_json["contexts"]["trace"]
- envelope_data["request"] = envelope_json["request"]
- envelope_data["tags"] = envelope_json["tags"]
-
- return envelope_data
-
-
-class TestTransport(HttpTransport):
- def _send_event(self, event):
- event = event_processor(event)
- # Writing a single string to stdout holds the GIL (seems like) and
- # therefore cannot be interleaved with other threads. This is why we
- # explicitly add a newline at the end even though `print` would provide
- # us one.
- print("\\nEVENT: {}\\n".format(json.dumps(event)))
-
- def _send_envelope(self, envelope):
- envelope = envelope_processor(envelope)
- print("\\nENVELOPE: {}\\n".format(json.dumps(envelope)))
-
-
-def init_sdk(timeout_warning=False, **extra_init_args):
- sentry_sdk.init(
- dsn="https://123abc@example.com/123",
- transport=TestTransport,
- integrations=[AwsLambdaIntegration(timeout_warning=timeout_warning)],
- shutdown_timeout=10,
- **extra_init_args
- )
-"""
-
-
-@pytest.fixture
-def lambda_client():
- if "SENTRY_PYTHON_TEST_AWS_ACCESS_KEY_ID" not in os.environ:
- pytest.skip("AWS environ vars not set")
-
- from tests.integrations.aws_lambda.client import get_boto_client
-
- return get_boto_client()
-
-
-@pytest.fixture(
- params=["python3.6", "python3.7", "python3.8", "python3.9", "python2.7"]
-)
-def lambda_runtime(request):
- return request.param
-
-
-@pytest.fixture
-def run_lambda_function(request, lambda_client, lambda_runtime):
- def inner(
- code, payload, timeout=30, syntax_check=True, layer=None, initial_handler=None
- ):
- from tests.integrations.aws_lambda.client import run_lambda_function
-
- response = run_lambda_function(
- client=lambda_client,
- runtime=lambda_runtime,
- code=code,
- payload=payload,
- add_finalizer=request.addfinalizer,
- timeout=timeout,
- syntax_check=syntax_check,
- layer=layer,
- initial_handler=initial_handler,
- )
-
- # for better debugging
- response["LogResult"] = base64.b64decode(response["LogResult"]).splitlines()
- response["Payload"] = json.loads(response["Payload"].read().decode("utf-8"))
- del response["ResponseMetadata"]
-
- events = []
- envelopes = []
-
- for line in response["LogResult"]:
- print("AWS:", line)
- if line.startswith(b"EVENT: "):
- line = line[len(b"EVENT: ") :]
- events.append(json.loads(line.decode("utf-8")))
- elif line.startswith(b"ENVELOPE: "):
- line = line[len(b"ENVELOPE: ") :]
- envelopes.append(json.loads(line.decode("utf-8")))
- else:
- continue
-
- return envelopes, events, response
-
- return inner
-
-
-def test_basic(run_lambda_function):
- envelopes, events, response = run_lambda_function(
- LAMBDA_PRELUDE
- + dedent(
- """
- init_sdk()
-
- def event_processor(event):
- # Delay event output like this to test proper shutdown
- time.sleep(1)
- return event
-
- def test_handler(event, context):
- raise Exception("something went wrong")
- """
- ),
- b'{"foo": "bar"}',
- )
-
- assert response["FunctionError"] == "Unhandled"
-
- (event,) = events
- assert event["level"] == "error"
- (exception,) = event["exception"]["values"]
- assert exception["type"] == "Exception"
- assert exception["value"] == "something went wrong"
-
- (frame1,) = exception["stacktrace"]["frames"]
- assert frame1["filename"] == "test_lambda.py"
- assert frame1["abs_path"] == "/var/task/test_lambda.py"
- assert frame1["function"] == "test_handler"
-
- assert frame1["in_app"] is True
-
- assert exception["mechanism"] == {"type": "aws_lambda", "handled": False}
-
- assert event["extra"]["lambda"]["function_name"].startswith("test_function_")
-
- logs_url = event["extra"]["cloudwatch logs"]["url"]
- assert logs_url.startswith("https://console.aws.amazon.com/cloudwatch/home?region=")
- assert not re.search("(=;|=$)", logs_url)
- assert event["extra"]["cloudwatch logs"]["log_group"].startswith(
- "/aws/lambda/test_function_"
- )
-
- log_stream_re = "^[0-9]{4}/[0-9]{2}/[0-9]{2}/\\[[^\\]]+][a-f0-9]+$"
- log_stream = event["extra"]["cloudwatch logs"]["log_stream"]
-
- assert re.match(log_stream_re, log_stream)
-
-
-def test_initialization_order(run_lambda_function):
- """Zappa lazily imports our code, so by the time we monkeypatch the handler
- as seen by AWS already runs. At this point at least draining the queue
- should work."""
-
- envelopes, events, _response = run_lambda_function(
- LAMBDA_PRELUDE
- + dedent(
- """
- def test_handler(event, context):
- init_sdk()
- sentry_sdk.capture_exception(Exception("something went wrong"))
- """
- ),
- b'{"foo": "bar"}',
- )
-
- (event,) = events
- assert event["level"] == "error"
- (exception,) = event["exception"]["values"]
- assert exception["type"] == "Exception"
- assert exception["value"] == "something went wrong"
-
-
-def test_request_data(run_lambda_function):
- envelopes, events, _response = run_lambda_function(
- LAMBDA_PRELUDE
- + dedent(
- """
- init_sdk()
- def test_handler(event, context):
- sentry_sdk.capture_message("hi")
- return "ok"
- """
- ),
- payload=b"""
- {
- "resource": "/asd",
- "path": "/asd",
- "httpMethod": "GET",
- "headers": {
- "Host": "iwsz2c7uwi.execute-api.us-east-1.amazonaws.com",
- "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:62.0) Gecko/20100101 Firefox/62.0",
- "X-Forwarded-Proto": "https"
- },
- "queryStringParameters": {
- "bonkers": "true"
- },
- "pathParameters": null,
- "stageVariables": null,
- "requestContext": {
- "identity": {
- "sourceIp": "213.47.147.207",
- "userArn": "42"
- }
- },
- "body": null,
- "isBase64Encoded": false
- }
- """,
- )
-
- (event,) = events
-
- assert event["request"] == {
- "headers": {
- "Host": "iwsz2c7uwi.execute-api.us-east-1.amazonaws.com",
- "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:62.0) Gecko/20100101 Firefox/62.0",
- "X-Forwarded-Proto": "https",
- },
- "method": "GET",
- "query_string": {"bonkers": "true"},
- "url": "https://iwsz2c7uwi.execute-api.us-east-1.amazonaws.com/asd",
- }
-
-
-def test_init_error(run_lambda_function, lambda_runtime):
- if lambda_runtime == "python2.7":
- pytest.skip("initialization error not supported on Python 2.7")
-
- envelopes, events, response = run_lambda_function(
- LAMBDA_PRELUDE
- + (
- "def event_processor(event):\n"
- ' return event["exception"]["values"][0]["value"]\n'
- "init_sdk()\n"
- "func()"
- ),
- b'{"foo": "bar"}',
- syntax_check=False,
- )
-
- (event,) = events
- assert "name 'func' is not defined" in event
-
-
-def test_timeout_error(run_lambda_function):
- envelopes, events, response = run_lambda_function(
- LAMBDA_PRELUDE
- + dedent(
- """
- init_sdk(timeout_warning=True)
-
- def test_handler(event, context):
- time.sleep(10)
- return 0
- """
- ),
- b'{"foo": "bar"}',
- timeout=3,
- )
-
- (event,) = events
- assert event["level"] == "error"
- (exception,) = event["exception"]["values"]
- assert exception["type"] == "ServerlessTimeoutWarning"
- assert exception["value"] in (
- "WARNING : Function is expected to get timed out. Configured timeout duration = 4 seconds.",
- "WARNING : Function is expected to get timed out. Configured timeout duration = 3 seconds.",
- )
-
- assert exception["mechanism"] == {"type": "threading", "handled": False}
-
- assert event["extra"]["lambda"]["function_name"].startswith("test_function_")
-
- logs_url = event["extra"]["cloudwatch logs"]["url"]
- assert logs_url.startswith("https://console.aws.amazon.com/cloudwatch/home?region=")
- assert not re.search("(=;|=$)", logs_url)
- assert event["extra"]["cloudwatch logs"]["log_group"].startswith(
- "/aws/lambda/test_function_"
- )
-
- log_stream_re = "^[0-9]{4}/[0-9]{2}/[0-9]{2}/\\[[^\\]]+][a-f0-9]+$"
- log_stream = event["extra"]["cloudwatch logs"]["log_stream"]
-
- assert re.match(log_stream_re, log_stream)
-
-
-def test_performance_no_error(run_lambda_function):
- envelopes, events, response = run_lambda_function(
- LAMBDA_PRELUDE
- + dedent(
- """
- init_sdk(traces_sample_rate=1.0)
-
- def test_handler(event, context):
- return "test_string"
- """
- ),
- b'{"foo": "bar"}',
- )
-
- (envelope,) = envelopes
- assert envelope["type"] == "transaction"
- assert envelope["contexts"]["trace"]["op"] == "function.aws.lambda"
- assert envelope["transaction"].startswith("test_function_")
- assert envelope["transaction_info"] == {"source": "component"}
- assert envelope["transaction"] in envelope["request"]["url"]
-
-
-def test_performance_error(run_lambda_function):
- envelopes, events, response = run_lambda_function(
- LAMBDA_PRELUDE
- + dedent(
- """
- init_sdk(traces_sample_rate=1.0)
-
- def test_handler(event, context):
- raise Exception("something went wrong")
- """
- ),
- b'{"foo": "bar"}',
- )
-
- (event,) = events
- assert event["level"] == "error"
- (exception,) = event["exception"]["values"]
- assert exception["type"] == "Exception"
- assert exception["value"] == "something went wrong"
-
- (envelope,) = envelopes
-
- assert envelope["type"] == "transaction"
- assert envelope["contexts"]["trace"]["op"] == "function.aws.lambda"
- assert envelope["transaction"].startswith("test_function_")
- assert envelope["transaction_info"] == {"source": "component"}
- assert envelope["transaction"] in envelope["request"]["url"]
-
-
-@pytest.mark.parametrize(
- "aws_event, has_request_data, batch_size",
- [
- (b"1231", False, 1),
- (b"11.21", False, 1),
- (b'"Good dog!"', False, 1),
- (b"true", False, 1),
- (
- b"""
- [
- {"good dog": "Maisey"},
- {"good dog": "Charlie"},
- {"good dog": "Cory"},
- {"good dog": "Bodhi"}
- ]
- """,
- False,
- 4,
- ),
- (
- b"""
- [
- {
- "headers": {
- "Host": "dogs.are.great",
- "X-Forwarded-Proto": "http"
- },
- "httpMethod": "GET",
- "path": "/tricks/kangaroo",
- "queryStringParameters": {
- "completed_successfully": "true",
- "treat_provided": "true",
- "treat_type": "cheese"
- },
- "dog": "Maisey"
- },
- {
- "headers": {
- "Host": "dogs.are.great",
- "X-Forwarded-Proto": "http"
- },
- "httpMethod": "GET",
- "path": "/tricks/kangaroo",
- "queryStringParameters": {
- "completed_successfully": "true",
- "treat_provided": "true",
- "treat_type": "cheese"
- },
- "dog": "Charlie"
- }
- ]
- """,
- True,
- 2,
- ),
- ],
-)
-def test_non_dict_event(
- run_lambda_function,
- aws_event,
- has_request_data,
- batch_size,
- DictionaryContaining, # noqa:N803
-):
- envelopes, events, response = run_lambda_function(
- LAMBDA_PRELUDE
- + dedent(
- """
- init_sdk(traces_sample_rate=1.0)
-
- def test_handler(event, context):
- raise Exception("More treats, please!")
- """
- ),
- aws_event,
- )
-
- assert response["FunctionError"] == "Unhandled"
-
- error_event = events[0]
- assert error_event["level"] == "error"
- assert error_event["contexts"]["trace"]["op"] == "function.aws.lambda"
-
- function_name = error_event["extra"]["lambda"]["function_name"]
- assert function_name.startswith("test_function_")
- assert error_event["transaction"] == function_name
-
- exception = error_event["exception"]["values"][0]
- assert exception["type"] == "Exception"
- assert exception["value"] == "More treats, please!"
- assert exception["mechanism"]["type"] == "aws_lambda"
-
- envelope = envelopes[0]
- assert envelope["type"] == "transaction"
- assert envelope["contexts"]["trace"] == DictionaryContaining(
- error_event["contexts"]["trace"]
- )
- assert envelope["contexts"]["trace"]["status"] == "internal_error"
- assert envelope["transaction"] == error_event["transaction"]
- assert envelope["request"]["url"] == error_event["request"]["url"]
-
- if has_request_data:
- request_data = {
- "headers": {"Host": "dogs.are.great", "X-Forwarded-Proto": "http"},
- "method": "GET",
- "url": "http://dogs.are.great/tricks/kangaroo",
- "query_string": {
- "completed_successfully": "true",
- "treat_provided": "true",
- "treat_type": "cheese",
- },
- }
- else:
- request_data = {"url": "awslambda:///{}".format(function_name)}
-
- assert error_event["request"] == request_data
- assert envelope["request"] == request_data
-
- if batch_size > 1:
- assert error_event["tags"]["batch_size"] == batch_size
- assert error_event["tags"]["batch_request"] is True
- assert envelope["tags"]["batch_size"] == batch_size
- assert envelope["tags"]["batch_request"] is True
-
-
-def test_traces_sampler_gets_correct_values_in_sampling_context(
- run_lambda_function,
- DictionaryContaining, # noqa:N803
- ObjectDescribedBy,
- StringContaining,
-):
- # TODO: This whole thing is a little hacky, specifically around the need to
- # get `conftest.py` code into the AWS runtime, which is why there's both
- # `inspect.getsource` and a copy of `_safe_is_equal` included directly in
- # the code below. Ideas which have been discussed to fix this:
-
- # - Include the test suite as a module installed in the package which is
- # shot up to AWS
- # - In client.py, copy `conftest.py` (or wherever the necessary code lives)
- # from the test suite into the main SDK directory so it gets included as
- # "part of the SDK"
-
- # It's also worth noting why it's necessary to run the assertions in the AWS
- # runtime rather than asserting on side effects the way we do with events
- # and envelopes. The reasons are two-fold:
-
- # - We're testing against the `LambdaContext` class, which only exists in
- # the AWS runtime
- # - If we were to transmit call args data they way we transmit event and
- # envelope data (through JSON), we'd quickly run into the problem that all
- # sorts of stuff isn't serializable by `json.dumps` out of the box, up to
- # and including `datetime` objects (so anything with a timestamp is
- # automatically out)
-
- # Perhaps these challenges can be solved in a cleaner and more systematic
- # way if we ever decide to refactor the entire AWS testing apparatus.
-
- import inspect
-
- envelopes, events, response = run_lambda_function(
- LAMBDA_PRELUDE
- + dedent(inspect.getsource(StringContaining))
- + dedent(inspect.getsource(DictionaryContaining))
- + dedent(inspect.getsource(ObjectDescribedBy))
- + dedent(
- """
- try:
- from unittest import mock # python 3.3 and above
- except ImportError:
- import mock # python < 3.3
-
- def _safe_is_equal(x, y):
- # copied from conftest.py - see docstring and comments there
- try:
- is_equal = x.__eq__(y)
- except AttributeError:
- is_equal = NotImplemented
-
- if is_equal == NotImplemented:
- # using == smoothes out weird variations exposed by raw __eq__
- return x == y
-
- return is_equal
-
- def test_handler(event, context):
- # this runs after the transaction has started, which means we
- # can make assertions about traces_sampler
- try:
- traces_sampler.assert_any_call(
- DictionaryContaining(
- {
- "aws_event": DictionaryContaining({
- "httpMethod": "GET",
- "path": "/sit/stay/rollover",
- "headers": {"Host": "dogs.are.great", "X-Forwarded-Proto": "http"},
- }),
- "aws_context": ObjectDescribedBy(
- type=get_lambda_bootstrap().LambdaContext,
- attrs={
- 'function_name': StringContaining("test_function"),
- 'function_version': '$LATEST',
- }
- )
- }
- )
- )
- except AssertionError:
- # catch the error and return it because the error itself will
- # get swallowed by the SDK as an "internal exception"
- return {"AssertionError raised": True,}
-
- return {"AssertionError raised": False,}
-
-
- traces_sampler = mock.Mock(return_value=True)
-
- init_sdk(
- traces_sampler=traces_sampler,
- )
- """
- ),
- b'{"httpMethod": "GET", "path": "/sit/stay/rollover", "headers": {"Host": "dogs.are.great", "X-Forwarded-Proto": "http"}}',
- )
-
- assert response["Payload"]["AssertionError raised"] is False
-
-
-def test_serverless_no_code_instrumentation(run_lambda_function):
- """
- Test that ensures that just by adding a lambda layer containing the
- python sdk, with no code changes sentry is able to capture errors
- """
-
- for initial_handler in [
- None,
- "test_dir/test_lambda.test_handler",
- "test_dir.test_lambda.test_handler",
- ]:
- print("Testing Initial Handler ", initial_handler)
- _, _, response = run_lambda_function(
- dedent(
- """
- import sentry_sdk
-
- def test_handler(event, context):
- current_client = sentry_sdk.Hub.current.client
-
- assert current_client is not None
-
- assert len(current_client.options['integrations']) == 1
- assert isinstance(current_client.options['integrations'][0],
- sentry_sdk.integrations.aws_lambda.AwsLambdaIntegration)
-
- raise Exception("something went wrong")
- """
- ),
- b'{"foo": "bar"}',
- layer=True,
- initial_handler=initial_handler,
- )
- assert response["FunctionError"] == "Unhandled"
- assert response["StatusCode"] == 200
-
- assert response["Payload"]["errorType"] != "AssertionError"
-
- assert response["Payload"]["errorType"] == "Exception"
- assert response["Payload"]["errorMessage"] == "something went wrong"
-
- assert "sentry_handler" in response["LogResult"][3].decode("utf-8")
diff --git a/tests/integrations/aws_lambda/test_aws_lambda.py b/tests/integrations/aws_lambda/test_aws_lambda.py
new file mode 100644
index 0000000000..664220464c
--- /dev/null
+++ b/tests/integrations/aws_lambda/test_aws_lambda.py
@@ -0,0 +1,575 @@
+import boto3
+import docker
+import json
+import pytest
+import subprocess
+import tempfile
+import time
+import yaml
+
+from unittest import mock
+
+from aws_cdk import App
+
+from .utils import LocalLambdaStack, SentryServerForTesting, SAM_PORT
+
+
+DOCKER_NETWORK_NAME = "lambda-test-network"
+SAM_TEMPLATE_FILE = "sam.template.yaml"
+
+
+@pytest.fixture(scope="session", autouse=True)
+def test_environment():
+ print("[test_environment fixture] Setting up AWS Lambda test infrastructure")
+
+ # Create a Docker network
+ docker_client = docker.from_env()
+ docker_client.networks.prune()
+ docker_client.networks.create(DOCKER_NETWORK_NAME, driver="bridge")
+
+ # Start Sentry server
+ server = SentryServerForTesting()
+ server.start()
+ time.sleep(1) # Give it a moment to start up
+
+ # Create local AWS SAM stack
+ app = App()
+ stack = LocalLambdaStack(app, "LocalLambdaStack")
+
+ # Write SAM template to file
+ template = app.synth().get_stack_by_name("LocalLambdaStack").template
+ with open(SAM_TEMPLATE_FILE, "w") as f:
+ yaml.dump(template, f)
+
+ # Write SAM debug log to file
+ debug_log_file = tempfile.gettempdir() + "/sentry_aws_lambda_tests_sam_debug.log"
+ debug_log = open(debug_log_file, "w")
+ print("[test_environment fixture] Writing SAM debug log to: %s" % debug_log_file)
+
+ # Start SAM local
+ process = subprocess.Popen(
+ [
+ "sam",
+ "local",
+ "start-lambda",
+ "--debug",
+ "--template",
+ SAM_TEMPLATE_FILE,
+ "--warm-containers",
+ "EAGER",
+ "--docker-network",
+ DOCKER_NETWORK_NAME,
+ ],
+ stdout=debug_log,
+ stderr=debug_log,
+ text=True, # This makes stdout/stderr return strings instead of bytes
+ )
+
+ try:
+ # Wait for SAM to be ready
+ LocalLambdaStack.wait_for_stack()
+
+ def before_test():
+ server.clear_envelopes()
+
+ yield {
+ "stack": stack,
+ "server": server,
+ "before_test": before_test,
+ }
+
+ finally:
+ print("[test_environment fixture] Tearing down AWS Lambda test infrastructure")
+
+ process.terminate()
+ process.wait(timeout=5) # Give it time to shut down gracefully
+
+ # Force kill if still running
+ if process.poll() is None:
+ process.kill()
+
+
+@pytest.fixture(autouse=True)
+def clear_before_test(test_environment):
+ test_environment["before_test"]()
+
+
+@pytest.fixture
+def lambda_client():
+ """
+ Create a boto3 client configured to use the local AWS SAM instance.
+ """
+ return boto3.client(
+ "lambda",
+ endpoint_url=f"http://127.0.0.1:{SAM_PORT}", # noqa: E231
+ aws_access_key_id="dummy",
+ aws_secret_access_key="dummy",
+ region_name="us-east-1",
+ )
+
+
+def test_basic_no_exception(lambda_client, test_environment):
+ lambda_client.invoke(
+ FunctionName="BasicOk",
+ Payload=json.dumps({}),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (transaction_event,) = envelopes
+
+ assert transaction_event["type"] == "transaction"
+ assert transaction_event["transaction"] == "BasicOk"
+ assert transaction_event["sdk"]["name"] == "sentry.python.aws_lambda"
+ assert transaction_event["tags"] == {"aws_region": "us-east-1"}
+
+ assert transaction_event["extra"]["cloudwatch logs"] == {
+ "log_group": mock.ANY,
+ "log_stream": mock.ANY,
+ "url": mock.ANY,
+ }
+ assert transaction_event["extra"]["lambda"] == {
+ "aws_request_id": mock.ANY,
+ "execution_duration_in_millis": mock.ANY,
+ "function_name": "BasicOk",
+ "function_version": "$LATEST",
+ "invoked_function_arn": "arn:aws:lambda:us-east-1:012345678912:function:BasicOk",
+ "remaining_time_in_millis": mock.ANY,
+ }
+ assert transaction_event["contexts"]["trace"] == {
+ "op": "function.aws",
+ "description": mock.ANY,
+ "span_id": mock.ANY,
+ "parent_span_id": mock.ANY,
+ "trace_id": mock.ANY,
+ "origin": "auto.function.aws_lambda",
+ "data": mock.ANY,
+ }
+
+
+def test_basic_exception(lambda_client, test_environment):
+ lambda_client.invoke(
+ FunctionName="BasicException",
+ Payload=json.dumps({}),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ # The second envelope we ignore.
+ # It is the transaction that we test in test_basic_no_exception.
+ (error_event, _) = envelopes
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "RuntimeError"
+ assert error_event["exception"]["values"][0]["value"] == "Oh!"
+ assert error_event["sdk"]["name"] == "sentry.python.aws_lambda"
+
+ assert error_event["tags"] == {"aws_region": "us-east-1"}
+ assert error_event["extra"]["cloudwatch logs"] == {
+ "log_group": mock.ANY,
+ "log_stream": mock.ANY,
+ "url": mock.ANY,
+ }
+ assert error_event["extra"]["lambda"] == {
+ "aws_request_id": mock.ANY,
+ "execution_duration_in_millis": mock.ANY,
+ "function_name": "BasicException",
+ "function_version": "$LATEST",
+ "invoked_function_arn": "arn:aws:lambda:us-east-1:012345678912:function:BasicException",
+ "remaining_time_in_millis": mock.ANY,
+ }
+ assert error_event["contexts"]["trace"] == {
+ "op": "function.aws",
+ "description": mock.ANY,
+ "span_id": mock.ANY,
+ "parent_span_id": mock.ANY,
+ "trace_id": mock.ANY,
+ "origin": "auto.function.aws_lambda",
+ "data": mock.ANY,
+ }
+
+
+def test_init_error(lambda_client, test_environment):
+ lambda_client.invoke(
+ FunctionName="InitError",
+ Payload=json.dumps({}),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (error_event, transaction_event) = envelopes
+
+ assert (
+ error_event["exception"]["values"][0]["value"] == "name 'func' is not defined"
+ )
+ assert transaction_event["transaction"] == "InitError"
+
+
+def test_timeout_error(lambda_client, test_environment):
+ lambda_client.invoke(
+ FunctionName="TimeoutError",
+ Payload=json.dumps({}),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (error_event,) = envelopes
+
+ assert error_event["level"] == "error"
+ assert error_event["extra"]["lambda"]["function_name"] == "TimeoutError"
+
+ (exception,) = error_event["exception"]["values"]
+ assert not exception["mechanism"]["handled"]
+ assert exception["type"] == "ServerlessTimeoutWarning"
+ assert exception["value"].startswith(
+ "WARNING : Function is expected to get timed out. Configured timeout duration ="
+ )
+ assert exception["mechanism"]["type"] == "threading"
+
+
+def test_timeout_error_scope_modified(lambda_client, test_environment):
+ lambda_client.invoke(
+ FunctionName="TimeoutErrorScopeModified",
+ Payload=json.dumps({}),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (error_event,) = envelopes
+
+ assert error_event["level"] == "error"
+ assert (
+ error_event["extra"]["lambda"]["function_name"] == "TimeoutErrorScopeModified"
+ )
+
+ (exception,) = error_event["exception"]["values"]
+ assert not exception["mechanism"]["handled"]
+ assert exception["type"] == "ServerlessTimeoutWarning"
+ assert exception["value"].startswith(
+ "WARNING : Function is expected to get timed out. Configured timeout duration ="
+ )
+ assert exception["mechanism"]["type"] == "threading"
+
+ assert error_event["tags"]["custom_tag"] == "custom_value"
+
+
+@pytest.mark.parametrize(
+ "aws_event, has_request_data, batch_size",
+ [
+ (b"1231", False, 1),
+ (b"11.21", False, 1),
+ (b'"Good dog!"', False, 1),
+ (b"true", False, 1),
+ (
+ b"""
+ [
+ {"good dog": "Maisey"},
+ {"good dog": "Charlie"},
+ {"good dog": "Cory"},
+ {"good dog": "Bodhi"}
+ ]
+ """,
+ False,
+ 4,
+ ),
+ (
+ b"""
+ [
+ {
+ "headers": {
+ "Host": "x1.io",
+ "X-Forwarded-Proto": "https"
+ },
+ "httpMethod": "GET",
+ "path": "/1",
+ "queryStringParameters": {
+ "done": "f"
+ },
+ "d": "D1"
+ },
+ {
+ "headers": {
+ "Host": "x2.io",
+ "X-Forwarded-Proto": "http"
+ },
+ "httpMethod": "POST",
+ "path": "/2",
+ "queryStringParameters": {
+ "done": "t"
+ },
+ "d": "D2"
+ }
+ ]
+ """,
+ True,
+ 2,
+ ),
+ (b"[]", False, 1),
+ ],
+ ids=[
+ "event as integer",
+ "event as float",
+ "event as string",
+ "event as bool",
+ "event as list of dicts",
+ "event as dict",
+ "event as empty list",
+ ],
+)
+def test_non_dict_event(
+ lambda_client, test_environment, aws_event, has_request_data, batch_size
+):
+ lambda_client.invoke(
+ FunctionName="BasicException",
+ Payload=aws_event,
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (error_event, transaction_event) = envelopes
+
+ assert transaction_event["type"] == "transaction"
+ assert transaction_event["transaction"] == "BasicException"
+ assert transaction_event["sdk"]["name"] == "sentry.python.aws_lambda"
+ assert transaction_event["contexts"]["trace"]["status"] == "internal_error"
+
+ assert error_event["level"] == "error"
+ assert error_event["transaction"] == "BasicException"
+ assert error_event["sdk"]["name"] == "sentry.python.aws_lambda"
+ assert error_event["exception"]["values"][0]["type"] == "RuntimeError"
+ assert error_event["exception"]["values"][0]["value"] == "Oh!"
+ assert error_event["exception"]["values"][0]["mechanism"]["type"] == "aws_lambda"
+
+ if has_request_data:
+ request_data = {
+ "headers": {"Host": "x1.io", "X-Forwarded-Proto": "https"},
+ "method": "GET",
+ "url": "https://x1.io/1",
+ "query_string": {
+ "done": "f",
+ },
+ }
+ else:
+ request_data = {"url": "awslambda:///BasicException"}
+
+ assert error_event["request"] == request_data
+ assert transaction_event["request"] == request_data
+
+ if batch_size > 1:
+ assert error_event["tags"]["batch_size"] == batch_size
+ assert error_event["tags"]["batch_request"] is True
+ assert transaction_event["tags"]["batch_size"] == batch_size
+ assert transaction_event["tags"]["batch_request"] is True
+
+
+def test_request_data(lambda_client, test_environment):
+ payload = b"""
+ {
+ "resource": "/asd",
+ "path": "/asd",
+ "httpMethod": "GET",
+ "headers": {
+ "Host": "iwsz2c7uwi.execute-api.us-east-1.amazonaws.com",
+ "User-Agent": "custom",
+ "X-Forwarded-Proto": "https"
+ },
+ "queryStringParameters": {
+ "bonkers": "true"
+ },
+ "pathParameters": null,
+ "stageVariables": null,
+ "requestContext": {
+ "identity": {
+ "sourceIp": "213.47.147.207",
+ "userArn": "42"
+ }
+ },
+ "body": null,
+ "isBase64Encoded": false
+ }
+ """
+
+ lambda_client.invoke(
+ FunctionName="BasicOk",
+ Payload=payload,
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (transaction_event,) = envelopes
+
+ assert transaction_event["request"] == {
+ "headers": {
+ "Host": "iwsz2c7uwi.execute-api.us-east-1.amazonaws.com",
+ "User-Agent": "custom",
+ "X-Forwarded-Proto": "https",
+ },
+ "method": "GET",
+ "query_string": {"bonkers": "true"},
+ "url": "https://iwsz2c7uwi.execute-api.us-east-1.amazonaws.com/asd",
+ }
+
+
+def test_trace_continuation(lambda_client, test_environment):
+ trace_id = "471a43a4192642f0b136d5159a501701"
+ parent_span_id = "6e8f22c393e68f19"
+ parent_sampled = 1
+ sentry_trace_header = "{}-{}-{}".format(trace_id, parent_span_id, parent_sampled)
+
+ # We simulate here AWS Api Gateway's behavior of passing HTTP headers
+ # as the `headers` dict in the event passed to the Lambda function.
+ payload = {
+ "headers": {
+ "sentry-trace": sentry_trace_header,
+ }
+ }
+
+ lambda_client.invoke(
+ FunctionName="BasicException",
+ Payload=json.dumps(payload),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (error_event, transaction_event) = envelopes
+
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ == "471a43a4192642f0b136d5159a501701"
+ )
+
+
+@pytest.mark.parametrize(
+ "payload",
+ [
+ {},
+ {"headers": None},
+ {"headers": ""},
+ {"headers": {}},
+ {"headers": []}, # EventBridge sends an empty list
+ ],
+ ids=[
+ "no headers",
+ "none headers",
+ "empty string headers",
+ "empty dict headers",
+ "empty list headers",
+ ],
+)
+def test_headers(lambda_client, test_environment, payload):
+ lambda_client.invoke(
+ FunctionName="BasicException",
+ Payload=json.dumps(payload),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (error_event, _) = envelopes
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "RuntimeError"
+ assert error_event["exception"]["values"][0]["value"] == "Oh!"
+
+
+def test_span_origin(lambda_client, test_environment):
+ lambda_client.invoke(
+ FunctionName="BasicOk",
+ Payload=json.dumps({}),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ (transaction_event,) = envelopes
+
+ assert (
+ transaction_event["contexts"]["trace"]["origin"] == "auto.function.aws_lambda"
+ )
+
+
+def test_traces_sampler_has_correct_sampling_context(lambda_client, test_environment):
+ """
+ Test that aws_event and aws_context are passed in the custom_sampling_context
+ when using the AWS Lambda integration.
+ """
+ test_payload = {"test_key": "test_value"}
+ response = lambda_client.invoke(
+ FunctionName="TracesSampler",
+ Payload=json.dumps(test_payload),
+ )
+ response_payload = json.loads(response["Payload"].read().decode())
+ sampling_context_data = json.loads(response_payload["body"])[
+ "sampling_context_data"
+ ]
+ assert sampling_context_data.get("aws_event_present") is True
+ assert sampling_context_data.get("aws_context_present") is True
+ assert sampling_context_data.get("event_data", {}).get("test_key") == "test_value"
+
+
+@pytest.mark.parametrize(
+ "lambda_function_name",
+ ["RaiseErrorPerformanceEnabled", "RaiseErrorPerformanceDisabled"],
+)
+def test_error_has_new_trace_context(
+ lambda_client, test_environment, lambda_function_name
+):
+ lambda_client.invoke(
+ FunctionName=lambda_function_name,
+ Payload=json.dumps({}),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ if lambda_function_name == "RaiseErrorPerformanceEnabled":
+ (error_event, transaction_event) = envelopes
+ else:
+ (error_event,) = envelopes
+ transaction_event = None
+
+ assert "trace" in error_event["contexts"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ if transaction_event:
+ assert "trace" in transaction_event["contexts"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+@pytest.mark.parametrize(
+ "lambda_function_name",
+ ["RaiseErrorPerformanceEnabled", "RaiseErrorPerformanceDisabled"],
+)
+def test_error_has_existing_trace_context(
+ lambda_client, test_environment, lambda_function_name
+):
+ trace_id = "471a43a4192642f0b136d5159a501701"
+ parent_span_id = "6e8f22c393e68f19"
+ parent_sampled = 1
+ sentry_trace_header = "{}-{}-{}".format(trace_id, parent_span_id, parent_sampled)
+
+ # We simulate here AWS Api Gateway's behavior of passing HTTP headers
+ # as the `headers` dict in the event passed to the Lambda function.
+ payload = {
+ "headers": {
+ "sentry-trace": sentry_trace_header,
+ }
+ }
+
+ lambda_client.invoke(
+ FunctionName=lambda_function_name,
+ Payload=json.dumps(payload),
+ )
+ envelopes = test_environment["server"].envelopes
+
+ if lambda_function_name == "RaiseErrorPerformanceEnabled":
+ (error_event, transaction_event) = envelopes
+ else:
+ (error_event,) = envelopes
+ transaction_event = None
+
+ assert "trace" in error_event["contexts"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == "471a43a4192642f0b136d5159a501701"
+ )
+
+ if transaction_event:
+ assert "trace" in transaction_event["contexts"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+ assert (
+ transaction_event["contexts"]["trace"]["trace_id"]
+ == "471a43a4192642f0b136d5159a501701"
+ )
diff --git a/tests/integrations/aws_lambda/utils.py b/tests/integrations/aws_lambda/utils.py
new file mode 100644
index 0000000000..d20c9352e7
--- /dev/null
+++ b/tests/integrations/aws_lambda/utils.py
@@ -0,0 +1,294 @@
+import gzip
+import json
+import os
+import shutil
+import subprocess
+import requests
+import sys
+import time
+import threading
+import socket
+import platform
+
+from aws_cdk import (
+ CfnResource,
+ Stack,
+)
+from constructs import Construct
+from fastapi import FastAPI, Request
+import uvicorn
+
+from scripts.build_aws_lambda_layer import build_packaged_zip, DIST_PATH
+
+
+LAMBDA_FUNCTION_DIR = "./tests/integrations/aws_lambda/lambda_functions/"
+LAMBDA_FUNCTION_WITH_EMBEDDED_SDK_DIR = (
+ "./tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/"
+)
+LAMBDA_FUNCTION_TIMEOUT = 10
+SAM_PORT = 3001
+
+PYTHON_VERSION = f"python{sys.version_info.major}.{sys.version_info.minor}"
+
+
+def get_host_ip():
+ """
+ Returns the IP address of the host we are running on.
+ """
+ if os.environ.get("GITHUB_ACTIONS"):
+ # Running in GitHub Actions
+ hostname = socket.gethostname()
+ host = socket.gethostbyname(hostname)
+ else:
+ # Running locally
+ if platform.system() in ["Darwin", "Windows"]:
+ # Windows or MacOS
+ host = "host.docker.internal"
+ else:
+ # Linux
+ hostname = socket.gethostname()
+ host = socket.gethostbyname(hostname)
+
+ return host
+
+
+def get_project_root():
+ """
+ Returns the absolute path to the project root directory.
+ """
+ # Start from the current file's directory
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+
+ # Navigate up to the project root (4 levels up from tests/integrations/aws_lambda/)
+ # This is equivalent to the multiple dirname() calls
+ project_root = os.path.abspath(os.path.join(current_dir, "../../../"))
+
+ return project_root
+
+
+class LocalLambdaStack(Stack):
+ """
+ Uses the AWS CDK to create a local SAM stack containing Lambda functions.
+ """
+
+ def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
+ print("[LocalLambdaStack] Creating local SAM Lambda Stack")
+ super().__init__(scope, construct_id, **kwargs)
+
+ # Override the template synthesis
+ self.template_options.template_format_version = "2010-09-09"
+ self.template_options.transforms = ["AWS::Serverless-2016-10-31"]
+
+ print("[LocalLambdaStack] Create Sentry Lambda layer package")
+ filename = "sentry-sdk-lambda-layer.zip"
+ build_packaged_zip(
+ make_dist=True,
+ out_zip_filename=filename,
+ )
+
+ print(
+ "[LocalLambdaStack] Add Sentry Lambda layer containing the Sentry SDK to the SAM stack"
+ )
+ self.sentry_layer = CfnResource(
+ self,
+ "SentryPythonServerlessSDK",
+ type="AWS::Serverless::LayerVersion",
+ properties={
+ "ContentUri": os.path.join(DIST_PATH, filename),
+ "CompatibleRuntimes": [
+ PYTHON_VERSION,
+ ],
+ },
+ )
+
+ dsn = f"http://123@{get_host_ip()}:9999/0" # noqa: E231
+ print("[LocalLambdaStack] Using Sentry DSN: %s" % dsn)
+
+ print(
+ "[LocalLambdaStack] Add all Lambda functions defined in "
+ "/tests/integrations/aws_lambda/lambda_functions/ to the SAM stack"
+ )
+ lambda_dirs = [
+ d
+ for d in os.listdir(LAMBDA_FUNCTION_DIR)
+ if os.path.isdir(os.path.join(LAMBDA_FUNCTION_DIR, d))
+ ]
+ for lambda_dir in lambda_dirs:
+ CfnResource(
+ self,
+ lambda_dir,
+ type="AWS::Serverless::Function",
+ properties={
+ "CodeUri": os.path.join(LAMBDA_FUNCTION_DIR, lambda_dir),
+ "Handler": "sentry_sdk.integrations.init_serverless_sdk.sentry_lambda_handler",
+ "Runtime": PYTHON_VERSION,
+ "Timeout": LAMBDA_FUNCTION_TIMEOUT,
+ "Layers": [
+ {"Ref": self.sentry_layer.logical_id}
+ ], # Add layer containing the Sentry SDK to function.
+ "Environment": {
+ "Variables": {
+ "SENTRY_DSN": dsn,
+ "SENTRY_INITIAL_HANDLER": "index.handler",
+ "SENTRY_TRACES_SAMPLE_RATE": "1.0",
+ }
+ },
+ },
+ )
+ print(
+ "[LocalLambdaStack] - Created Lambda function: %s (%s)"
+ % (
+ lambda_dir,
+ os.path.join(LAMBDA_FUNCTION_DIR, lambda_dir),
+ )
+ )
+
+ print(
+ "[LocalLambdaStack] Add all Lambda functions defined in "
+ "/tests/integrations/aws_lambda/lambda_functions_with_embedded_sdk/ to the SAM stack"
+ )
+ lambda_dirs = [
+ d
+ for d in os.listdir(LAMBDA_FUNCTION_WITH_EMBEDDED_SDK_DIR)
+ if os.path.isdir(os.path.join(LAMBDA_FUNCTION_WITH_EMBEDDED_SDK_DIR, d))
+ ]
+ for lambda_dir in lambda_dirs:
+ # Copy the Sentry SDK into the function directory
+ sdk_path = os.path.join(
+ LAMBDA_FUNCTION_WITH_EMBEDDED_SDK_DIR, lambda_dir, "sentry_sdk"
+ )
+ if not os.path.exists(sdk_path):
+ # Find the Sentry SDK in the current environment
+ import sentry_sdk as sdk_module
+
+ sdk_source = os.path.dirname(sdk_module.__file__)
+ shutil.copytree(sdk_source, sdk_path)
+
+ # Install the requirements of Sentry SDK into the function directory
+ requirements_file = os.path.join(
+ get_project_root(), "requirements-aws-lambda-layer.txt"
+ )
+
+ # Install the package using pip
+ subprocess.check_call(
+ [
+ sys.executable,
+ "-m",
+ "pip",
+ "install",
+ "--upgrade",
+ "--target",
+ os.path.join(LAMBDA_FUNCTION_WITH_EMBEDDED_SDK_DIR, lambda_dir),
+ "-r",
+ requirements_file,
+ ]
+ )
+
+ CfnResource(
+ self,
+ lambda_dir,
+ type="AWS::Serverless::Function",
+ properties={
+ "CodeUri": os.path.join(
+ LAMBDA_FUNCTION_WITH_EMBEDDED_SDK_DIR, lambda_dir
+ ),
+ "Handler": "index.handler",
+ "Runtime": PYTHON_VERSION,
+ "Timeout": LAMBDA_FUNCTION_TIMEOUT,
+ "Environment": {
+ "Variables": {
+ "SENTRY_DSN": dsn,
+ }
+ },
+ },
+ )
+ print(
+ "[LocalLambdaStack] - Created Lambda function: %s (%s)"
+ % (
+ lambda_dir,
+ os.path.join(LAMBDA_FUNCTION_DIR, lambda_dir),
+ )
+ )
+
+ @classmethod
+ def wait_for_stack(cls, timeout=60, port=SAM_PORT):
+ """
+ Wait for SAM to be ready, with timeout.
+ """
+ start_time = time.time()
+ while True:
+ if time.time() - start_time > timeout:
+ raise TimeoutError(
+ "AWS SAM failed to start within %s seconds. (Maybe Docker is not running?)"
+ % timeout
+ )
+
+ try:
+ # Try to connect to SAM
+ response = requests.get(f"http://127.0.0.1:{port}/") # noqa: E231
+ if response.status_code == 200 or response.status_code == 404:
+ return
+
+ except requests.exceptions.ConnectionError:
+ time.sleep(1)
+ continue
+
+
+class SentryServerForTesting:
+ """
+ A simple Sentry.io style server that accepts envelopes and stores them in a list.
+ """
+
+ def __init__(self, host="0.0.0.0", port=9999, log_level="warning"):
+ self.envelopes = []
+ self.host = host
+ self.port = port
+ self.log_level = log_level
+ self.app = FastAPI()
+
+ @self.app.post("/api/0/envelope/")
+ async def envelope(request: Request):
+ print("[SentryServerForTesting] Received envelope")
+ try:
+ raw_body = await request.body()
+ except Exception:
+ return {"status": "no body received"}
+
+ try:
+ body = gzip.decompress(raw_body).decode("utf-8")
+ except Exception:
+ # If decompression fails, assume it's plain text
+ body = raw_body.decode("utf-8")
+
+ lines = body.split("\n")
+
+ current_line = 1 # line 0 is envelope header
+ while current_line < len(lines):
+ # skip empty lines
+ if not lines[current_line].strip():
+ current_line += 1
+ continue
+
+ # skip envelope item header
+ current_line += 1
+
+ # add envelope item to store
+ envelope_item = lines[current_line]
+ if envelope_item.strip():
+ self.envelopes.append(json.loads(envelope_item))
+
+ return {"status": "ok"}
+
+ def run_server(self):
+ uvicorn.run(self.app, host=self.host, port=self.port, log_level=self.log_level)
+
+ def start(self):
+ print(
+ "[SentryServerForTesting] Starting server on %s:%s" % (self.host, self.port)
+ )
+ server_thread = threading.Thread(target=self.run_server, daemon=True)
+ server_thread.start()
+
+ def clear_envelopes(self):
+ print("[SentryServerForTesting] Clearing envelopes")
+ self.envelopes = []
diff --git a/tests/integrations/beam/__init__.py b/tests/integrations/beam/__init__.py
new file mode 100644
index 0000000000..f4fe442d63
--- /dev/null
+++ b/tests/integrations/beam/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("apache_beam")
diff --git a/tests/integrations/beam/test_beam.py b/tests/integrations/beam/test_beam.py
index 7aeb617e3c..809c4122e4 100644
--- a/tests/integrations/beam/test_beam.py
+++ b/tests/integrations/beam/test_beam.py
@@ -1,8 +1,6 @@
import pytest
import inspect
-pytest.importorskip("apache_beam")
-
import dill
from sentry_sdk.integrations.beam import (
@@ -14,9 +12,14 @@
from apache_beam.typehints.trivial_inference import instance_to_type
from apache_beam.typehints.decorators import getcallargs_forhints
from apache_beam.transforms.core import DoFn, ParDo, _DoFnParam, CallableWrapperDoFn
-from apache_beam.runners.common import DoFnInvoker, OutputProcessor, DoFnContext
+from apache_beam.runners.common import DoFnInvoker, DoFnContext
from apache_beam.utils.windowed_value import WindowedValue
+try:
+ from apache_beam.runners.common import OutputHandler
+except ImportError:
+ from apache_beam.runners.common import OutputProcessor as OutputHandler
+
def foo():
return True
@@ -42,7 +45,7 @@ def process(self):
return self.fn()
-class B(A, object):
+class B(A):
def fa(self, x, element=False, another_element=False):
if x or (element and not another_element):
# print(self.r)
@@ -52,7 +55,7 @@ def fa(self, x, element=False, another_element=False):
def __init__(self):
self.r = "We are in B"
- super(B, self).__init__(self.fa)
+ super().__init__(self.fa)
class SimpleFunc(DoFn):
@@ -141,19 +144,26 @@ def test_monkey_patch_signature(f, args, kwargs):
try:
expected_signature = inspect.signature(f)
test_signature = inspect.signature(f_temp)
- assert (
- expected_signature == test_signature
- ), "Failed on {}, signature {} does not match {}".format(
- f, expected_signature, test_signature
+ assert expected_signature == test_signature, (
+ "Failed on {}, signature {} does not match {}".format(
+ f, expected_signature, test_signature
+ )
)
except Exception:
# expected to pass for py2.7
pass
-class _OutputProcessor(OutputProcessor):
+class _OutputHandler(OutputHandler):
def process_outputs(
self, windowed_input_element, results, watermark_estimator=None
+ ):
+ self.handle_process_outputs(
+ windowed_input_element, results, watermark_estimator
+ )
+
+ def handle_process_outputs(
+ self, windowed_input_element, results, watermark_estimator=None
):
print(windowed_input_element)
try:
@@ -170,9 +180,13 @@ def inner(fn):
# Little hack to avoid having to run the whole pipeline.
pardo = ParDo(fn)
signature = pardo._signature
- output_processor = _OutputProcessor()
+ output_processor = _OutputHandler()
return DoFnInvoker.create_invoker(
- signature, output_processor, DoFnContext("test")
+ signature,
+ output_processor,
+ DoFnContext("test"),
+ input_args=[],
+ input_kwargs={},
)
return inner
diff --git a/tests/integrations/boto3/aws_mock.py b/tests/integrations/boto3/aws_mock.py
index 84ff23f466..da97570e4c 100644
--- a/tests/integrations/boto3/aws_mock.py
+++ b/tests/integrations/boto3/aws_mock.py
@@ -10,7 +10,7 @@ def stream(self, **kwargs):
contents = self.read()
-class MockResponse(object):
+class MockResponse:
def __init__(self, client, status_code, headers, body):
self._client = client
self._status_code = status_code
diff --git a/tests/integrations/boto3/test_s3.py b/tests/integrations/boto3/test_s3.py
index 7f02d422a0..97a1543b0f 100644
--- a/tests/integrations/boto3/test_s3.py
+++ b/tests/integrations/boto3/test_s3.py
@@ -1,9 +1,14 @@
-from sentry_sdk import Hub
+from unittest import mock
+
+import boto3
+import pytest
+
+import sentry_sdk
from sentry_sdk.integrations.boto3 import Boto3Integration
-from tests.integrations.boto3.aws_mock import MockResponse
+from tests.conftest import ApproxDict
from tests.integrations.boto3 import read_fixture
+from tests.integrations.boto3.aws_mock import MockResponse
-import boto3
session = boto3.Session(
aws_access_key_id="-",
@@ -16,7 +21,7 @@ def test_basic(sentry_init, capture_events):
events = capture_events()
s3 = session.resource("s3")
- with Hub.current.start_transaction() as transaction, MockResponse(
+ with sentry_sdk.start_transaction() as transaction, MockResponse(
s3.meta.client, 200, {}, read_fixture("s3_list.xml")
):
bucket = s3.Bucket("bucket")
@@ -39,7 +44,7 @@ def test_streaming(sentry_init, capture_events):
events = capture_events()
s3 = session.resource("s3")
- with Hub.current.start_transaction() as transaction, MockResponse(
+ with sentry_sdk.start_transaction() as transaction, MockResponse(
s3.meta.client, 200, {}, b"hello"
):
obj = s3.Bucket("bucket").Object("foo.pdf")
@@ -53,9 +58,19 @@ def test_streaming(sentry_init, capture_events):
(event,) = events
assert event["type"] == "transaction"
assert len(event["spans"]) == 2
+
span1 = event["spans"][0]
assert span1["op"] == "http.client"
assert span1["description"] == "aws.s3.GetObject"
+ assert span1["data"] == ApproxDict(
+ {
+ "http.method": "GET",
+ "aws.request.url": "https://bucket.s3.amazonaws.com/foo.pdf",
+ "http.fragment": "",
+ "http.query": "",
+ }
+ )
+
span2 = event["spans"][1]
assert span2["op"] == "http.client.stream"
assert span2["description"] == "aws.s3.GetObject"
@@ -67,7 +82,7 @@ def test_streaming_close(sentry_init, capture_events):
events = capture_events()
s3 = session.resource("s3")
- with Hub.current.start_transaction() as transaction, MockResponse(
+ with sentry_sdk.start_transaction() as transaction, MockResponse(
s3.meta.client, 200, {}, b"hello"
):
obj = s3.Bucket("bucket").Object("foo.pdf")
@@ -83,3 +98,54 @@ def test_streaming_close(sentry_init, capture_events):
assert span1["op"] == "http.client"
span2 = event["spans"][1]
assert span2["op"] == "http.client.stream"
+
+
+@pytest.mark.tests_internal_exceptions
+def test_omit_url_data_if_parsing_fails(sentry_init, capture_events):
+ sentry_init(traces_sample_rate=1.0, integrations=[Boto3Integration()])
+ events = capture_events()
+
+ s3 = session.resource("s3")
+
+ with mock.patch(
+ "sentry_sdk.integrations.boto3.parse_url",
+ side_effect=ValueError,
+ ):
+ with sentry_sdk.start_transaction() as transaction, MockResponse(
+ s3.meta.client, 200, {}, read_fixture("s3_list.xml")
+ ):
+ bucket = s3.Bucket("bucket")
+ items = [obj for obj in bucket.objects.all()]
+ assert len(items) == 2
+ assert items[0].key == "foo.txt"
+ assert items[1].key == "bar.txt"
+ transaction.finish()
+
+ (event,) = events
+ assert event["spans"][0]["data"] == ApproxDict(
+ {
+ "http.method": "GET",
+ # no url data
+ }
+ )
+
+ assert "aws.request.url" not in event["spans"][0]["data"]
+ assert "http.fragment" not in event["spans"][0]["data"]
+ assert "http.query" not in event["spans"][0]["data"]
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(traces_sample_rate=1.0, integrations=[Boto3Integration()])
+ events = capture_events()
+
+ s3 = session.resource("s3")
+ with sentry_sdk.start_transaction(), MockResponse(
+ s3.meta.client, 200, {}, read_fixture("s3_list.xml")
+ ):
+ bucket = s3.Bucket("bucket")
+ _ = [obj for obj in bucket.objects.all()]
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.http.boto3"
diff --git a/tests/integrations/bottle/__init__.py b/tests/integrations/bottle/__init__.py
new file mode 100644
index 0000000000..39015ee6f2
--- /dev/null
+++ b/tests/integrations/bottle/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("bottle")
diff --git a/tests/integrations/bottle/test_bottle.py b/tests/integrations/bottle/test_bottle.py
index dfd6e52f80..1965691d6c 100644
--- a/tests/integrations/bottle/test_bottle.py
+++ b/tests/integrations/bottle/test_bottle.py
@@ -2,17 +2,16 @@
import pytest
import logging
-
-pytest.importorskip("bottle")
-
from io import BytesIO
-from bottle import Bottle, debug as set_debug, abort, redirect
+from bottle import Bottle, debug as set_debug, abort, redirect, HTTPResponse
from sentry_sdk import capture_message
+from sentry_sdk.consts import DEFAULT_MAX_VALUE_LENGTH
+from sentry_sdk.integrations.bottle import BottleIntegration
+from sentry_sdk.serializer import MAX_DATABAG_BREADTH
from sentry_sdk.integrations.logging import LoggingIntegration
from werkzeug.test import Client
-
-import sentry_sdk.integrations.bottle as bottle_sentry
+from werkzeug.wrappers import Response
@pytest.fixture(scope="function")
@@ -46,7 +45,7 @@ def inner():
def test_has_context(sentry_init, app, capture_events, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()])
events = capture_events()
client = get_client()
@@ -77,11 +76,7 @@ def test_transaction_style(
capture_events,
get_client,
):
- sentry_init(
- integrations=[
- bottle_sentry.BottleIntegration(transaction_style=transaction_style)
- ]
- )
+ sentry_init(integrations=[BottleIntegration(transaction_style=transaction_style)])
events = capture_events()
client = get_client()
@@ -100,7 +95,7 @@ def test_transaction_style(
def test_errors(
sentry_init, capture_exceptions, capture_events, app, debug, catchall, get_client
):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()])
app.catchall = catchall
set_debug(mode=debug)
@@ -127,9 +122,9 @@ def index():
def test_large_json_request(sentry_init, capture_events, app, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()], max_request_body_size="always")
- data = {"foo": {"bar": "a" * 2000}}
+ data = {"foo": {"bar": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10)}}
@app.route("/", method="POST")
def index():
@@ -150,14 +145,19 @@ def index():
(event,) = events
assert event["_meta"]["request"]["data"]["foo"]["bar"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]["bar"]) == 1024
+ assert len(event["request"]["data"]["foo"]["bar"]) == DEFAULT_MAX_VALUE_LENGTH
@pytest.mark.parametrize("data", [{}, []], ids=["empty-dict", "empty-list"])
def test_empty_json_request(sentry_init, capture_events, app, data, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()])
@app.route("/", method="POST")
def index():
@@ -180,9 +180,9 @@ def index():
def test_medium_formdata_request(sentry_init, capture_events, app, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()], max_request_body_size="always")
- data = {"foo": "a" * 2000}
+ data = {"foo": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10)}
@app.route("/", method="POST")
def index():
@@ -200,18 +200,21 @@ def index():
(event,) = events
assert event["_meta"]["request"]["data"]["foo"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]) == 1024
+ assert len(event["request"]["data"]["foo"]) == DEFAULT_MAX_VALUE_LENGTH
@pytest.mark.parametrize("input_char", ["a", b"a"])
def test_too_large_raw_request(
sentry_init, input_char, capture_events, app, get_client
):
- sentry_init(
- integrations=[bottle_sentry.BottleIntegration()], request_bodies="small"
- )
+ sentry_init(integrations=[BottleIntegration()], max_request_body_size="small")
data = input_char * 2000
@@ -239,11 +242,12 @@ def index():
def test_files_and_form(sentry_init, capture_events, app, get_client):
- sentry_init(
- integrations=[bottle_sentry.BottleIntegration()], request_bodies="always"
- )
+ sentry_init(integrations=[BottleIntegration()], max_request_body_size="always")
- data = {"foo": "a" * 2000, "file": (BytesIO(b"hello"), "hello.txt")}
+ data = {
+ "foo": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10),
+ "file": (BytesIO(b"hello"), "hello.txt"),
+ }
@app.route("/", method="POST")
def index():
@@ -263,9 +267,14 @@ def index():
(event,) = events
assert event["_meta"]["request"]["data"]["foo"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]) == 1024
+ assert len(event["request"]["data"]["foo"]) == DEFAULT_MAX_VALUE_LENGTH
assert event["_meta"]["request"]["data"]["file"] == {
"": {
@@ -275,11 +284,40 @@ def index():
assert not event["request"]["data"]["file"]
+def test_json_not_truncated_if_max_request_body_size_is_always(
+ sentry_init, capture_events, app, get_client
+):
+ sentry_init(integrations=[BottleIntegration()], max_request_body_size="always")
+
+ data = {
+ "key{}".format(i): "value{}".format(i) for i in range(MAX_DATABAG_BREADTH + 10)
+ }
+
+ @app.route("/", method="POST")
+ def index():
+ import bottle
+
+ assert bottle.request.json == data
+ assert bottle.request.body.read() == json.dumps(data).encode("ascii")
+ capture_message("hi")
+ return "ok"
+
+ events = capture_events()
+
+ client = get_client()
+
+ response = client.post("/", content_type="application/json", data=json.dumps(data))
+ assert response[1] == "200 OK"
+
+ (event,) = events
+ assert event["request"]["data"] == data
+
+
@pytest.mark.parametrize(
"integrations",
[
- [bottle_sentry.BottleIntegration()],
- [bottle_sentry.BottleIntegration(), LoggingIntegration(event_level="ERROR")],
+ [BottleIntegration()],
+ [BottleIntegration(), LoggingIntegration(event_level="ERROR")],
],
)
def test_errors_not_reported_twice(
@@ -293,46 +331,24 @@ def test_errors_not_reported_twice(
@app.route("/")
def index():
- try:
- 1 / 0
- except Exception as e:
- logger.exception(e)
- raise e
+ 1 / 0
events = capture_events()
client = get_client()
+
with pytest.raises(ZeroDivisionError):
- client.get("/")
+ try:
+ client.get("/")
+ except ZeroDivisionError as e:
+ logger.exception(e)
+ raise e
assert len(events) == 1
-def test_logging(sentry_init, capture_events, app, get_client):
- # ensure that Bottle's logger magic doesn't break ours
- sentry_init(
- integrations=[
- bottle_sentry.BottleIntegration(),
- LoggingIntegration(event_level="ERROR"),
- ]
- )
-
- @app.route("/")
- def index():
- app.logger.error("hi")
- return "ok"
-
- events = capture_events()
-
- client = get_client()
- client.get("/")
-
- (event,) = events
- assert event["level"] == "error"
-
-
def test_mount(app, capture_exceptions, capture_events, sentry_init, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()])
app.catchall = False
@@ -354,39 +370,12 @@ def crashing_app(environ, start_response):
assert error is exc.value
(event,) = events
- assert event["exception"]["values"][0]["mechanism"] == {
- "type": "bottle",
- "handled": False,
- }
-
-
-def test_500(sentry_init, capture_events, app, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
-
- set_debug(False)
- app.catchall = True
-
- @app.route("/")
- def index():
- 1 / 0
-
- @app.error(500)
- def error_handler(err):
- capture_message("error_msg")
- return "My error"
-
- events = capture_events()
-
- client = get_client()
- response = client.get("/")
- assert response[1] == "500 Internal Server Error"
-
- _, event = events
- assert event["message"] == "error_msg"
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "bottle"
+ assert event["exception"]["values"][0]["mechanism"]["handled"] is False
def test_error_in_errorhandler(sentry_init, capture_events, app, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()])
set_debug(False)
app.catchall = True
@@ -416,7 +405,7 @@ def error_handler(err):
def test_bad_request_not_captured(sentry_init, capture_events, app, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()])
events = capture_events()
@app.route("/")
@@ -431,7 +420,7 @@ def index():
def test_no_exception_on_redirect(sentry_init, capture_events, app, get_client):
- sentry_init(integrations=[bottle_sentry.BottleIntegration()])
+ sentry_init(integrations=[BottleIntegration()])
events = capture_events()
@app.route("/")
@@ -447,3 +436,99 @@ def here():
client.get("/")
assert not events
+
+
+def test_span_origin(
+ sentry_init,
+ get_client,
+ capture_events,
+):
+ sentry_init(
+ integrations=[BottleIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = get_client()
+ client.get("/message")
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.bottle"
+
+
+@pytest.mark.parametrize("raise_error", [True, False])
+@pytest.mark.parametrize(
+ ("integration_kwargs", "status_code", "should_capture"),
+ (
+ ({}, None, False),
+ ({}, 400, False),
+ ({}, 451, False), # Highest 4xx status code
+ ({}, 500, True),
+ ({}, 511, True), # Highest 5xx status code
+ ({"failed_request_status_codes": set()}, 500, False),
+ ({"failed_request_status_codes": set()}, 511, False),
+ ({"failed_request_status_codes": {404, *range(500, 600)}}, 404, True),
+ ({"failed_request_status_codes": {404, *range(500, 600)}}, 500, True),
+ ({"failed_request_status_codes": {404, *range(500, 600)}}, 400, False),
+ ),
+)
+def test_failed_request_status_codes(
+ sentry_init,
+ capture_events,
+ integration_kwargs,
+ status_code,
+ should_capture,
+ raise_error,
+):
+ sentry_init(integrations=[BottleIntegration(**integration_kwargs)])
+ events = capture_events()
+
+ app = Bottle()
+
+ @app.route("/")
+ def handle():
+ if status_code is not None:
+ response = HTTPResponse(status=status_code)
+ if raise_error:
+ raise response
+ else:
+ return response
+ return "OK"
+
+ client = Client(app, Response)
+ response = client.get("/")
+
+ expected_status = 200 if status_code is None else status_code
+ assert response.status_code == expected_status
+
+ if should_capture:
+ (event,) = events
+ assert event["exception"]["values"][0]["type"] == "HTTPResponse"
+ else:
+ assert not events
+
+
+def test_failed_request_status_codes_non_http_exception(sentry_init, capture_events):
+ """
+ If an exception, which is not an instance of HTTPResponse, is raised, it should be captured, even if
+ failed_request_status_codes is empty.
+ """
+ sentry_init(integrations=[BottleIntegration(failed_request_status_codes=set())])
+ events = capture_events()
+
+ app = Bottle()
+
+ @app.route("/")
+ def handle():
+ 1 / 0
+
+ client = Client(app, Response)
+
+ try:
+ client.get("/")
+ except ZeroDivisionError:
+ pass
+
+ (event,) = events
+ assert event["exception"]["values"][0]["type"] == "ZeroDivisionError"
diff --git a/tests/integrations/celery/__init__.py b/tests/integrations/celery/__init__.py
index e69de29bb2..e37dfbf00e 100644
--- a/tests/integrations/celery/__init__.py
+++ b/tests/integrations/celery/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("celery")
diff --git a/tests/integrations/celery/integration_tests/__init__.py b/tests/integrations/celery/integration_tests/__init__.py
new file mode 100644
index 0000000000..2dfe2ddcf7
--- /dev/null
+++ b/tests/integrations/celery/integration_tests/__init__.py
@@ -0,0 +1,58 @@
+import os
+import signal
+import tempfile
+import threading
+import time
+
+from celery.beat import Scheduler
+
+from sentry_sdk.utils import logger
+
+
+class ImmediateScheduler(Scheduler):
+ """
+ A custom scheduler that starts tasks immediately after starting Celery beat.
+ """
+
+ def setup_schedule(self):
+ super().setup_schedule()
+ for _, entry in self.schedule.items():
+ self.apply_entry(entry)
+
+ def tick(self):
+ # Override tick to prevent the normal schedule cycle
+ return 1
+
+
+def kill_beat(beat_pid_file, delay_seconds=1):
+ """
+ Terminates Celery Beat after the given `delay_seconds`.
+ """
+ logger.info("Starting Celery Beat killer...")
+ time.sleep(delay_seconds)
+ pid = int(open(beat_pid_file, "r").read())
+ logger.info("Terminating Celery Beat...")
+ os.kill(pid, signal.SIGTERM)
+
+
+def run_beat(celery_app, runtime_seconds=1, loglevel="warning", quiet=True):
+ """
+ Run Celery Beat that immediately starts tasks.
+ The Celery Beat instance is automatically terminated after `runtime_seconds`.
+ """
+ logger.info("Starting Celery Beat...")
+ pid_file = os.path.join(tempfile.mkdtemp(), f"celery-beat-{os.getpid()}.pid")
+
+ t = threading.Thread(
+ target=kill_beat,
+ args=(pid_file,),
+ kwargs={"delay_seconds": runtime_seconds},
+ )
+ t.start()
+
+ beat_instance = celery_app.Beat(
+ loglevel=loglevel,
+ quiet=quiet,
+ pidfile=pid_file,
+ )
+ beat_instance.run()
diff --git a/tests/integrations/celery/integration_tests/test_celery_beat_cron_monitoring.py b/tests/integrations/celery/integration_tests/test_celery_beat_cron_monitoring.py
new file mode 100644
index 0000000000..e7d8197439
--- /dev/null
+++ b/tests/integrations/celery/integration_tests/test_celery_beat_cron_monitoring.py
@@ -0,0 +1,157 @@
+import os
+import sys
+import pytest
+
+from celery.contrib.testing.worker import start_worker
+
+from sentry_sdk.utils import logger
+
+from tests.integrations.celery.integration_tests import run_beat
+
+
+REDIS_SERVER = "redis://127.0.0.1:6379"
+REDIS_DB = 15
+
+
+@pytest.fixture()
+def celery_config():
+ return {
+ "worker_concurrency": 1,
+ "broker_url": f"{REDIS_SERVER}/{REDIS_DB}",
+ "result_backend": f"{REDIS_SERVER}/{REDIS_DB}",
+ "beat_scheduler": "tests.integrations.celery.integration_tests:ImmediateScheduler",
+ "task_always_eager": False,
+ "task_create_missing_queues": True,
+ "task_default_queue": f"queue_{os.getpid()}",
+ }
+
+
+@pytest.fixture
+def celery_init(sentry_init, celery_config):
+ """
+ Create a Sentry instrumented Celery app.
+ """
+ from celery import Celery
+
+ from sentry_sdk.integrations.celery import CeleryIntegration
+
+ def inner(propagate_traces=True, monitor_beat_tasks=False, **kwargs):
+ sentry_init(
+ integrations=[
+ CeleryIntegration(
+ propagate_traces=propagate_traces,
+ monitor_beat_tasks=monitor_beat_tasks,
+ )
+ ],
+ **kwargs,
+ )
+ app = Celery("tasks")
+ app.conf.update(celery_config)
+
+ return app
+
+ return inner
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires Python 3.7+")
+@pytest.mark.forked
+def test_explanation(celery_init, capture_envelopes):
+ """
+ This is a dummy test for explaining how to test using Celery Beat
+ """
+
+ # First initialize a Celery app.
+ # You can give the options of CeleryIntegrations
+ # and the options for `sentry_dks.init` as keyword arguments.
+ # See the celery_init fixture for details.
+ app = celery_init(
+ monitor_beat_tasks=True,
+ )
+
+ # Capture envelopes.
+ envelopes = capture_envelopes()
+
+ # Define the task you want to run
+ @app.task
+ def test_task():
+ logger.info("Running test_task")
+
+ # Add the task to the beat schedule
+ app.add_periodic_task(60.0, test_task.s(), name="success_from_beat")
+
+ # Start a Celery worker
+ with start_worker(app, perform_ping_check=False):
+ # And start a Celery Beat instance
+ # This Celery Beat will start the task above immediately
+ # after start for the first time
+ # By default Celery Beat is terminated after 1 second.
+ # See `run_beat` function on how to change this.
+ run_beat(app)
+
+ # After the Celery Beat is terminated, you can check the envelopes
+ assert len(envelopes) >= 0
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires Python 3.7+")
+@pytest.mark.forked
+def test_beat_task_crons_success(celery_init, capture_envelopes):
+ app = celery_init(
+ monitor_beat_tasks=True,
+ )
+ envelopes = capture_envelopes()
+
+ @app.task
+ def test_task():
+ logger.info("Running test_task")
+
+ app.add_periodic_task(60.0, test_task.s(), name="success_from_beat")
+
+ with start_worker(app, perform_ping_check=False):
+ run_beat(app)
+
+ assert len(envelopes) == 2
+ (envelop_in_progress, envelope_ok) = envelopes
+
+ assert envelop_in_progress.items[0].headers["type"] == "check_in"
+ check_in = envelop_in_progress.items[0].payload.json
+ assert check_in["type"] == "check_in"
+ assert check_in["monitor_slug"] == "success_from_beat"
+ assert check_in["status"] == "in_progress"
+
+ assert envelope_ok.items[0].headers["type"] == "check_in"
+ check_in = envelope_ok.items[0].payload.json
+ assert check_in["type"] == "check_in"
+ assert check_in["monitor_slug"] == "success_from_beat"
+ assert check_in["status"] == "ok"
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires Python 3.7+")
+@pytest.mark.forked
+def test_beat_task_crons_error(celery_init, capture_envelopes):
+ app = celery_init(
+ monitor_beat_tasks=True,
+ )
+ envelopes = capture_envelopes()
+
+ @app.task
+ def test_task():
+ logger.info("Running test_task")
+ 1 / 0
+
+ app.add_periodic_task(60.0, test_task.s(), name="failure_from_beat")
+
+ with start_worker(app, perform_ping_check=False):
+ run_beat(app)
+
+ envelop_in_progress = envelopes[0]
+ envelope_error = envelopes[-1]
+
+ check_in = envelop_in_progress.items[0].payload.json
+ assert check_in["type"] == "check_in"
+ assert check_in["monitor_slug"] == "failure_from_beat"
+ assert check_in["status"] == "in_progress"
+
+ check_in = envelope_error.items[0].payload.json
+ assert check_in["type"] == "check_in"
+ assert check_in["monitor_slug"] == "failure_from_beat"
+ assert check_in["status"] == "error"
diff --git a/tests/integrations/celery/test_celery.py b/tests/integrations/celery/test_celery.py
index a2c8fa1594..5d2d19c06a 100644
--- a/tests/integrations/celery/test_celery.py
+++ b/tests/integrations/celery/test_celery.py
@@ -1,20 +1,19 @@
import threading
+import kombu
+from unittest import mock
import pytest
-
-pytest.importorskip("celery")
-
-from sentry_sdk import Hub, configure_scope, start_transaction
-from sentry_sdk.integrations.celery import CeleryIntegration
-from sentry_sdk._compat import text_type
-
from celery import Celery, VERSION
from celery.bin import worker
-try:
- from unittest import mock # python 3.3 and above
-except ImportError:
- import mock # python < 3.3
+import sentry_sdk
+from sentry_sdk import start_transaction, get_current_span
+from sentry_sdk.integrations.celery import (
+ CeleryIntegration,
+ _wrap_task_run,
+)
+from sentry_sdk.integrations.celery.beat import _get_headers
+from tests.conftest import ApproxDict
@pytest.fixture
@@ -28,10 +27,20 @@ def inner(signal, f):
@pytest.fixture
def init_celery(sentry_init, request):
- def inner(propagate_traces=True, backend="always_eager", **kwargs):
+ def inner(
+ propagate_traces=True,
+ backend="always_eager",
+ monitor_beat_tasks=False,
+ **kwargs,
+ ):
sentry_init(
- integrations=[CeleryIntegration(propagate_traces=propagate_traces)],
- **kwargs
+ integrations=[
+ CeleryIntegration(
+ propagate_traces=propagate_traces,
+ monitor_beat_tasks=monitor_beat_tasks,
+ )
+ ],
+ **kwargs,
)
celery = Celery(__name__)
@@ -52,9 +61,6 @@ def inner(propagate_traces=True, backend="always_eager", **kwargs):
celery.conf.result_backend = "redis://127.0.0.1:6379"
celery.conf.task_always_eager = False
- Hub.main.bind_client(Hub.current.client)
- request.addfinalizer(lambda: Hub.main.bind_client(None))
-
# Once we drop celery 3 we can use the celery_worker fixture
if VERSION < (5,):
worker_fn = worker.worker(app=celery).run
@@ -84,8 +90,14 @@ def celery(init_celery):
@pytest.fixture(
params=[
- lambda task, x, y: (task.delay(x, y), {"args": [x, y], "kwargs": {}}),
- lambda task, x, y: (task.apply_async((x, y)), {"args": [x, y], "kwargs": {}}),
+ lambda task, x, y: (
+ task.delay(x, y),
+ {"args": [x, y], "kwargs": {}},
+ ),
+ lambda task, x, y: (
+ task.apply_async((x, y)),
+ {"args": [x, y], "kwargs": {}},
+ ),
lambda task, x, y: (
task.apply_async(args=(x, y)),
{"args": [x, y], "kwargs": {}},
@@ -105,7 +117,8 @@ def celery_invocation(request):
return request.param
-def test_simple(capture_events, celery, celery_invocation):
+def test_simple_with_performance(capture_events, init_celery, celery_invocation):
+ celery = init_celery(traces_sample_rate=1.0)
events = capture_events()
@celery.task(name="dummy_task")
@@ -113,21 +126,57 @@ def dummy_task(x, y):
foo = 42 # noqa
return x / y
- with start_transaction() as transaction:
+ with start_transaction(op="unit test transaction") as transaction:
celery_invocation(dummy_task, 1, 2)
_, expected_context = celery_invocation(dummy_task, 1, 0)
- (event,) = events
+ (_, error_event, _, _) = events
- assert event["contexts"]["trace"]["trace_id"] == transaction.trace_id
- assert event["contexts"]["trace"]["span_id"] != transaction.span_id
- assert event["transaction"] == "dummy_task"
- assert "celery_task_id" in event["tags"]
- assert event["extra"]["celery-job"] == dict(
+ assert error_event["contexts"]["trace"]["trace_id"] == transaction.trace_id
+ assert error_event["contexts"]["trace"]["span_id"] != transaction.span_id
+ assert error_event["transaction"] == "dummy_task"
+ assert "celery_task_id" in error_event["tags"]
+ assert error_event["extra"]["celery-job"] == dict(
task_name="dummy_task", **expected_context
)
- (exception,) = event["exception"]["values"]
+ (exception,) = error_event["exception"]["values"]
+ assert exception["type"] == "ZeroDivisionError"
+ assert exception["mechanism"]["type"] == "celery"
+ assert exception["stacktrace"]["frames"][0]["vars"]["foo"] == "42"
+
+
+def test_simple_without_performance(capture_events, init_celery, celery_invocation):
+ celery = init_celery(traces_sample_rate=None)
+ events = capture_events()
+
+ @celery.task(name="dummy_task")
+ def dummy_task(x, y):
+ foo = 42 # noqa
+ return x / y
+
+ scope = sentry_sdk.get_isolation_scope()
+
+ celery_invocation(dummy_task, 1, 2)
+ _, expected_context = celery_invocation(dummy_task, 1, 0)
+
+ (error_event,) = events
+
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == scope._propagation_context.trace_id
+ )
+ assert (
+ error_event["contexts"]["trace"]["span_id"]
+ != scope._propagation_context.span_id
+ )
+ assert error_event["transaction"] == "dummy_task"
+ assert "celery_task_id" in error_event["tags"]
+ assert error_event["extra"]["celery-job"] == dict(
+ task_name="dummy_task", **expected_context
+ )
+
+ (exception,) = error_event["exception"]["values"]
assert exception["type"] == "ZeroDivisionError"
assert exception["mechanism"]["type"] == "celery"
assert exception["stacktrace"]["frames"][0]["vars"]["foo"] == "42"
@@ -170,44 +219,61 @@ def dummy_task(x, y):
else:
assert execution_event["contexts"]["trace"]["status"] == "ok"
- assert execution_event["spans"] == []
+ assert len(execution_event["spans"]) == 1
+ assert (
+ execution_event["spans"][0].items()
+ >= {
+ "trace_id": str(transaction.trace_id),
+ "same_process_as_parent": True,
+ "op": "queue.process",
+ "description": "dummy_task",
+ "data": ApproxDict(),
+ }.items()
+ )
assert submission_event["spans"] == [
{
+ "data": ApproxDict(),
"description": "dummy_task",
"op": "queue.submit.celery",
+ "origin": "auto.queue.celery",
"parent_span_id": submission_event["contexts"]["trace"]["span_id"],
"same_process_as_parent": True,
"span_id": submission_event["spans"][0]["span_id"],
"start_timestamp": submission_event["spans"][0]["start_timestamp"],
"timestamp": submission_event["spans"][0]["timestamp"],
- "trace_id": text_type(transaction.trace_id),
+ "trace_id": str(transaction.trace_id),
}
]
-def test_no_stackoverflows(celery):
- """We used to have a bug in the Celery integration where its monkeypatching
+def test_no_double_patching(celery):
+ """Ensure that Celery tasks are only patched once to prevent stack overflows.
+
+ We used to have a bug in the Celery integration where its monkeypatching
was repeated for every task invocation, leading to stackoverflows.
See https://github.com/getsentry/sentry-python/issues/265
"""
- results = []
-
@celery.task(name="dummy_task")
def dummy_task():
- with configure_scope() as scope:
- scope.set_tag("foo", "bar")
+ return 42
- results.append(42)
+ # Initially, the task should not be marked as patched
+ assert not hasattr(dummy_task, "_sentry_is_patched")
- for _ in range(10000):
- dummy_task.delay()
+ # First invocation should trigger patching
+ result1 = dummy_task.delay()
+ assert result1.get() == 42
+ assert getattr(dummy_task, "_sentry_is_patched", False) is True
- assert results == [42] * 10000
+ patched_run = dummy_task.run
- with configure_scope() as scope:
- assert not scope._tags
+ # Second invocation should not re-patch
+ result2 = dummy_task.delay()
+ assert result2.get() == 42
+ assert dummy_task.run is patched_run
+ assert getattr(dummy_task, "_sentry_is_patched", False) is True
def test_simple_no_propagation(capture_events, init_celery):
@@ -240,42 +306,6 @@ def dummy_task(x, y):
assert not events
-def test_broken_prerun(init_celery, connect_signal):
- from celery.signals import task_prerun
-
- stack_lengths = []
-
- def crash(*args, **kwargs):
- # scope should exist in prerun
- stack_lengths.append(len(Hub.current._stack))
- 1 / 0
-
- # Order here is important to reproduce the bug: In Celery 3, a crashing
- # prerun would prevent other preruns from running.
-
- connect_signal(task_prerun, crash)
- celery = init_celery()
-
- assert len(Hub.current._stack) == 1
-
- @celery.task(name="dummy_task")
- def dummy_task(x, y):
- stack_lengths.append(len(Hub.current._stack))
- return x / y
-
- if VERSION >= (4,):
- dummy_task.delay(2, 2)
- else:
- with pytest.raises(ZeroDivisionError):
- dummy_task.delay(2, 2)
-
- assert len(Hub.current._stack) == 1
- if VERSION < (4,):
- assert stack_lengths == [2]
- else:
- assert stack_lengths == [2, 2]
-
-
@pytest.mark.xfail(
(4, 2, 0) <= VERSION < (4, 4, 3),
strict=True,
@@ -313,11 +343,12 @@ def dummy_task(self):
assert e["type"] == "ZeroDivisionError"
-# TODO: This test is hanging when running test with `tox --parallel auto`. Find out why and fix it!
-@pytest.mark.skip
+@pytest.mark.skip(
+ reason="This test is hanging when running test with `tox --parallel auto`. TODO: Figure out why and fix it!"
+)
@pytest.mark.forked
-def test_redis_backend_trace_propagation(init_celery, capture_events_forksafe, tmpdir):
- celery = init_celery(traces_sample_rate=1.0, backend="redis", debug=True)
+def test_redis_backend_trace_propagation(init_celery, capture_events_forksafe):
+ celery = init_celery(traces_sample_rate=1.0, backend="redis")
events = capture_events_forksafe()
@@ -332,7 +363,7 @@ def dummy_task(self):
# Curious: Cannot use delay() here or py2.7-celery-4.2 crashes
res = dummy_task.apply_async()
- with pytest.raises(Exception):
+ with pytest.raises(Exception): # noqa: B017
# Celery 4.1 raises a gibberish exception
res.wait()
@@ -343,9 +374,9 @@ def dummy_task(self):
assert submit_transaction["type"] == "transaction"
assert submit_transaction["transaction"] == "submit_celery"
- assert len(
- submit_transaction["spans"]
- ), 4 # Because redis integration was auto enabled
+ assert len(submit_transaction["spans"]), (
+ 4
+ ) # Because redis integration was auto enabled
span = submit_transaction["spans"][0]
assert span["op"] == "queue.submit.celery"
assert span["description"] == "dummy_task"
@@ -371,11 +402,24 @@ def dummy_task(self):
@pytest.mark.parametrize("newrelic_order", ["sentry_first", "sentry_last"])
def test_newrelic_interference(init_celery, newrelic_order, celery_invocation):
def instrument_newrelic():
- import celery.app.trace as celery_mod
- from newrelic.hooks.application_celery import instrument_celery_execute_trace
+ try:
+ # older newrelic versions
+ from newrelic.hooks.application_celery import (
+ instrument_celery_execute_trace,
+ )
+ import celery.app.trace as celery_trace_module
+
+ assert hasattr(celery_trace_module, "build_tracer")
+ instrument_celery_execute_trace(celery_trace_module)
- assert hasattr(celery_mod, "build_tracer")
- instrument_celery_execute_trace(celery_mod)
+ except ImportError:
+ # newer newrelic versions
+ from newrelic.hooks.application_celery import instrument_celery_app_base
+ import celery.app as celery_app_module
+
+ assert hasattr(celery_app_module, "Celery")
+ assert hasattr(celery_app_module.Celery, "send_task")
+ instrument_celery_app_base(celery_app_module)
if newrelic_order == "sentry_first":
celery = init_celery()
@@ -395,7 +439,9 @@ def dummy_task(self, x, y):
def test_traces_sampler_gets_task_info_in_sampling_context(
- init_celery, celery_invocation, DictionaryContaining # noqa:N803
+ init_celery,
+ celery_invocation,
+ DictionaryContaining, # noqa:N803
):
traces_sampler = mock.Mock()
celery = init_celery(traces_sampler=traces_sampler)
@@ -437,3 +483,402 @@ def dummy_task(x, y):
celery_invocation(dummy_task, 1, 0)
assert not events
+
+
+def test_task_headers(celery):
+ """
+ Test that the headers set in the Celery Beat auto-instrumentation are passed to the celery signal handlers
+ """
+ sentry_crons_setup = {
+ "sentry-monitor-slug": "some-slug",
+ "sentry-monitor-config": {"some": "config"},
+ "sentry-monitor-check-in-id": "123abc",
+ }
+
+ @celery.task(name="dummy_task", bind=True)
+ def dummy_task(self, x, y):
+ return _get_headers(self)
+
+ # This is how the Celery Beat auto-instrumentation starts a task
+ # in the monkey patched version of `apply_async`
+ # in `sentry_sdk/integrations/celery.py::_wrap_apply_async()`
+ result = dummy_task.apply_async(args=(1, 0), headers=sentry_crons_setup)
+
+ expected_headers = sentry_crons_setup.copy()
+ # Newly added headers
+ expected_headers["sentry-trace"] = mock.ANY
+ expected_headers["baggage"] = mock.ANY
+ expected_headers["sentry-task-enqueued-time"] = mock.ANY
+
+ assert result.get() == expected_headers
+
+
+def test_baggage_propagation(init_celery):
+ celery = init_celery(traces_sample_rate=1.0, release="abcdef")
+
+ @celery.task(name="dummy_task", bind=True)
+ def dummy_task(self, x, y):
+ return _get_headers(self)
+
+ # patch random.randrange to return a predictable sample_rand value
+ with mock.patch("sentry_sdk.tracing_utils.Random.randrange", return_value=500000):
+ with start_transaction() as transaction:
+ result = dummy_task.apply_async(
+ args=(1, 0),
+ headers={"baggage": "custom=value"},
+ ).get()
+
+ assert sorted(result["baggage"].split(",")) == sorted(
+ [
+ "sentry-release=abcdef",
+ "sentry-trace_id={}".format(transaction.trace_id),
+ "sentry-environment=production",
+ "sentry-sample_rand=0.500000",
+ "sentry-sample_rate=1.0",
+ "sentry-sampled=true",
+ "custom=value",
+ ]
+ )
+
+
+def test_sentry_propagate_traces_override(init_celery):
+ """
+ Test if the `sentry-propagate-traces` header given to `apply_async`
+ overrides the `propagate_traces` parameter in the integration constructor.
+ """
+ celery = init_celery(
+ propagate_traces=True, traces_sample_rate=1.0, release="abcdef"
+ )
+
+ @celery.task(name="dummy_task", bind=True)
+ def dummy_task(self, message):
+ trace_id = get_current_span().trace_id
+ return trace_id
+
+ with start_transaction() as transaction:
+ transaction_trace_id = transaction.trace_id
+
+ # should propagate trace
+ task_transaction_id = dummy_task.apply_async(
+ args=("some message",),
+ ).get()
+ assert transaction_trace_id == task_transaction_id
+
+ # should NOT propagate trace (overrides `propagate_traces` parameter in integration constructor)
+ task_transaction_id = dummy_task.apply_async(
+ args=("another message",),
+ headers={"sentry-propagate-traces": False},
+ ).get()
+ assert transaction_trace_id != task_transaction_id
+
+
+def test_apply_async_manually_span(sentry_init):
+ sentry_init(
+ integrations=[CeleryIntegration()],
+ )
+
+ def dummy_function(*args, **kwargs):
+ headers = kwargs.get("headers")
+ assert "sentry-trace" in headers
+ assert "baggage" in headers
+
+ wrapped = _wrap_task_run(dummy_function)
+ wrapped(mock.MagicMock(), (), headers={})
+
+
+def test_apply_async_no_args(init_celery):
+ celery = init_celery()
+
+ @celery.task
+ def example_task():
+ return "success"
+
+ try:
+ result = example_task.apply_async(None, {})
+ except TypeError:
+ pytest.fail("Calling `apply_async` without arguments raised a TypeError")
+
+ assert result.get() == "success"
+
+
+@pytest.mark.parametrize("routing_key", ("celery", "custom"))
+@mock.patch("celery.app.task.Task.request")
+def test_messaging_destination_name_default_exchange(
+ mock_request, routing_key, init_celery, capture_events
+):
+ celery_app = init_celery(traces_sample_rate=1.0)
+ events = capture_events()
+ mock_request.delivery_info = {"routing_key": routing_key, "exchange": ""}
+
+ @celery_app.task()
+ def task(): ...
+
+ task.apply_async()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["data"]["messaging.destination.name"] == routing_key
+
+
+@mock.patch("celery.app.task.Task.request")
+def test_messaging_destination_name_nondefault_exchange(
+ mock_request, init_celery, capture_events
+):
+ """
+ Currently, we only capture the routing key as the messaging.destination.name when
+ we are using the default exchange (""). This is because the default exchange ensures
+ that the routing key is the queue name. Other exchanges may not guarantee this
+ behavior.
+ """
+ celery_app = init_celery(traces_sample_rate=1.0)
+ events = capture_events()
+ mock_request.delivery_info = {"routing_key": "celery", "exchange": "custom"}
+
+ @celery_app.task()
+ def task(): ...
+
+ task.apply_async()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert "messaging.destination.name" not in span["data"]
+
+
+def test_messaging_id(init_celery, capture_events):
+ celery = init_celery(traces_sample_rate=1.0)
+ events = capture_events()
+
+ @celery.task
+ def example_task(): ...
+
+ example_task.apply_async()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert "messaging.message.id" in span["data"]
+
+
+def test_retry_count_zero(init_celery, capture_events):
+ celery = init_celery(traces_sample_rate=1.0)
+ events = capture_events()
+
+ @celery.task()
+ def task(): ...
+
+ task.apply_async()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["data"]["messaging.message.retry.count"] == 0
+
+
+@mock.patch("celery.app.task.Task.request")
+def test_retry_count_nonzero(mock_request, init_celery, capture_events):
+ mock_request.retries = 3
+
+ celery = init_celery(traces_sample_rate=1.0)
+ events = capture_events()
+
+ @celery.task()
+ def task(): ...
+
+ task.apply_async()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["data"]["messaging.message.retry.count"] == 3
+
+
+@pytest.mark.parametrize("system", ("redis", "amqp"))
+def test_messaging_system(system, init_celery, capture_events):
+ celery = init_celery(traces_sample_rate=1.0)
+ events = capture_events()
+
+ # Does not need to be a real URL, since we use always eager
+ celery.conf.broker_url = f"{system}://example.com" # noqa: E231
+
+ @celery.task()
+ def task(): ...
+
+ task.apply_async()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["data"]["messaging.system"] == system
+
+
+@pytest.mark.parametrize("system", ("amqp", "redis"))
+def test_producer_span_data(system, monkeypatch, sentry_init, capture_events):
+ old_publish = kombu.messaging.Producer._publish
+
+ def publish(*args, **kwargs):
+ pass
+
+ monkeypatch.setattr(kombu.messaging.Producer, "_publish", publish)
+
+ sentry_init(integrations=[CeleryIntegration()], traces_sample_rate=1.0)
+ celery = Celery(__name__, broker=f"{system}://example.com") # noqa: E231
+ events = capture_events()
+
+ @celery.task()
+ def task(): ...
+
+ with start_transaction():
+ task.apply_async()
+
+ (event,) = events
+ span = next(span for span in event["spans"] if span["op"] == "queue.publish")
+
+ assert span["data"]["messaging.system"] == system
+
+ assert span["data"]["messaging.destination.name"] == "celery"
+ assert "messaging.message.id" in span["data"]
+ assert span["data"]["messaging.message.retry.count"] == 0
+
+ monkeypatch.setattr(kombu.messaging.Producer, "_publish", old_publish)
+
+
+def test_receive_latency(init_celery, capture_events):
+ celery = init_celery(traces_sample_rate=1.0)
+ events = capture_events()
+
+ @celery.task()
+ def task(): ...
+
+ task.apply_async()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert "messaging.message.receive.latency" in span["data"]
+ assert span["data"]["messaging.message.receive.latency"] > 0
+
+
+def tests_span_origin_consumer(init_celery, capture_events):
+ celery = init_celery(traces_sample_rate=1.0)
+ celery.conf.broker_url = "redis://example.com" # noqa: E231
+
+ events = capture_events()
+
+ @celery.task()
+ def task(): ...
+
+ task.apply_async()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.queue.celery"
+ assert event["spans"][0]["origin"] == "auto.queue.celery"
+
+
+def tests_span_origin_producer(monkeypatch, sentry_init, capture_events):
+ old_publish = kombu.messaging.Producer._publish
+
+ def publish(*args, **kwargs):
+ pass
+
+ monkeypatch.setattr(kombu.messaging.Producer, "_publish", publish)
+
+ sentry_init(integrations=[CeleryIntegration()], traces_sample_rate=1.0)
+ celery = Celery(__name__, broker="redis://example.com") # noqa: E231
+
+ events = capture_events()
+
+ @celery.task()
+ def task(): ...
+
+ with start_transaction(name="custom_transaction"):
+ task.apply_async()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ for span in event["spans"]:
+ assert span["origin"] == "auto.queue.celery"
+
+ monkeypatch.setattr(kombu.messaging.Producer, "_publish", old_publish)
+
+
+@pytest.mark.forked
+@mock.patch("celery.Celery.send_task")
+def test_send_task_wrapped(
+ patched_send_task,
+ sentry_init,
+ capture_events,
+ reset_integrations,
+):
+ sentry_init(integrations=[CeleryIntegration()], traces_sample_rate=1.0)
+ celery = Celery(__name__, broker="redis://example.com") # noqa: E231
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="custom_transaction"):
+ celery.send_task("very_creative_task_name", args=(1, 2), kwargs={"foo": "bar"})
+
+ (call,) = patched_send_task.call_args_list # We should have exactly one call
+ (args, kwargs) = call
+
+ assert args == (celery, "very_creative_task_name")
+ assert kwargs["args"] == (1, 2)
+ assert kwargs["kwargs"] == {"foo": "bar"}
+ assert set(kwargs["headers"].keys()) == {
+ "sentry-task-enqueued-time",
+ "sentry-trace",
+ "baggage",
+ "headers",
+ }
+ assert set(kwargs["headers"]["headers"].keys()) == {
+ "sentry-trace",
+ "baggage",
+ "sentry-task-enqueued-time",
+ }
+ assert (
+ kwargs["headers"]["sentry-trace"]
+ == kwargs["headers"]["headers"]["sentry-trace"]
+ )
+
+ (event,) = events # We should have exactly one event (the transaction)
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "custom_transaction"
+
+ (span,) = event["spans"] # We should have exactly one span
+ assert span["description"] == "very_creative_task_name"
+ assert span["op"] == "queue.submit.celery"
+ assert span["trace_id"] == kwargs["headers"]["sentry-trace"].split("-")[0]
+
+
+def test_user_custom_headers_accessible_in_task(init_celery):
+ """
+ Regression test for https://github.com/getsentry/sentry-python/issues/5566
+
+ User-provided custom headers passed to apply_async() must be accessible
+ via task.request.headers on the worker side.
+ """
+ celery = init_celery(traces_sample_rate=1.0)
+
+ @celery.task(name="custom_headers_task", bind=True)
+ def custom_headers_task(self):
+ return dict(self.request.headers or {})
+
+ custom_headers = {
+ "my_custom_key": "my_value",
+ "correlation_id": "abc-123",
+ "tenant_id": "tenant-42",
+ }
+
+ with start_transaction(name="test"):
+ result = custom_headers_task.apply_async(headers=custom_headers)
+
+ received_headers = result.get()
+ for key, value in custom_headers.items():
+ assert received_headers.get(key) == value, (
+ f"Custom header {key!r} not found in task.request.headers"
+ )
+
+
+@pytest.mark.skip(reason="placeholder so that forked test does not come last")
+def test_placeholder():
+ """Forked tests must not come last in the module.
+ See https://github.com/pytest-dev/pytest-forked/issues/67#issuecomment-1964718720.
+ """
+ pass
diff --git a/tests/integrations/celery/test_celery_beat_crons.py b/tests/integrations/celery/test_celery_beat_crons.py
index fd90196c8e..17b4a5e73d 100644
--- a/tests/integrations/celery/test_celery_beat_crons.py
+++ b/tests/integrations/celery/test_celery_beat_crons.py
@@ -1,26 +1,25 @@
-import tempfile
-import mock
+import datetime
+from unittest import mock
+from unittest.mock import MagicMock
import pytest
+from celery.schedules import crontab, schedule
-pytest.importorskip("celery")
-
-from sentry_sdk.integrations.celery import (
+from sentry_sdk.crons import MonitorStatus
+from sentry_sdk.integrations.celery.beat import (
_get_headers,
- _get_humanized_interval,
_get_monitor_config,
- _reinstall_patched_tasks,
- crons_task_before_run,
- crons_task_success,
+ _patch_beat_apply_entry,
+ _patch_redbeat_apply_async,
crons_task_failure,
crons_task_retry,
+ crons_task_success,
)
-from sentry_sdk.crons import MonitorStatus
-from celery.schedules import crontab, schedule
+from sentry_sdk.integrations.celery.utils import _get_humanized_interval
def test_get_headers():
- fake_task = mock.MagicMock()
+ fake_task = MagicMock()
fake_task.request = {
"bla": "blub",
"foo": "bar",
@@ -56,9 +55,11 @@ def test_get_headers():
@pytest.mark.parametrize(
"seconds, expected_tuple",
[
- (0, (1, "minute")),
- (0.00001, (1, "minute")),
- (1, (1, "minute")),
+ (0, (0, "second")),
+ (1, (1, "second")),
+ (0.00001, (0, "second")),
+ (59, (59, "second")),
+ (60, (1, "minute")),
(100, (1, "minute")),
(1000, (16, "minute")),
(10000, (2, "hour")),
@@ -70,44 +71,8 @@ def test_get_humanized_interval(seconds, expected_tuple):
assert _get_humanized_interval(seconds) == expected_tuple
-def test_crons_task_before_run():
- fake_task = mock.MagicMock()
- fake_task.request = {
- "headers": {
- "sentry-monitor-slug": "test123",
- "sentry-monitor-config": {
- "schedule": {
- "type": "interval",
- "value": 3,
- "unit": "day",
- },
- "timezone": "Europe/Vienna",
- },
- "sentry-monitor-some-future-key": "some-future-value",
- },
- }
-
- with mock.patch(
- "sentry_sdk.integrations.celery.capture_checkin"
- ) as mock_capture_checkin:
- crons_task_before_run(fake_task)
-
- mock_capture_checkin.assert_called_once_with(
- monitor_slug="test123",
- monitor_config={
- "schedule": {
- "type": "interval",
- "value": 3,
- "unit": "day",
- },
- "timezone": "Europe/Vienna",
- },
- status=MonitorStatus.IN_PROGRESS,
- )
-
-
def test_crons_task_success():
- fake_task = mock.MagicMock()
+ fake_task = MagicMock()
fake_task.request = {
"headers": {
"sentry-monitor-slug": "test123",
@@ -126,9 +91,12 @@ def test_crons_task_success():
}
with mock.patch(
- "sentry_sdk.integrations.celery.capture_checkin"
+ "sentry_sdk.integrations.celery.beat.capture_checkin"
) as mock_capture_checkin:
- with mock.patch("sentry_sdk.integrations.celery.now", return_value=500.5):
+ with mock.patch(
+ "sentry_sdk.integrations.celery.beat._now_seconds_since_epoch",
+ return_value=500.5,
+ ):
crons_task_success(fake_task)
mock_capture_checkin.assert_called_once_with(
@@ -148,7 +116,7 @@ def test_crons_task_success():
def test_crons_task_failure():
- fake_task = mock.MagicMock()
+ fake_task = MagicMock()
fake_task.request = {
"headers": {
"sentry-monitor-slug": "test123",
@@ -167,9 +135,12 @@ def test_crons_task_failure():
}
with mock.patch(
- "sentry_sdk.integrations.celery.capture_checkin"
+ "sentry_sdk.integrations.celery.beat.capture_checkin"
) as mock_capture_checkin:
- with mock.patch("sentry_sdk.integrations.celery.now", return_value=500.5):
+ with mock.patch(
+ "sentry_sdk.integrations.celery.beat._now_seconds_since_epoch",
+ return_value=500.5,
+ ):
crons_task_failure(fake_task)
mock_capture_checkin.assert_called_once_with(
@@ -189,7 +160,7 @@ def test_crons_task_failure():
def test_crons_task_retry():
- fake_task = mock.MagicMock()
+ fake_task = MagicMock()
fake_task.request = {
"headers": {
"sentry-monitor-slug": "test123",
@@ -208,9 +179,12 @@ def test_crons_task_retry():
}
with mock.patch(
- "sentry_sdk.integrations.celery.capture_checkin"
+ "sentry_sdk.integrations.celery.beat.capture_checkin"
) as mock_capture_checkin:
- with mock.patch("sentry_sdk.integrations.celery.now", return_value=500.5):
+ with mock.patch(
+ "sentry_sdk.integrations.celery.beat._now_seconds_since_epoch",
+ return_value=500.5,
+ ):
crons_task_retry(fake_task)
mock_capture_checkin.assert_called_once_with(
@@ -229,79 +203,297 @@ def test_crons_task_retry():
)
-def test_get_monitor_config():
- app = mock.MagicMock()
- app.conf = mock.MagicMock()
- app.conf.timezone = "Europe/Vienna"
+def test_get_monitor_config_crontab():
+ app = MagicMock()
+ app.timezone = "Europe/Vienna"
+ # schedule with the default timezone
celery_schedule = crontab(day_of_month="3", hour="12", minute="*/10")
- monitor_config = _get_monitor_config(celery_schedule, app)
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
assert monitor_config == {
"schedule": {
"type": "crontab",
"value": "*/10 12 3 * *",
},
- "timezone": "Europe/Vienna",
+ "timezone": "UTC", # the default because `crontab` does not know about the app
}
assert "unit" not in monitor_config["schedule"]
- celery_schedule = schedule(run_every=3)
+ # schedule with the timezone from the app
+ celery_schedule = crontab(day_of_month="3", hour="12", minute="*/10", app=app)
+
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
+ assert monitor_config == {
+ "schedule": {
+ "type": "crontab",
+ "value": "*/10 12 3 * *",
+ },
+ "timezone": "Europe/Vienna", # the timezone from the app
+ }
+
+ # schedule without a timezone, the celery integration will read the config from the app
+ celery_schedule = crontab(day_of_month="3", hour="12", minute="*/10")
+ celery_schedule.tz = None
+
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
+ assert monitor_config == {
+ "schedule": {
+ "type": "crontab",
+ "value": "*/10 12 3 * *",
+ },
+ "timezone": "Europe/Vienna", # the timezone from the app
+ }
+
+ # schedule without a timezone, and an app without timezone, the celery integration will fall back to UTC
+ app = MagicMock()
+ app.timezone = None
+
+ celery_schedule = crontab(day_of_month="3", hour="12", minute="*/10")
+ celery_schedule.tz = None
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
+ assert monitor_config == {
+ "schedule": {
+ "type": "crontab",
+ "value": "*/10 12 3 * *",
+ },
+ "timezone": "UTC", # default timezone from celery integration
+ }
+
+
+def test_get_monitor_config_seconds():
+ app = MagicMock()
+ app.timezone = "Europe/Vienna"
- monitor_config = _get_monitor_config(celery_schedule, app)
+ celery_schedule = schedule(run_every=3) # seconds
+
+ with mock.patch("sentry_sdk.integrations.logger.warning") as mock_logger_warning:
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
+ mock_logger_warning.assert_called_with(
+ "Intervals shorter than one minute are not supported by Sentry Crons. Monitor '%s' has an interval of %s seconds. Use the `exclude_beat_tasks` option in the celery integration to exclude it.",
+ "foo",
+ 3,
+ )
+ assert monitor_config == {}
+
+
+def test_get_monitor_config_minutes():
+ app = MagicMock()
+ app.timezone = "Europe/Vienna"
+
+ # schedule with the default timezone
+ celery_schedule = schedule(run_every=60) # seconds
+
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
assert monitor_config == {
"schedule": {
"type": "interval",
"value": 1,
"unit": "minute",
},
- "timezone": "Europe/Vienna",
+ "timezone": "UTC",
}
- unknown_celery_schedule = mock.MagicMock()
- monitor_config = _get_monitor_config(unknown_celery_schedule, app)
+ # schedule with the timezone from the app
+ celery_schedule = schedule(run_every=60, app=app) # seconds
+
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
+ assert monitor_config == {
+ "schedule": {
+ "type": "interval",
+ "value": 1,
+ "unit": "minute",
+ },
+ "timezone": "Europe/Vienna", # the timezone from the app
+ }
+
+ # schedule without a timezone, the celery integration will read the config from the app
+ celery_schedule = schedule(run_every=60) # seconds
+ celery_schedule.tz = None
+
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
+ assert monitor_config == {
+ "schedule": {
+ "type": "interval",
+ "value": 1,
+ "unit": "minute",
+ },
+ "timezone": "Europe/Vienna", # the timezone from the app
+ }
+
+ # schedule without a timezone, and an app without timezone, the celery integration will fall back to UTC
+ app = MagicMock()
+ app.timezone = None
+
+ celery_schedule = schedule(run_every=60) # seconds
+ celery_schedule.tz = None
+
+ monitor_config = _get_monitor_config(celery_schedule, app, "foo")
+ assert monitor_config == {
+ "schedule": {
+ "type": "interval",
+ "value": 1,
+ "unit": "minute",
+ },
+ "timezone": "UTC", # default timezone from celery integration
+ }
+
+
+def test_get_monitor_config_unknown():
+ app = MagicMock()
+ app.timezone = "Europe/Vienna"
+
+ unknown_celery_schedule = MagicMock()
+ monitor_config = _get_monitor_config(unknown_celery_schedule, app, "foo")
assert monitor_config == {}
def test_get_monitor_config_default_timezone():
- app = mock.MagicMock()
- app.conf = mock.MagicMock()
- app.conf.timezone = None
+ app = MagicMock()
+ app.timezone = None
celery_schedule = crontab(day_of_month="3", hour="12", minute="*/10")
- monitor_config = _get_monitor_config(celery_schedule, app)
+ monitor_config = _get_monitor_config(celery_schedule, app, "dummy_monitor_name")
assert monitor_config["timezone"] == "UTC"
-def test_reinstall_patched_tasks():
- fake_beat = mock.MagicMock()
- fake_beat.run = mock.MagicMock()
+def test_get_monitor_config_timezone_in_app_conf():
+ app = MagicMock()
+ app.timezone = "Asia/Karachi"
+
+ celery_schedule = crontab(day_of_month="3", hour="12", minute="*/10")
+ celery_schedule.tz = None
+
+ monitor_config = _get_monitor_config(celery_schedule, app, "dummy_monitor_name")
+
+ assert monitor_config["timezone"] == "Asia/Karachi"
+
+
+def test_get_monitor_config_timezone_in_celery_schedule():
+ app = MagicMock()
+ app.timezone = "Asia/Karachi"
+
+ panama_tz = datetime.timezone(datetime.timedelta(hours=-5), name="America/Panama")
+
+ celery_schedule = crontab(day_of_month="3", hour="12", minute="*/10")
+ celery_schedule.tz = panama_tz
+
+ monitor_config = _get_monitor_config(celery_schedule, app, "dummy_monitor_name")
+
+ assert monitor_config["timezone"] == str(panama_tz)
+
+
+@pytest.mark.parametrize(
+ "task_name,exclude_beat_tasks,task_in_excluded_beat_tasks",
+ [
+ ["some_task_name", ["xxx", "some_task.*"], True],
+ ["some_task_name", ["xxx", "some_other_task.*"], False],
+ ],
+)
+def test_exclude_beat_tasks_option(
+ task_name, exclude_beat_tasks, task_in_excluded_beat_tasks
+):
+ """
+ Test excluding Celery Beat tasks from automatic instrumentation.
+ """
+ fake_apply_entry = MagicMock()
- app = mock.MagicMock()
- app.Beat = mock.MagicMock(return_value=fake_beat)
+ fake_scheduler = MagicMock()
+ fake_scheduler.apply_entry = fake_apply_entry
- sender = mock.MagicMock()
- sender.schedule_filename = "test_schedule_filename"
- sender.stop = mock.MagicMock()
+ fake_integration = MagicMock()
+ fake_integration.exclude_beat_tasks = exclude_beat_tasks
- add_updated_periodic_tasks = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()]
+ fake_client = MagicMock()
+ fake_client.get_integration.return_value = fake_integration
- mock_open = mock.Mock(return_value=tempfile.NamedTemporaryFile())
+ fake_schedule_entry = MagicMock()
+ fake_schedule_entry.name = task_name
- with mock.patch("sentry_sdk.integrations.celery.open", mock_open):
+ fake_get_monitor_config = MagicMock()
+
+ with mock.patch(
+ "sentry_sdk.integrations.celery.beat.Scheduler", fake_scheduler
+ ) as Scheduler: # noqa: N806
with mock.patch(
- "sentry_sdk.integrations.celery.shutil.copyfileobj"
- ) as mock_copyfileobj:
- _reinstall_patched_tasks(app, sender, add_updated_periodic_tasks)
+ "sentry_sdk.integrations.celery.sentry_sdk.get_client",
+ return_value=fake_client,
+ ):
+ with mock.patch(
+ "sentry_sdk.integrations.celery.beat._get_monitor_config",
+ fake_get_monitor_config,
+ ) as _get_monitor_config:
+ # Mimic CeleryIntegration patching of Scheduler.apply_entry()
+ _patch_beat_apply_entry()
+ # Mimic Celery Beat calling a task from the Beat schedule
+ Scheduler.apply_entry(fake_scheduler, fake_schedule_entry)
+
+ if task_in_excluded_beat_tasks:
+ # Only the original Scheduler.apply_entry() is called, _get_monitor_config is NOT called.
+ assert fake_apply_entry.call_count == 1
+ _get_monitor_config.assert_not_called()
+
+ else:
+ # The original Scheduler.apply_entry() is called, AND _get_monitor_config is called.
+ assert fake_apply_entry.call_count == 1
+ assert _get_monitor_config.call_count == 1
+
+
+@pytest.mark.parametrize(
+ "task_name,exclude_beat_tasks,task_in_excluded_beat_tasks",
+ [
+ ["some_task_name", ["xxx", "some_task.*"], True],
+ ["some_task_name", ["xxx", "some_other_task.*"], False],
+ ],
+)
+def test_exclude_redbeat_tasks_option(
+ task_name, exclude_beat_tasks, task_in_excluded_beat_tasks
+):
+ """
+ Test excluding Celery RedBeat tasks from automatic instrumentation.
+ """
+ fake_apply_async = MagicMock()
- sender.stop.assert_called_once_with()
+ fake_redbeat_scheduler = MagicMock()
+ fake_redbeat_scheduler.apply_async = fake_apply_async
- add_updated_periodic_tasks[0].assert_called_once_with()
- add_updated_periodic_tasks[1].assert_called_once_with()
- add_updated_periodic_tasks[2].assert_called_once_with()
+ fake_integration = MagicMock()
+ fake_integration.exclude_beat_tasks = exclude_beat_tasks
- mock_copyfileobj.assert_called_once()
+ fake_client = MagicMock()
+ fake_client.get_integration.return_value = fake_integration
- fake_beat.run.assert_called_once_with()
+ fake_schedule_entry = MagicMock()
+ fake_schedule_entry.name = task_name
+
+ fake_get_monitor_config = MagicMock()
+
+ with mock.patch(
+ "sentry_sdk.integrations.celery.beat.RedBeatScheduler", fake_redbeat_scheduler
+ ) as RedBeatScheduler: # noqa: N806
+ with mock.patch(
+ "sentry_sdk.integrations.celery.sentry_sdk.get_client",
+ return_value=fake_client,
+ ):
+ with mock.patch(
+ "sentry_sdk.integrations.celery.beat._get_monitor_config",
+ fake_get_monitor_config,
+ ) as _get_monitor_config:
+ # Mimic CeleryIntegration patching of RedBeatScheduler.apply_async()
+ _patch_redbeat_apply_async()
+ # Mimic Celery RedBeat calling a task from the RedBeat schedule
+ RedBeatScheduler.apply_async(
+ fake_redbeat_scheduler, fake_schedule_entry
+ )
+
+ if task_in_excluded_beat_tasks:
+ # Only the original RedBeatScheduler.maybe_due() is called, _get_monitor_config is NOT called.
+ assert fake_apply_async.call_count == 1
+ _get_monitor_config.assert_not_called()
+
+ else:
+ # The original RedBeatScheduler.maybe_due() is called, AND _get_monitor_config is called.
+ assert fake_apply_async.call_count == 1
+ assert _get_monitor_config.call_count == 1
diff --git a/tests/integrations/celery/test_update_celery_task_headers.py b/tests/integrations/celery/test_update_celery_task_headers.py
new file mode 100644
index 0000000000..950b13826c
--- /dev/null
+++ b/tests/integrations/celery/test_update_celery_task_headers.py
@@ -0,0 +1,228 @@
+from copy import copy
+import itertools
+import pytest
+
+from unittest import mock
+
+from sentry_sdk.integrations.celery import _update_celery_task_headers
+import sentry_sdk
+from sentry_sdk.tracing_utils import Baggage
+
+
+BAGGAGE_VALUE = (
+ "sentry-trace_id=771a43a4192642f0b136d5159a501700,"
+ "sentry-public_key=49d0f7386ad645858ae85020e393bef3,"
+ "sentry-sample_rate=0.1337,"
+ "custom=value"
+)
+
+SENTRY_TRACE_VALUE = "771a43a4192642f0b136d5159a501700-1234567890abcdef-1"
+
+
+@pytest.mark.parametrize("monitor_beat_tasks", [True, False, None, "", "bla", 1, 0])
+def test_monitor_beat_tasks(monitor_beat_tasks):
+ headers = {}
+ span = None
+
+ outgoing_headers = _update_celery_task_headers(headers, span, monitor_beat_tasks)
+
+ assert headers == {} # left unchanged
+
+ if monitor_beat_tasks:
+ assert outgoing_headers["sentry-monitor-start-timestamp-s"] == mock.ANY
+ assert (
+ outgoing_headers["headers"]["sentry-monitor-start-timestamp-s"] == mock.ANY
+ )
+ else:
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers["headers"]
+
+
+@pytest.mark.parametrize("monitor_beat_tasks", [True, False, None, "", "bla", 1, 0])
+def test_monitor_beat_tasks_with_headers(monitor_beat_tasks):
+ headers = {
+ "blub": "foo",
+ "sentry-something": "bar",
+ "sentry-task-enqueued-time": mock.ANY,
+ }
+ span = None
+
+ outgoing_headers = _update_celery_task_headers(headers, span, monitor_beat_tasks)
+
+ assert headers == {
+ "blub": "foo",
+ "sentry-something": "bar",
+ "sentry-task-enqueued-time": mock.ANY,
+ } # left unchanged
+
+ if monitor_beat_tasks:
+ assert outgoing_headers["blub"] == "foo"
+ assert outgoing_headers["sentry-something"] == "bar"
+ assert outgoing_headers["sentry-monitor-start-timestamp-s"] == mock.ANY
+ assert outgoing_headers["headers"]["sentry-something"] == "bar"
+ assert (
+ outgoing_headers["headers"]["sentry-monitor-start-timestamp-s"] == mock.ANY
+ )
+ else:
+ assert outgoing_headers["blub"] == "foo"
+ assert outgoing_headers["sentry-something"] == "bar"
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers["headers"]
+
+
+def test_span_with_transaction(sentry_init):
+ sentry_init(traces_sample_rate=1.0)
+ headers = {}
+ monitor_beat_tasks = False
+
+ with sentry_sdk.start_transaction(name="test_transaction") as transaction:
+ with sentry_sdk.start_span(op="test_span") as span:
+ outgoing_headers = _update_celery_task_headers(
+ headers, span, monitor_beat_tasks
+ )
+
+ assert outgoing_headers["sentry-trace"] == span.to_traceparent()
+ assert outgoing_headers["headers"]["sentry-trace"] == span.to_traceparent()
+ assert outgoing_headers["baggage"] == transaction.get_baggage().serialize()
+ assert (
+ outgoing_headers["headers"]["baggage"]
+ == transaction.get_baggage().serialize()
+ )
+
+
+def test_span_with_transaction_custom_headers(sentry_init):
+ sentry_init(traces_sample_rate=1.0)
+ headers = {
+ "baggage": BAGGAGE_VALUE,
+ "sentry-trace": SENTRY_TRACE_VALUE,
+ }
+
+ with sentry_sdk.start_transaction(name="test_transaction") as transaction:
+ with sentry_sdk.start_span(op="test_span") as span:
+ outgoing_headers = _update_celery_task_headers(headers, span, False)
+
+ assert outgoing_headers["sentry-trace"] == span.to_traceparent()
+ assert outgoing_headers["headers"]["sentry-trace"] == span.to_traceparent()
+
+ incoming_baggage = Baggage.from_incoming_header(headers["baggage"])
+ combined_baggage = copy(transaction.get_baggage())
+ combined_baggage.sentry_items.update(incoming_baggage.sentry_items)
+ combined_baggage.third_party_items = ",".join(
+ [
+ x
+ for x in [
+ combined_baggage.third_party_items,
+ incoming_baggage.third_party_items,
+ ]
+ if x is not None and x != ""
+ ]
+ )
+ assert outgoing_headers["baggage"] == combined_baggage.serialize(
+ include_third_party=True
+ )
+ assert outgoing_headers["headers"]["baggage"] == combined_baggage.serialize(
+ include_third_party=True
+ )
+
+
+@pytest.mark.parametrize("monitor_beat_tasks", [True, False])
+def test_celery_trace_propagation_default(sentry_init, monitor_beat_tasks):
+ """
+ The celery integration does not check the traces_sample_rate.
+ By default traces_sample_rate is None which means "do not propagate traces".
+ But the celery integration does not check this value.
+ The Celery integration has its own mechanism to propagate traces:
+ https://docs.sentry.io/platforms/python/integrations/celery/#distributed-traces
+ """
+ sentry_init()
+
+ headers = {}
+ span = None
+
+ scope = sentry_sdk.get_isolation_scope()
+
+ outgoing_headers = _update_celery_task_headers(headers, span, monitor_beat_tasks)
+
+ assert outgoing_headers["sentry-trace"] == scope.get_traceparent()
+ assert outgoing_headers["headers"]["sentry-trace"] == scope.get_traceparent()
+ assert outgoing_headers["baggage"] == scope.get_baggage().serialize()
+ assert outgoing_headers["headers"]["baggage"] == scope.get_baggage().serialize()
+
+ if monitor_beat_tasks:
+ assert "sentry-monitor-start-timestamp-s" in outgoing_headers
+ assert "sentry-monitor-start-timestamp-s" in outgoing_headers["headers"]
+ else:
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers["headers"]
+
+
+@pytest.mark.parametrize(
+ "traces_sample_rate,monitor_beat_tasks",
+ list(itertools.product([None, 0, 0.0, 0.5, 1.0, 1, 2], [True, False])),
+)
+def test_celery_trace_propagation_traces_sample_rate(
+ sentry_init, traces_sample_rate, monitor_beat_tasks
+):
+ """
+ The celery integration does not check the traces_sample_rate.
+ By default traces_sample_rate is None which means "do not propagate traces".
+ But the celery integration does not check this value.
+ The Celery integration has its own mechanism to propagate traces:
+ https://docs.sentry.io/platforms/python/integrations/celery/#distributed-traces
+ """
+ sentry_init(traces_sample_rate=traces_sample_rate)
+
+ headers = {}
+ span = None
+
+ scope = sentry_sdk.get_isolation_scope()
+
+ outgoing_headers = _update_celery_task_headers(headers, span, monitor_beat_tasks)
+
+ assert outgoing_headers["sentry-trace"] == scope.get_traceparent()
+ assert outgoing_headers["headers"]["sentry-trace"] == scope.get_traceparent()
+ assert outgoing_headers["baggage"] == scope.get_baggage().serialize()
+ assert outgoing_headers["headers"]["baggage"] == scope.get_baggage().serialize()
+
+ if monitor_beat_tasks:
+ assert "sentry-monitor-start-timestamp-s" in outgoing_headers
+ assert "sentry-monitor-start-timestamp-s" in outgoing_headers["headers"]
+ else:
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers["headers"]
+
+
+@pytest.mark.parametrize(
+ "enable_tracing,monitor_beat_tasks",
+ list(itertools.product([None, True, False], [True, False])),
+)
+def test_celery_trace_propagation_enable_tracing(
+ sentry_init, enable_tracing, monitor_beat_tasks
+):
+ """
+ The celery integration does not check the traces_sample_rate.
+ By default traces_sample_rate is None which means "do not propagate traces".
+ But the celery integration does not check this value.
+ The Celery integration has its own mechanism to propagate traces:
+ https://docs.sentry.io/platforms/python/integrations/celery/#distributed-traces
+ """
+ sentry_init(enable_tracing=enable_tracing)
+
+ headers = {}
+ span = None
+
+ scope = sentry_sdk.get_isolation_scope()
+
+ outgoing_headers = _update_celery_task_headers(headers, span, monitor_beat_tasks)
+
+ assert outgoing_headers["sentry-trace"] == scope.get_traceparent()
+ assert outgoing_headers["headers"]["sentry-trace"] == scope.get_traceparent()
+ assert outgoing_headers["baggage"] == scope.get_baggage().serialize()
+ assert outgoing_headers["headers"]["baggage"] == scope.get_baggage().serialize()
+
+ if monitor_beat_tasks:
+ assert "sentry-monitor-start-timestamp-s" in outgoing_headers
+ assert "sentry-monitor-start-timestamp-s" in outgoing_headers["headers"]
+ else:
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers
+ assert "sentry-monitor-start-timestamp-s" not in outgoing_headers["headers"]
diff --git a/tests/integrations/chalice/test_chalice.py b/tests/integrations/chalice/test_chalice.py
index 4162a55623..ec8106eb5f 100644
--- a/tests/integrations/chalice/test_chalice.py
+++ b/tests/integrations/chalice/test_chalice.py
@@ -3,8 +3,9 @@
from chalice import Chalice, BadRequestError
from chalice.local import LambdaContext, LocalGateway
-from sentry_sdk.integrations.chalice import ChaliceIntegration
from sentry_sdk import capture_message
+from sentry_sdk.integrations.chalice import CHALICE_VERSION, ChaliceIntegration
+from sentry_sdk.utils import parse_version
from pytest_chalice.handlers import RequestHandler
@@ -65,12 +66,10 @@ def lambda_context_args():
def test_exception_boom(app, client: RequestHandler) -> None:
response = client.get("/boom")
assert response.status_code == 500
- assert response.json == dict(
- [
- ("Code", "InternalServerError"),
- ("Message", "An internal server error occurred."),
- ]
- )
+ assert response.json == {
+ "Code": "InternalServerError",
+ "Message": "An internal server error occurred.",
+ }
def test_has_request(app, capture_events, client: RequestHandler):
@@ -110,16 +109,32 @@ def every_hour(event):
assert str(exc_info.value) == "schedule event!"
-def test_bad_reques(client: RequestHandler) -> None:
+@pytest.mark.skipif(
+ parse_version(CHALICE_VERSION) >= (1, 26, 0),
+ reason="different behavior based on chalice version",
+)
+def test_bad_request_old(client: RequestHandler) -> None:
response = client.get("/badrequest")
assert response.status_code == 400
- assert response.json == dict(
- [
- ("Code", "BadRequestError"),
- ("Message", "BadRequestError: bad-request"),
- ]
- )
+ assert response.json == {
+ "Code": "BadRequestError",
+ "Message": "BadRequestError: bad-request",
+ }
+
+
+@pytest.mark.skipif(
+ parse_version(CHALICE_VERSION) < (1, 26, 0),
+ reason="different behavior based on chalice version",
+)
+def test_bad_request(client: RequestHandler) -> None:
+ response = client.get("/badrequest")
+
+ assert response.status_code == 400
+ assert response.json == {
+ "Code": "BadRequestError",
+ "Message": "bad-request",
+ }
@pytest.mark.parametrize(
diff --git a/tests/integrations/clickhouse_driver/__init__.py b/tests/integrations/clickhouse_driver/__init__.py
new file mode 100644
index 0000000000..602c4e553c
--- /dev/null
+++ b/tests/integrations/clickhouse_driver/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("clickhouse_driver")
diff --git a/tests/integrations/clickhouse_driver/test_clickhouse_driver.py b/tests/integrations/clickhouse_driver/test_clickhouse_driver.py
new file mode 100644
index 0000000000..b501aa3531
--- /dev/null
+++ b/tests/integrations/clickhouse_driver/test_clickhouse_driver.py
@@ -0,0 +1,990 @@
+"""
+Tests need a local clickhouse instance running, this can best be done using
+```sh
+docker run -d -p 18123:8123 -p9000:9000 --name clickhouse-test --ulimit nofile=262144:262144 --rm clickhouse/clickhouse-server
+```
+"""
+
+import clickhouse_driver
+from clickhouse_driver import Client, connect
+
+from sentry_sdk import start_transaction, capture_message
+from sentry_sdk.integrations.clickhouse_driver import ClickhouseDriverIntegration
+from tests.conftest import ApproxDict
+
+EXPECT_PARAMS_IN_SELECT = True
+if clickhouse_driver.VERSION < (0, 2, 6):
+ EXPECT_PARAMS_IN_SELECT = False
+
+
+def test_clickhouse_client_breadcrumbs(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ client = Client("localhost")
+ client.execute("DROP TABLE IF EXISTS test")
+ client.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ client.execute("INSERT INTO test (x) VALUES", [{"x": 100}])
+ client.execute("INSERT INTO test (x) VALUES", [[170], [200]])
+
+ res = client.execute("SELECT sum(x) FROM test WHERE x > %(minv)i", {"minv": 150})
+ assert res[0][0] == 370
+
+ capture_message("hi")
+
+ (event,) = events
+
+ expected_breadcrumbs = [
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "DROP TABLE IF EXISTS test",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "CREATE TABLE test (x Int32) ENGINE = Memory",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "INSERT INTO test (x) VALUES",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "INSERT INTO test (x) VALUES",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "SELECT sum(x) FROM test WHERE x > 150",
+ "type": "default",
+ },
+ ]
+
+ if not EXPECT_PARAMS_IN_SELECT:
+ expected_breadcrumbs[-1]["data"].pop("db.params", None)
+
+ for crumb in expected_breadcrumbs:
+ crumb["data"] = ApproxDict(crumb["data"])
+
+ for crumb in event["breadcrumbs"]["values"]:
+ crumb.pop("timestamp", None)
+
+ actual_query_breadcrumbs = [
+ breadcrumb
+ for breadcrumb in event["breadcrumbs"]["values"]
+ if breadcrumb["category"] == "query"
+ ]
+
+ assert actual_query_breadcrumbs == expected_breadcrumbs
+
+
+def test_clickhouse_client_breadcrumbs_with_pii(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ send_default_pii=True,
+ _experiments={"record_sql_params": True},
+ )
+ events = capture_events()
+
+ client = Client("localhost")
+ client.execute("DROP TABLE IF EXISTS test")
+ client.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ client.execute("INSERT INTO test (x) VALUES", [{"x": 100}])
+ client.execute("INSERT INTO test (x) VALUES", [[170], [200]])
+
+ res = client.execute("SELECT sum(x) FROM test WHERE x > %(minv)i", {"minv": 150})
+ assert res[0][0] == 370
+
+ capture_message("hi")
+
+ (event,) = events
+
+ expected_breadcrumbs = [
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [],
+ },
+ "message": "DROP TABLE IF EXISTS test",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [],
+ },
+ "message": "CREATE TABLE test (x Int32) ENGINE = Memory",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": [{"x": 100}],
+ },
+ "message": "INSERT INTO test (x) VALUES",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": [[170], [200]],
+ },
+ "message": "INSERT INTO test (x) VALUES",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [[370]],
+ "db.params": {"minv": 150},
+ },
+ "message": "SELECT sum(x) FROM test WHERE x > 150",
+ "type": "default",
+ },
+ ]
+
+ if not EXPECT_PARAMS_IN_SELECT:
+ expected_breadcrumbs[-1]["data"].pop("db.params", None)
+
+ for crumb in expected_breadcrumbs:
+ crumb["data"] = ApproxDict(crumb["data"])
+
+ for crumb in event["breadcrumbs"]["values"]:
+ crumb.pop("timestamp", None)
+
+ assert event["breadcrumbs"]["values"] == expected_breadcrumbs
+
+
+def test_clickhouse_client_spans(
+ sentry_init, capture_events, capture_envelopes
+) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ _experiments={"record_sql_params": True},
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ transaction_trace_id = None
+ transaction_span_id = None
+
+ with start_transaction(name="test_clickhouse_transaction") as transaction:
+ transaction_trace_id = transaction.trace_id
+ transaction_span_id = transaction.span_id
+
+ client = Client("localhost")
+ client.execute("DROP TABLE IF EXISTS test")
+ client.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ client.execute("INSERT INTO test (x) VALUES", [{"x": 100}])
+ client.execute("INSERT INTO test (x) VALUES", [[170], [200]])
+
+ res = client.execute(
+ "SELECT sum(x) FROM test WHERE x > %(minv)i", {"minv": 150}
+ )
+ assert res[0][0] == 370
+
+ (event,) = events
+
+ expected_spans = [
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "DROP TABLE IF EXISTS test",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "CREATE TABLE test (x Int32) ENGINE = Memory",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "INSERT INTO test (x) VALUES",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "INSERT INTO test (x) VALUES",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "SELECT sum(x) FROM test WHERE x > 150",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ ]
+
+ if not EXPECT_PARAMS_IN_SELECT:
+ expected_spans[-1]["data"].pop("db.params", None)
+
+ for span in expected_spans:
+ span["data"] = ApproxDict(span["data"])
+
+ for span in event["spans"]:
+ span.pop("span_id", None)
+ span.pop("start_timestamp", None)
+ span.pop("timestamp", None)
+
+ assert event["spans"] == expected_spans
+
+
+def test_clickhouse_spans_with_generator(sentry_init, capture_events):
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Use a generator to test that the integration obtains values from the generator,
+ # without consuming the generator.
+ values = ({"x": i} for i in range(3))
+
+ with start_transaction(name="test_clickhouse_transaction"):
+ client = Client("localhost")
+ client.execute("DROP TABLE IF EXISTS test")
+ client.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ client.execute("INSERT INTO test (x) VALUES", values)
+ res = client.execute("SELECT x FROM test")
+
+ # Verify that the integration did not consume the generator
+ assert res == [(0,), (1,), (2,)]
+
+ (event,) = events
+ spans = event["spans"]
+
+ [span] = [
+ span for span in spans if span["description"] == "INSERT INTO test (x) VALUES"
+ ]
+
+ assert span["data"]["db.params"] == [{"x": 0}, {"x": 1}, {"x": 2}]
+
+
+def test_clickhouse_client_spans_with_pii(
+ sentry_init, capture_events, capture_envelopes
+) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ _experiments={"record_sql_params": True},
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ transaction_trace_id = None
+ transaction_span_id = None
+
+ with start_transaction(name="test_clickhouse_transaction") as transaction:
+ transaction_trace_id = transaction.trace_id
+ transaction_span_id = transaction.span_id
+
+ client = Client("localhost")
+ client.execute("DROP TABLE IF EXISTS test")
+ client.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ client.execute("INSERT INTO test (x) VALUES", [{"x": 100}])
+ client.execute("INSERT INTO test (x) VALUES", [[170], [200]])
+
+ res = client.execute(
+ "SELECT sum(x) FROM test WHERE x > %(minv)i", {"minv": 150}
+ )
+ assert res[0][0] == 370
+
+ (event,) = events
+
+ expected_spans = [
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "DROP TABLE IF EXISTS test",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "CREATE TABLE test (x Int32) ENGINE = Memory",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "INSERT INTO test (x) VALUES",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": [{"x": 100}],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "INSERT INTO test (x) VALUES",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": [[170], [200]],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "SELECT sum(x) FROM test WHERE x > 150",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": {"minv": 150},
+ "db.result": [[370]],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ ]
+
+ if not EXPECT_PARAMS_IN_SELECT:
+ expected_spans[-1]["data"].pop("db.params", None)
+
+ for span in expected_spans:
+ span["data"] = ApproxDict(span["data"])
+
+ for span in event["spans"]:
+ span.pop("span_id", None)
+ span.pop("start_timestamp", None)
+ span.pop("timestamp", None)
+
+ assert event["spans"] == expected_spans
+
+
+def test_clickhouse_dbapi_breadcrumbs(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ )
+ events = capture_events()
+
+ conn = connect("clickhouse://localhost")
+ cursor = conn.cursor()
+ cursor.execute("DROP TABLE IF EXISTS test")
+ cursor.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ cursor.executemany("INSERT INTO test (x) VALUES", [{"x": 100}])
+ cursor.executemany("INSERT INTO test (x) VALUES", [[170], [200]])
+ cursor.execute("SELECT sum(x) FROM test WHERE x > %(minv)i", {"minv": 150})
+ res = cursor.fetchall()
+
+ assert res[0][0] == 370
+
+ capture_message("hi")
+
+ (event,) = events
+
+ expected_breadcrumbs = [
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "DROP TABLE IF EXISTS test",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "CREATE TABLE test (x Int32) ENGINE = Memory",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "INSERT INTO test (x) VALUES",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "INSERT INTO test (x) VALUES",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "message": "SELECT sum(x) FROM test WHERE x > 150",
+ "type": "default",
+ },
+ ]
+
+ if not EXPECT_PARAMS_IN_SELECT:
+ expected_breadcrumbs[-1]["data"].pop("db.params", None)
+
+ for crumb in expected_breadcrumbs:
+ crumb["data"] = ApproxDict(crumb["data"])
+
+ for crumb in event["breadcrumbs"]["values"]:
+ crumb.pop("timestamp", None)
+
+ assert event["breadcrumbs"]["values"] == expected_breadcrumbs
+
+
+def test_clickhouse_dbapi_breadcrumbs_with_pii(sentry_init, capture_events) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ conn = connect("clickhouse://localhost")
+ cursor = conn.cursor()
+ cursor.execute("DROP TABLE IF EXISTS test")
+ cursor.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ cursor.executemany("INSERT INTO test (x) VALUES", [{"x": 100}])
+ cursor.executemany("INSERT INTO test (x) VALUES", [[170], [200]])
+ cursor.execute("SELECT sum(x) FROM test WHERE x > %(minv)i", {"minv": 150})
+ res = cursor.fetchall()
+
+ assert res[0][0] == 370
+
+ capture_message("hi")
+
+ (event,) = events
+
+ expected_breadcrumbs = [
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [[], []],
+ },
+ "message": "DROP TABLE IF EXISTS test",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [[], []],
+ },
+ "message": "CREATE TABLE test (x Int32) ENGINE = Memory",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": [{"x": 100}],
+ },
+ "message": "INSERT INTO test (x) VALUES",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": [[170], [200]],
+ },
+ "message": "INSERT INTO test (x) VALUES",
+ "type": "default",
+ },
+ {
+ "category": "query",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": {"minv": 150},
+ "db.result": [[["370"]], [["'sum(x)'", "'Int64'"]]],
+ },
+ "message": "SELECT sum(x) FROM test WHERE x > 150",
+ "type": "default",
+ },
+ ]
+
+ if not EXPECT_PARAMS_IN_SELECT:
+ expected_breadcrumbs[-1]["data"].pop("db.params", None)
+
+ for crumb in expected_breadcrumbs:
+ crumb["data"] = ApproxDict(crumb["data"])
+
+ for crumb in event["breadcrumbs"]["values"]:
+ crumb.pop("timestamp", None)
+
+ assert event["breadcrumbs"]["values"] == expected_breadcrumbs
+
+
+def test_clickhouse_dbapi_spans(sentry_init, capture_events, capture_envelopes) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ _experiments={"record_sql_params": True},
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ transaction_trace_id = None
+ transaction_span_id = None
+
+ with start_transaction(name="test_clickhouse_transaction") as transaction:
+ transaction_trace_id = transaction.trace_id
+ transaction_span_id = transaction.span_id
+
+ conn = connect("clickhouse://localhost")
+ cursor = conn.cursor()
+ cursor.execute("DROP TABLE IF EXISTS test")
+ cursor.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ cursor.executemany("INSERT INTO test (x) VALUES", [{"x": 100}])
+ cursor.executemany("INSERT INTO test (x) VALUES", [[170], [200]])
+ cursor.execute("SELECT sum(x) FROM test WHERE x > %(minv)i", {"minv": 150})
+ res = cursor.fetchall()
+
+ assert res[0][0] == 370
+
+ (event,) = events
+
+ expected_spans = [
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "DROP TABLE IF EXISTS test",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "CREATE TABLE test (x Int32) ENGINE = Memory",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "INSERT INTO test (x) VALUES",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "INSERT INTO test (x) VALUES",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "SELECT sum(x) FROM test WHERE x > 150",
+ "data": {
+ "db.system": "clickhouse",
+ "db.driver.name": "clickhouse-driver",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ ]
+
+ if not EXPECT_PARAMS_IN_SELECT:
+ expected_spans[-1]["data"].pop("db.params", None)
+
+ for span in expected_spans:
+ span["data"] = ApproxDict(span["data"])
+
+ for span in event["spans"]:
+ span.pop("span_id", None)
+ span.pop("start_timestamp", None)
+ span.pop("timestamp", None)
+
+ assert event["spans"] == expected_spans
+
+
+def test_clickhouse_dbapi_spans_with_pii(
+ sentry_init, capture_events, capture_envelopes
+) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ _experiments={"record_sql_params": True},
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ transaction_trace_id = None
+ transaction_span_id = None
+
+ with start_transaction(name="test_clickhouse_transaction") as transaction:
+ transaction_trace_id = transaction.trace_id
+ transaction_span_id = transaction.span_id
+
+ conn = connect("clickhouse://localhost")
+ cursor = conn.cursor()
+ cursor.execute("DROP TABLE IF EXISTS test")
+ cursor.execute("CREATE TABLE test (x Int32) ENGINE = Memory")
+ cursor.executemany("INSERT INTO test (x) VALUES", [{"x": 100}])
+ cursor.executemany("INSERT INTO test (x) VALUES", [[170], [200]])
+ cursor.execute("SELECT sum(x) FROM test WHERE x > %(minv)i", {"minv": 150})
+ res = cursor.fetchall()
+
+ assert res[0][0] == 370
+
+ (event,) = events
+
+ expected_spans = [
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "DROP TABLE IF EXISTS test",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [[], []],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "CREATE TABLE test (x Int32) ENGINE = Memory",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.result": [[], []],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "INSERT INTO test (x) VALUES",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": [{"x": 100}],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "INSERT INTO test (x) VALUES",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": [[170], [200]],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ {
+ "op": "db",
+ "origin": "auto.db.clickhouse_driver",
+ "description": "SELECT sum(x) FROM test WHERE x > 150",
+ "data": {
+ "db.system": "clickhouse",
+ "db.name": "",
+ "db.user": "default",
+ "server.address": "localhost",
+ "server.port": 9000,
+ "db.params": {"minv": 150},
+ "db.result": [[[370]], [["sum(x)", "Int64"]]],
+ },
+ "same_process_as_parent": True,
+ "trace_id": transaction_trace_id,
+ "parent_span_id": transaction_span_id,
+ },
+ ]
+
+ if not EXPECT_PARAMS_IN_SELECT:
+ expected_spans[-1]["data"].pop("db.params", None)
+
+ for span in expected_spans:
+ span["data"] = ApproxDict(span["data"])
+
+ for span in event["spans"]:
+ span.pop("span_id", None)
+ span.pop("start_timestamp", None)
+ span.pop("timestamp", None)
+
+ assert event["spans"] == expected_spans
+
+
+def test_span_origin(sentry_init, capture_events, capture_envelopes) -> None:
+ sentry_init(
+ integrations=[ClickhouseDriverIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_clickhouse_transaction"):
+ conn = connect("clickhouse://localhost")
+ cursor = conn.cursor()
+ cursor.execute("SELECT 1")
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.db.clickhouse_driver"
diff --git a/tests/integrations/cloud_resource_context/test_cloud_resource_context.py b/tests/integrations/cloud_resource_context/test_cloud_resource_context.py
index b1efd97f3f..49732b00a5 100644
--- a/tests/integrations/cloud_resource_context/test_cloud_resource_context.py
+++ b/tests/integrations/cloud_resource_context/test_cloud_resource_context.py
@@ -1,8 +1,8 @@
import json
+from unittest import mock
+from unittest.mock import MagicMock
import pytest
-import mock
-from mock import MagicMock
from sentry_sdk.integrations.cloud_resource_context import (
CLOUD_PLATFORM,
@@ -27,16 +27,11 @@
"version": "2017-09-30",
}
-try:
- # Python 3
- AWS_EC2_EXAMPLE_IMDSv2_PAYLOAD_BYTES = bytes(
- json.dumps(AWS_EC2_EXAMPLE_IMDSv2_PAYLOAD), "utf-8"
- )
-except TypeError:
- # Python 2
- AWS_EC2_EXAMPLE_IMDSv2_PAYLOAD_BYTES = bytes(
- json.dumps(AWS_EC2_EXAMPLE_IMDSv2_PAYLOAD)
- ).encode("utf-8")
+
+AWS_EC2_EXAMPLE_IMDSv2_PAYLOAD_BYTES = bytes(
+ json.dumps(AWS_EC2_EXAMPLE_IMDSv2_PAYLOAD), "utf-8"
+)
+
GCP_GCE_EXAMPLE_METADATA_PLAYLOAD = {
"instance": {
@@ -136,7 +131,7 @@ def test_is_aws_ok():
CloudResourceContextIntegration.http.request = MagicMock(return_value=response)
assert CloudResourceContextIntegration._is_aws() is True
- assert CloudResourceContextIntegration.aws_token == b"something"
+ assert CloudResourceContextIntegration.aws_token == "something"
CloudResourceContextIntegration.http.request = MagicMock(
side_effect=Exception("Test")
@@ -399,7 +394,17 @@ def test_setup_once(
else:
fake_set_context.assert_not_called()
- if warning_called:
- fake_warning.assert_called_once()
- else:
- fake_warning.assert_not_called()
+ def invalid_value_warning_calls():
+ """
+ Iterator that yields True if the warning was called with the expected message.
+ Written as a generator function, rather than a list comprehension, to allow
+ us to handle exceptions that might be raised during the iteration if the
+ warning call was not as expected.
+ """
+ for call in fake_warning.call_args_list:
+ try:
+ yield call[0][0].startswith("Invalid value for cloud_provider:")
+ except (IndexError, KeyError, TypeError, AttributeError):
+ ...
+
+ assert warning_called == any(invalid_value_warning_calls())
diff --git a/tests/integrations/cohere/__init__.py b/tests/integrations/cohere/__init__.py
new file mode 100644
index 0000000000..3484a6dc41
--- /dev/null
+++ b/tests/integrations/cohere/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("cohere")
diff --git a/tests/integrations/cohere/test_cohere.py b/tests/integrations/cohere/test_cohere.py
new file mode 100644
index 0000000000..9ff56ed697
--- /dev/null
+++ b/tests/integrations/cohere/test_cohere.py
@@ -0,0 +1,304 @@
+import json
+
+import httpx
+import pytest
+from cohere import Client, ChatMessage
+
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.integrations.cohere import CohereIntegration
+
+from unittest import mock # python 3.3 and above
+from httpx import Client as HTTPXClient
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+def test_nonstreaming_chat(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[CohereIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = Client(api_key="z")
+ HTTPXClient.request = mock.Mock(
+ return_value=httpx.Response(
+ 200,
+ json={
+ "text": "the model response",
+ "meta": {
+ "billed_units": {
+ "output_tokens": 10,
+ "input_tokens": 20,
+ }
+ },
+ },
+ )
+ )
+
+ with start_transaction(name="cohere tx"):
+ response = client.chat(
+ model="some-model",
+ chat_history=[ChatMessage(role="SYSTEM", message="some context")],
+ message="hello",
+ ).text
+
+ assert response == "the model response"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "ai.chat_completions.create.cohere"
+ assert span["data"][SPANDATA.AI_MODEL_ID] == "some-model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ '{"role": "system", "content": "some context"}'
+ in span["data"][SPANDATA.AI_INPUT_MESSAGES]
+ )
+ assert (
+ '{"role": "user", "content": "hello"}'
+ in span["data"][SPANDATA.AI_INPUT_MESSAGES]
+ )
+ assert "the model response" in span["data"][SPANDATA.AI_RESPONSES]
+ else:
+ assert SPANDATA.AI_INPUT_MESSAGES not in span["data"]
+ assert SPANDATA.AI_RESPONSES not in span["data"]
+
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+# noinspection PyTypeChecker
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+def test_streaming_chat(sentry_init, capture_events, send_default_pii, include_prompts):
+ sentry_init(
+ integrations=[CohereIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = Client(api_key="z")
+ HTTPXClient.send = mock.Mock(
+ return_value=httpx.Response(
+ 200,
+ content="\n".join(
+ [
+ json.dumps({"event_type": "text-generation", "text": "the model "}),
+ json.dumps({"event_type": "text-generation", "text": "response"}),
+ json.dumps(
+ {
+ "event_type": "stream-end",
+ "finish_reason": "COMPLETE",
+ "response": {
+ "text": "the model response",
+ "meta": {
+ "billed_units": {
+ "output_tokens": 10,
+ "input_tokens": 20,
+ }
+ },
+ },
+ }
+ ),
+ ]
+ ),
+ )
+ )
+
+ with start_transaction(name="cohere tx"):
+ responses = list(
+ client.chat_stream(
+ model="some-model",
+ chat_history=[ChatMessage(role="SYSTEM", message="some context")],
+ message="hello",
+ )
+ )
+ response_string = responses[-1].response.text
+
+ assert response_string == "the model response"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "ai.chat_completions.create.cohere"
+ assert span["data"][SPANDATA.AI_MODEL_ID] == "some-model"
+
+ if send_default_pii and include_prompts:
+ assert (
+ '{"role": "system", "content": "some context"}'
+ in span["data"][SPANDATA.AI_INPUT_MESSAGES]
+ )
+ assert (
+ '{"role": "user", "content": "hello"}'
+ in span["data"][SPANDATA.AI_INPUT_MESSAGES]
+ )
+ assert "the model response" in span["data"][SPANDATA.AI_RESPONSES]
+ else:
+ assert SPANDATA.AI_INPUT_MESSAGES not in span["data"]
+ assert SPANDATA.AI_RESPONSES not in span["data"]
+
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+def test_bad_chat(sentry_init, capture_events):
+ sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ client = Client(api_key="z")
+ HTTPXClient.request = mock.Mock(
+ side_effect=httpx.HTTPError("API rate limit reached")
+ )
+ with pytest.raises(httpx.HTTPError):
+ client.chat(model="some-model", message="hello")
+
+ (event,) = events
+ assert event["level"] == "error"
+
+
+def test_span_status_error(sentry_init, capture_events):
+ sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with start_transaction(name="test"):
+ client = Client(api_key="z")
+ HTTPXClient.request = mock.Mock(
+ side_effect=httpx.HTTPError("API rate limit reached")
+ )
+ with pytest.raises(httpx.HTTPError):
+ client.chat(model="some-model", message="hello")
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+ assert transaction["spans"][0]["status"] == "internal_error"
+ assert transaction["spans"][0]["tags"]["status"] == "internal_error"
+ assert transaction["contexts"]["trace"]["status"] == "internal_error"
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+def test_embed(sentry_init, capture_events, send_default_pii, include_prompts):
+ sentry_init(
+ integrations=[CohereIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = Client(api_key="z")
+ HTTPXClient.request = mock.Mock(
+ return_value=httpx.Response(
+ 200,
+ json={
+ "response_type": "embeddings_floats",
+ "id": "1",
+ "texts": ["hello"],
+ "embeddings": [[1.0, 2.0, 3.0]],
+ "meta": {
+ "billed_units": {
+ "input_tokens": 10,
+ }
+ },
+ },
+ )
+ )
+
+ with start_transaction(name="cohere tx"):
+ response = client.embed(texts=["hello"], model="text-embedding-3-large")
+
+ assert len(response.embeddings[0]) == 3
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "ai.embeddings.create.cohere"
+ if send_default_pii and include_prompts:
+ assert "hello" in span["data"][SPANDATA.AI_INPUT_MESSAGES]
+ else:
+ assert SPANDATA.AI_INPUT_MESSAGES not in span["data"]
+
+ assert span["data"]["gen_ai.usage.input_tokens"] == 10
+ assert span["data"]["gen_ai.usage.total_tokens"] == 10
+
+
+def test_span_origin_chat(sentry_init, capture_events):
+ sentry_init(
+ integrations=[CohereIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = Client(api_key="z")
+ HTTPXClient.request = mock.Mock(
+ return_value=httpx.Response(
+ 200,
+ json={
+ "text": "the model response",
+ "meta": {
+ "billed_units": {
+ "output_tokens": 10,
+ "input_tokens": 20,
+ }
+ },
+ },
+ )
+ )
+
+ with start_transaction(name="cohere tx"):
+ client.chat(
+ model="some-model",
+ chat_history=[ChatMessage(role="SYSTEM", message="some context")],
+ message="hello",
+ ).text
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.cohere"
+
+
+def test_span_origin_embed(sentry_init, capture_events):
+ sentry_init(
+ integrations=[CohereIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = Client(api_key="z")
+ HTTPXClient.request = mock.Mock(
+ return_value=httpx.Response(
+ 200,
+ json={
+ "response_type": "embeddings_floats",
+ "id": "1",
+ "texts": ["hello"],
+ "embeddings": [[1.0, 2.0, 3.0]],
+ "meta": {
+ "billed_units": {
+ "input_tokens": 10,
+ }
+ },
+ },
+ )
+ )
+
+ with start_transaction(name="cohere tx"):
+ client.embed(texts=["hello"], model="text-embedding-3-large")
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.cohere"
diff --git a/tests/integrations/conftest.py b/tests/integrations/conftest.py
index cffb278d70..7ac43b0efe 100644
--- a/tests/integrations/conftest.py
+++ b/tests/integrations/conftest.py
@@ -6,16 +6,50 @@
def capture_exceptions(monkeypatch):
def inner():
errors = set()
- old_capture_event = sentry_sdk.Hub.capture_event
+ old_capture_event_hub = sentry_sdk.Hub.capture_event
+ old_capture_event_scope = sentry_sdk.Scope.capture_event
- def capture_event(self, event, hint=None):
+ def capture_event_hub(self, event, hint=None, scope=None):
+ """
+ Can be removed when we remove push_scope and the Hub from the SDK.
+ """
if hint:
if "exc_info" in hint:
error = hint["exc_info"][1]
errors.add(error)
- return old_capture_event(self, event, hint=hint)
+ return old_capture_event_hub(self, event, hint=hint, scope=scope)
+
+ def capture_event_scope(self, event, hint=None, scope=None):
+ if hint:
+ if "exc_info" in hint:
+ error = hint["exc_info"][1]
+ errors.add(error)
+ return old_capture_event_scope(self, event, hint=hint, scope=scope)
+
+ monkeypatch.setattr(sentry_sdk.Hub, "capture_event", capture_event_hub)
+ monkeypatch.setattr(sentry_sdk.Scope, "capture_event", capture_event_scope)
- monkeypatch.setattr(sentry_sdk.Hub, "capture_event", capture_event)
return errors
return inner
+
+
+parametrize_test_configurable_status_codes = pytest.mark.parametrize(
+ ("failed_request_status_codes", "status_code", "expected_error"),
+ (
+ (None, 500, True),
+ (None, 400, False),
+ ({500, 501}, 500, True),
+ ({500, 501}, 401, False),
+ ({*range(400, 500)}, 401, True),
+ ({*range(400, 500)}, 500, False),
+ ({*range(400, 600)}, 300, False),
+ ({*range(400, 600)}, 403, True),
+ ({*range(400, 600)}, 503, True),
+ ({*range(400, 403), 500, 501}, 401, True),
+ ({*range(400, 403), 500, 501}, 405, False),
+ ({*range(400, 403), 500, 501}, 501, True),
+ ({*range(400, 403), 500, 501}, 503, False),
+ (set(), 500, False),
+ ),
+)
diff --git a/tests/integrations/django/__init__.py b/tests/integrations/django/__init__.py
index d2555a8d48..41d72f92a5 100644
--- a/tests/integrations/django/__init__.py
+++ b/tests/integrations/django/__init__.py
@@ -1,3 +1,9 @@
+import os
+import sys
import pytest
-django = pytest.importorskip("django")
+pytest.importorskip("django")
+
+# Load `django_helpers` into the module search path to test query source path names relative to module. See
+# `test_query_source_with_module_in_search_path`
+sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
diff --git a/tests/integrations/django/asgi/image.png b/tests/integrations/django/asgi/image.png
new file mode 100644
index 0000000000..8db277a9fc
Binary files /dev/null and b/tests/integrations/django/asgi/image.png differ
diff --git a/tests/integrations/django/asgi/test_asgi.py b/tests/integrations/django/asgi/test_asgi.py
index d7ea06d85a..f956d12f82 100644
--- a/tests/integrations/django/asgi/test_asgi.py
+++ b/tests/integrations/django/asgi/test_asgi.py
@@ -1,16 +1,24 @@
+import base64
+import sys
import json
+import inspect
+import asyncio
+import os
+from unittest import mock
import django
import pytest
from channels.testing import HttpCommunicator
from sentry_sdk import capture_message
from sentry_sdk.integrations.django import DjangoIntegration
+from sentry_sdk.integrations.django.asgi import _asgi_middleware_mixin_factory
from tests.integrations.django.myapp.asgi import channels_application
try:
- from unittest import mock # python 3.3 and above
+ from django.urls import reverse
except ImportError:
- import mock # python < 3.3
+ from django.core.urlresolvers import reverse
+
APPS = [channels_application]
if django.VERSION >= (3, 0):
@@ -21,13 +29,38 @@
@pytest.mark.parametrize("application", APPS)
@pytest.mark.asyncio
+@pytest.mark.forked
+@pytest.mark.skipif(
+ django.VERSION < (3, 0), reason="Django ASGI support shipped in 3.0"
+)
async def test_basic(sentry_init, capture_events, application):
- sentry_init(integrations=[DjangoIntegration()], send_default_pii=True)
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ )
events = capture_events()
- comm = HttpCommunicator(application, "GET", "/view-exc?test=query")
- response = await comm.get_response()
+ import channels # type: ignore[import-not-found]
+
+ if (
+ sys.version_info < (3, 9)
+ and channels.__version__ < "4.0.0"
+ and django.VERSION >= (3, 0)
+ and django.VERSION < (4, 0)
+ ):
+ # We emit a UserWarning for channels 2.x and 3.x on Python 3.8 and older
+ # because the async support was not really good back then and there is a known issue.
+ # See the TreadingIntegration for details.
+ with pytest.warns(UserWarning):
+ comm = HttpCommunicator(application, "GET", "/view-exc?test=query")
+ response = await comm.get_response()
+ await comm.wait()
+ else:
+ comm = HttpCommunicator(application, "GET", "/view-exc?test=query")
+ response = await comm.get_response()
+ await comm.wait()
+
assert response["status"] == 500
(event,) = events
@@ -53,16 +86,22 @@ async def test_basic(sentry_init, capture_events, application):
@pytest.mark.parametrize("application", APPS)
@pytest.mark.asyncio
+@pytest.mark.forked
@pytest.mark.skipif(
django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
)
async def test_async_views(sentry_init, capture_events, application):
- sentry_init(integrations=[DjangoIntegration()], send_default_pii=True)
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ )
events = capture_events()
comm = HttpCommunicator(application, "GET", "/async_message")
response = await comm.get_response()
+ await comm.wait()
+
assert response["status"] == 200
(event,) = events
@@ -79,57 +118,81 @@ async def test_async_views(sentry_init, capture_events, application):
@pytest.mark.parametrize("application", APPS)
@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"])
+@pytest.mark.parametrize("middleware_spans", [False, True])
@pytest.mark.asyncio
+@pytest.mark.forked
@pytest.mark.skipif(
django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
)
async def test_active_thread_id(
- sentry_init, capture_envelopes, teardown_profiling, endpoint, application
+ sentry_init,
+ capture_envelopes,
+ teardown_profiling,
+ endpoint,
+ application,
+ middleware_spans,
):
- with mock.patch("sentry_sdk.profiler.PROFILE_MINIMUM_SAMPLES", 0):
+ with mock.patch(
+ "sentry_sdk.profiler.transaction_profiler.PROFILE_MINIMUM_SAMPLES", 0
+ ):
sentry_init(
- integrations=[DjangoIntegration()],
+ integrations=[DjangoIntegration(middleware_spans=middleware_spans)],
traces_sample_rate=1.0,
- _experiments={"profiles_sample_rate": 1.0},
+ profiles_sample_rate=1.0,
)
envelopes = capture_envelopes()
comm = HttpCommunicator(application, "GET", endpoint)
response = await comm.get_response()
+ await comm.wait()
+
assert response["status"] == 200, response["body"]
- await comm.wait()
+ assert len(envelopes) == 1
- data = json.loads(response["body"])
+ profiles = [item for item in envelopes[0].items if item.type == "profile"]
+ assert len(profiles) == 1
- envelopes = [envelope for envelope in envelopes]
- assert len(envelopes) == 1
+ data = json.loads(response["body"])
- profiles = [item for item in envelopes[0].items if item.type == "profile"]
- assert len(profiles) == 1
+ for item in profiles:
+ transactions = item.payload.json["transactions"]
+ assert len(transactions) == 1
+ assert str(data["active"]) == transactions[0]["active_thread_id"]
- for profile in profiles:
- transactions = profile.payload.json["transactions"]
- assert len(transactions) == 1
- assert str(data["active"]) == transactions[0]["active_thread_id"]
+ transactions = [item for item in envelopes[0].items if item.type == "transaction"]
+ assert len(transactions) == 1
+
+ for item in transactions:
+ transaction = item.payload.json
+ trace_context = transaction["contexts"]["trace"]
+ assert str(data["active"]) == trace_context["data"]["thread.id"]
@pytest.mark.asyncio
+@pytest.mark.forked
@pytest.mark.skipif(
django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
)
-async def test_async_views_concurrent_execution(sentry_init, capture_events, settings):
+async def test_async_views_concurrent_execution(sentry_init, settings):
import asyncio
import time
settings.MIDDLEWARE = []
asgi_application.load_middleware(is_async=True)
- sentry_init(integrations=[DjangoIntegration()], send_default_pii=True)
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ )
- comm = HttpCommunicator(asgi_application, "GET", "/my_async_view")
- comm2 = HttpCommunicator(asgi_application, "GET", "/my_async_view")
+ comm = HttpCommunicator(
+ asgi_application, "GET", "/my_async_view"
+ ) # sleeps for 1 second
+ comm2 = HttpCommunicator(
+ asgi_application, "GET", "/my_async_view"
+ ) # sleeps for 1 second
loop = asyncio.get_event_loop()
@@ -145,15 +208,18 @@ async def test_async_views_concurrent_execution(sentry_init, capture_events, set
assert resp1.result()["status"] == 200
assert resp2.result()["status"] == 200
- assert end - start < 1.5
+ assert (
+ end - start < 2
+ ) # it takes less than 2 seconds so it was ececuting concurrently
@pytest.mark.asyncio
+@pytest.mark.forked
@pytest.mark.skipif(
django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
)
async def test_async_middleware_that_is_function_concurrent_execution(
- sentry_init, capture_events, settings
+ sentry_init, settings
):
import asyncio
import time
@@ -163,10 +229,17 @@ async def test_async_middleware_that_is_function_concurrent_execution(
]
asgi_application.load_middleware(is_async=True)
- sentry_init(integrations=[DjangoIntegration()], send_default_pii=True)
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ )
- comm = HttpCommunicator(asgi_application, "GET", "/my_async_view")
- comm2 = HttpCommunicator(asgi_application, "GET", "/my_async_view")
+ comm = HttpCommunicator(
+ asgi_application, "GET", "/my_async_view"
+ ) # sleeps for 1 second
+ comm2 = HttpCommunicator(
+ asgi_application, "GET", "/my_async_view"
+ ) # sleeps for 1 second
loop = asyncio.get_event_loop()
@@ -182,10 +255,13 @@ async def test_async_middleware_that_is_function_concurrent_execution(
assert resp1.result()["status"] == 200
assert resp2.result()["status"] == 200
- assert end - start < 1.5
+ assert (
+ end - start < 2
+ ) # it takes less than 2 seconds so it was ececuting concurrently
@pytest.mark.asyncio
+@pytest.mark.forked
@pytest.mark.skipif(
django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
)
@@ -208,13 +284,13 @@ async def test_async_middleware_spans(
events = capture_events()
- comm = HttpCommunicator(asgi_application, "GET", "/async_message")
+ comm = HttpCommunicator(asgi_application, "GET", "/simple_async_view")
response = await comm.get_response()
- assert response["status"] == 200
-
await comm.wait()
- message, transaction = events
+ assert response["status"] == 200
+
+ (transaction,) = events
assert (
render_span_tree(transaction)
@@ -227,8 +303,437 @@ async def test_async_middleware_spans(
- op="middleware.django": description="django.middleware.csrf.CsrfViewMiddleware.__acall__"
- op="middleware.django": description="tests.integrations.django.myapp.settings.TestMiddleware.__acall__"
- op="middleware.django": description="django.middleware.csrf.CsrfViewMiddleware.process_view"
- - op="view.render": description="async_message"
+ - op="view.render": description="simple_async_view"
- op="event.django": description="django.db.close_old_connections"
- op="event.django": description="django.core.cache.close_caches"
- op="event.django": description="django.core.handlers.base.reset_urlconf\""""
)
+
+
+@pytest.mark.asyncio
+@pytest.mark.forked
+@pytest.mark.skipif(
+ django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
+)
+async def test_has_trace_if_performance_enabled(sentry_init, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ comm = HttpCommunicator(asgi_application, "GET", "/view-exc-with-msg")
+ response = await comm.get_response()
+ await comm.wait()
+
+ assert response["status"] == 500
+
+ (msg_event, error_event, transaction_event) = events
+
+ assert (
+ msg_event["contexts"]["trace"]["trace_id"]
+ == error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+@pytest.mark.asyncio
+@pytest.mark.forked
+@pytest.mark.skipif(
+ django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
+)
+async def test_has_trace_if_performance_disabled(sentry_init, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ )
+
+ events = capture_events()
+
+ comm = HttpCommunicator(asgi_application, "GET", "/view-exc-with-msg")
+ response = await comm.get_response()
+ await comm.wait()
+
+ assert response["status"] == 500
+
+ (msg_event, error_event) = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+ assert (
+ msg_event["contexts"]["trace"]["trace_id"]
+ == error_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+@pytest.mark.asyncio
+@pytest.mark.forked
+@pytest.mark.skipif(
+ django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
+)
+async def test_trace_from_headers_if_performance_enabled(sentry_init, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ trace_id = "582b43a4192642f0b136d5159a501701"
+ sentry_trace_header = "{}-{}-{}".format(trace_id, "6e8f22c393e68f19", 1)
+
+ comm = HttpCommunicator(
+ asgi_application,
+ "GET",
+ "/view-exc-with-msg",
+ headers=[(b"sentry-trace", sentry_trace_header.encode())],
+ )
+ response = await comm.get_response()
+ await comm.wait()
+
+ assert response["status"] == 500
+
+ (msg_event, error_event, transaction_event) = events
+
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert transaction_event["contexts"]["trace"]["trace_id"] == trace_id
+
+
+@pytest.mark.asyncio
+@pytest.mark.forked
+@pytest.mark.skipif(
+ django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
+)
+async def test_trace_from_headers_if_performance_disabled(sentry_init, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ )
+
+ events = capture_events()
+
+ trace_id = "582b43a4192642f0b136d5159a501701"
+ sentry_trace_header = "{}-{}-{}".format(trace_id, "6e8f22c393e68f19", 1)
+
+ comm = HttpCommunicator(
+ asgi_application,
+ "GET",
+ "/view-exc-with-msg",
+ headers=[(b"sentry-trace", sentry_trace_header.encode())],
+ )
+ response = await comm.get_response()
+ await comm.wait()
+
+ assert response["status"] == 500
+
+ (msg_event, error_event) = events
+
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
+
+
+PICTURE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "image.png")
+BODY_FORM = """--fd721ef49ea403a6\r\nContent-Disposition: form-data; name="username"\r\n\r\nJane\r\n--fd721ef49ea403a6\r\nContent-Disposition: form-data; name="password"\r\n\r\nhello123\r\n--fd721ef49ea403a6\r\nContent-Disposition: form-data; name="photo"; filename="image.png"\r\nContent-Type: image/png\r\nContent-Transfer-Encoding: base64\r\n\r\n{{image_data}}\r\n--fd721ef49ea403a6--\r\n""".replace(
+ "{{image_data}}", base64.b64encode(open(PICTURE, "rb").read()).decode("utf-8")
+).encode("utf-8")
+BODY_FORM_CONTENT_LENGTH = str(len(BODY_FORM)).encode("utf-8")
+
+
+@pytest.mark.parametrize("application", APPS)
+@pytest.mark.parametrize(
+ "send_default_pii,method,headers,url_name,body,expected_data",
+ [
+ (
+ True,
+ "POST",
+ [(b"content-type", b"text/plain")],
+ "post_echo_async",
+ b"",
+ None,
+ ),
+ (
+ True,
+ "POST",
+ [(b"content-type", b"text/plain")],
+ "post_echo_async",
+ b"some raw text body",
+ "",
+ ),
+ (
+ True,
+ "POST",
+ [(b"content-type", b"application/json")],
+ "post_echo_async",
+ b'{"username":"xyz","password":"xyz"}',
+ {"username": "xyz", "password": "[Filtered]"},
+ ),
+ (
+ True,
+ "POST",
+ [(b"content-type", b"application/xml")],
+ "post_echo_async",
+ b'',
+ "",
+ ),
+ (
+ True,
+ "POST",
+ [
+ (b"content-type", b"multipart/form-data; boundary=fd721ef49ea403a6"),
+ (b"content-length", BODY_FORM_CONTENT_LENGTH),
+ ],
+ "post_echo_async",
+ BODY_FORM,
+ {"password": "[Filtered]", "photo": "", "username": "Jane"},
+ ),
+ (
+ False,
+ "POST",
+ [(b"content-type", b"text/plain")],
+ "post_echo_async",
+ b"",
+ None,
+ ),
+ (
+ False,
+ "POST",
+ [(b"content-type", b"text/plain")],
+ "post_echo_async",
+ b"some raw text body",
+ "",
+ ),
+ (
+ False,
+ "POST",
+ [(b"content-type", b"application/json")],
+ "post_echo_async",
+ b'{"username":"xyz","password":"xyz"}',
+ {"username": "xyz", "password": "[Filtered]"},
+ ),
+ (
+ False,
+ "POST",
+ [(b"content-type", b"application/xml")],
+ "post_echo_async",
+ b'',
+ "",
+ ),
+ (
+ False,
+ "POST",
+ [
+ (b"content-type", b"multipart/form-data; boundary=fd721ef49ea403a6"),
+ (b"content-length", BODY_FORM_CONTENT_LENGTH),
+ ],
+ "post_echo_async",
+ BODY_FORM,
+ {"password": "[Filtered]", "photo": "", "username": "Jane"},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+@pytest.mark.forked
+@pytest.mark.skipif(
+ django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
+)
+async def test_asgi_request_body(
+ sentry_init,
+ capture_envelopes,
+ application,
+ send_default_pii,
+ method,
+ headers,
+ url_name,
+ body,
+ expected_data,
+):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=send_default_pii,
+ )
+
+ envelopes = capture_envelopes()
+
+ comm = HttpCommunicator(
+ application,
+ method=method,
+ headers=headers,
+ path=reverse(url_name),
+ body=body,
+ )
+ response = await comm.get_response()
+ await comm.wait()
+
+ assert response["status"] == 200
+ assert response["body"] == body
+
+ (envelope,) = envelopes
+ event = envelope.get_event()
+
+ if expected_data is not None:
+ assert event["request"]["data"] == expected_data
+ else:
+ assert "data" not in event["request"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ sys.version_info >= (3, 12),
+ reason=(
+ "asyncio.iscoroutinefunction has been replaced in 3.12 by inspect.iscoroutinefunction"
+ ),
+)
+@pytest.mark.skipif(
+ django.VERSION < (3, 0), reason="Django ASGI support shipped in 3.0"
+)
+async def test_asgi_mixin_iscoroutinefunction_before_3_12():
+ sentry_asgi_mixin = _asgi_middleware_mixin_factory(lambda: None)
+
+ async def get_response(): ...
+
+ instance = sentry_asgi_mixin(get_response)
+ assert asyncio.iscoroutinefunction(instance)
+
+
+@pytest.mark.skipif(
+ sys.version_info >= (3, 12),
+ reason=(
+ "asyncio.iscoroutinefunction has been replaced in 3.12 by inspect.iscoroutinefunction"
+ ),
+)
+def test_asgi_mixin_iscoroutinefunction_when_not_async_before_3_12():
+ sentry_asgi_mixin = _asgi_middleware_mixin_factory(lambda: None)
+
+ def get_response(): ...
+
+ instance = sentry_asgi_mixin(get_response)
+ assert not asyncio.iscoroutinefunction(instance)
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ sys.version_info < (3, 12),
+ reason=(
+ "asyncio.iscoroutinefunction has been replaced in 3.12 by inspect.iscoroutinefunction"
+ ),
+)
+async def test_asgi_mixin_iscoroutinefunction_after_3_12():
+ sentry_asgi_mixin = _asgi_middleware_mixin_factory(lambda: None)
+
+ async def get_response(): ...
+
+ instance = sentry_asgi_mixin(get_response)
+ assert inspect.iscoroutinefunction(instance)
+
+
+@pytest.mark.skipif(
+ sys.version_info < (3, 12),
+ reason=(
+ "asyncio.iscoroutinefunction has been replaced in 3.12 by inspect.iscoroutinefunction"
+ ),
+)
+def test_asgi_mixin_iscoroutinefunction_when_not_async_after_3_12():
+ sentry_asgi_mixin = _asgi_middleware_mixin_factory(lambda: None)
+
+ def get_response(): ...
+
+ instance = sentry_asgi_mixin(get_response)
+ assert not inspect.iscoroutinefunction(instance)
+
+
+@pytest.mark.parametrize("application", APPS)
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1"
+)
+async def test_async_view(sentry_init, capture_events, application):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ comm = HttpCommunicator(application, "GET", "/simple_async_view")
+ await comm.get_response()
+ await comm.wait()
+
+ (event,) = events
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "/simple_async_view"
+
+
+@pytest.mark.parametrize("application", APPS)
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ django.VERSION < (3, 0), reason="Django ASGI support shipped in 3.0"
+)
+async def test_transaction_http_method_default(
+ sentry_init, capture_events, application
+):
+ """
+ By default OPTIONS and HEAD requests do not create a transaction.
+ """
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ comm = HttpCommunicator(application, "GET", "/simple_async_view")
+ await comm.get_response()
+ await comm.wait()
+
+ comm = HttpCommunicator(application, "OPTIONS", "/simple_async_view")
+ await comm.get_response()
+ await comm.wait()
+
+ comm = HttpCommunicator(application, "HEAD", "/simple_async_view")
+ await comm.get_response()
+ await comm.wait()
+
+ (event,) = events
+
+ assert len(events) == 1
+ assert event["request"]["method"] == "GET"
+
+
+@pytest.mark.parametrize("application", APPS)
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ django.VERSION < (3, 0), reason="Django ASGI support shipped in 3.0"
+)
+async def test_transaction_http_method_custom(sentry_init, capture_events, application):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ http_methods_to_capture=(
+ "OPTIONS",
+ "head",
+ ), # capitalization does not matter
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ comm = HttpCommunicator(application, "GET", "/simple_async_view")
+ await comm.get_response()
+ await comm.wait()
+
+ comm = HttpCommunicator(application, "OPTIONS", "/simple_async_view")
+ await comm.get_response()
+ await comm.wait()
+
+ comm = HttpCommunicator(application, "HEAD", "/simple_async_view")
+ await comm.get_response()
+ await comm.wait()
+
+ assert len(events) == 2
+
+ (event1, event2) = events
+ assert event1["request"]["method"] == "OPTIONS"
+ assert event2["request"]["method"] == "HEAD"
diff --git a/tests/integrations/django/django_helpers/__init__.py b/tests/integrations/django/django_helpers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/django/django_helpers/views.py b/tests/integrations/django/django_helpers/views.py
new file mode 100644
index 0000000000..a5759a5199
--- /dev/null
+++ b/tests/integrations/django/django_helpers/views.py
@@ -0,0 +1,9 @@
+from django.contrib.auth.models import User
+from django.http import HttpResponse
+from django.views.decorators.csrf import csrf_exempt
+
+
+@csrf_exempt
+def postgres_select_orm(request, *args, **kwargs):
+ user = User.objects.using("postgres").all().first()
+ return HttpResponse("ok {}".format(user))
diff --git a/tests/integrations/django/myapp/custom_urls.py b/tests/integrations/django/myapp/custom_urls.py
index 6dfa2ed2f1..5b2a1e428b 100644
--- a/tests/integrations/django/myapp/custom_urls.py
+++ b/tests/integrations/django/myapp/custom_urls.py
@@ -13,7 +13,6 @@
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
-from __future__ import absolute_import
try:
from django.urls import path
diff --git a/tests/integrations/django/myapp/settings.py b/tests/integrations/django/myapp/settings.py
index cc4d249082..d70adf63ec 100644
--- a/tests/integrations/django/myapp/settings.py
+++ b/tests/integrations/django/myapp/settings.py
@@ -10,7 +10,6 @@
https://docs.djangoproject.com/en/2.0/ref/settings/
"""
-
# We shouldn't access settings while setting up integrations. Initialize SDK
# here to provoke any errors that might occur.
import sentry_sdk
@@ -18,16 +17,9 @@
sentry_sdk.init(integrations=[DjangoIntegration()])
-
import os
-try:
- # Django >= 1.10
- from django.utils.deprecation import MiddlewareMixin
-except ImportError:
- # Not required for Django <= 1.9, see:
- # https://docs.djangoproject.com/en/1.10/topics/http/middleware/#upgrading-pre-django-1-10-style-middleware
- MiddlewareMixin = object
+from django.utils.deprecation import MiddlewareMixin
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -121,16 +113,26 @@ def middleware(request):
try:
import psycopg2 # noqa
+ db_engine = "django.db.backends.postgresql"
+ try:
+ from django.db.backends import postgresql # noqa: F401
+ except ImportError:
+ db_engine = "django.db.backends.postgresql_psycopg2"
+
DATABASES["postgres"] = {
- "ENGINE": "django.db.backends.postgresql_psycopg2",
- "NAME": os.environ["SENTRY_PYTHON_TEST_POSTGRES_NAME"],
- "USER": os.environ["SENTRY_PYTHON_TEST_POSTGRES_USER"],
- "PASSWORD": os.environ["SENTRY_PYTHON_TEST_POSTGRES_PASSWORD"],
- "HOST": "localhost",
- "PORT": 5432,
+ "ENGINE": db_engine,
+ "HOST": os.environ.get("SENTRY_PYTHON_TEST_POSTGRES_HOST", "localhost"),
+ "PORT": int(os.environ.get("SENTRY_PYTHON_TEST_POSTGRES_PORT", "5432")),
+ "USER": os.environ.get("SENTRY_PYTHON_TEST_POSTGRES_USER", "postgres"),
+ "PASSWORD": os.environ.get("SENTRY_PYTHON_TEST_POSTGRES_PASSWORD", "sentry"),
+ "NAME": os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_NAME", f"myapp_db_{os.getpid()}"
+ ),
}
except (ImportError, KeyError):
- pass
+ from sentry_sdk.utils import logger
+
+ logger.warning("No psycopg2 found, testing with SQLite.")
# Password validation
diff --git a/tests/integrations/django/myapp/signals.py b/tests/integrations/django/myapp/signals.py
new file mode 100644
index 0000000000..3dab92b8d9
--- /dev/null
+++ b/tests/integrations/django/myapp/signals.py
@@ -0,0 +1,15 @@
+from django.core import signals
+from django.dispatch import receiver
+
+myapp_custom_signal = signals.Signal()
+myapp_custom_signal_silenced = signals.Signal()
+
+
+@receiver(myapp_custom_signal)
+def signal_handler(sender, **kwargs):
+ assert sender == "hello"
+
+
+@receiver(myapp_custom_signal_silenced)
+def signal_handler_silenced(sender, **kwargs):
+ assert sender == "hello"
diff --git a/tests/integrations/django/myapp/templates/trace_meta.html b/tests/integrations/django/myapp/templates/trace_meta.html
new file mode 100644
index 0000000000..139fd16101
--- /dev/null
+++ b/tests/integrations/django/myapp/templates/trace_meta.html
@@ -0,0 +1 @@
+{{ sentry_trace_meta }}
diff --git a/tests/integrations/django/myapp/urls.py b/tests/integrations/django/myapp/urls.py
index ee357c843b..26d5a1bf2c 100644
--- a/tests/integrations/django/myapp/urls.py
+++ b/tests/integrations/django/myapp/urls.py
@@ -13,7 +13,6 @@
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
-from __future__ import absolute_import
try:
from django.urls import path
@@ -25,9 +24,18 @@ def path(path, *args, **kwargs):
from . import views
+from django_helpers import views as helper_views
urlpatterns = [
path("view-exc", views.view_exc, name="view_exc"),
+ path("view-exc-with-msg", views.view_exc_with_msg, name="view_exc_with_msg"),
+ path("cached-view", views.cached_view, name="cached_view"),
+ path("not-cached-view", views.not_cached_view, name="not_cached_view"),
+ path(
+ "view-with-cached-template-fragment",
+ views.view_with_cached_template_fragment,
+ name="view_with_cached_template_fragment",
+ ),
path(
"read-body-and-view-exc",
views.read_body_and_view_exc,
@@ -35,6 +43,8 @@ def path(path, *args, **kwargs):
),
path("middleware-exc", views.message, name="middleware_exc"),
path("message", views.message, name="message"),
+ path("nomessage", views.nomessage, name="nomessage"),
+ path("view-with-signal", views.view_with_signal, name="view_with_signal"),
path("mylogin", views.mylogin, name="mylogin"),
path("classbased", views.ClassBasedView.as_view(), name="classbased"),
path("sentryclass", views.SentryClassBasedView(), name="sentryclass"),
@@ -47,7 +57,40 @@ def path(path, *args, **kwargs):
path("template-exc", views.template_exc, name="template_exc"),
path("template-test", views.template_test, name="template_test"),
path("template-test2", views.template_test2, name="template_test2"),
+ path("template-test3", views.template_test3, name="template_test3"),
+ path("template-test4", views.template_test4, name="template_test4"),
path("postgres-select", views.postgres_select, name="postgres_select"),
+ path("postgres-select-slow", views.postgres_select_orm, name="postgres_select_orm"),
+ path(
+ "postgres-insert-no-autocommit",
+ views.postgres_insert_orm_no_autocommit,
+ name="postgres_insert_orm_no_autocommit",
+ ),
+ path(
+ "postgres-insert-no-autocommit-rollback",
+ views.postgres_insert_orm_no_autocommit_rollback,
+ name="postgres_insert_orm_no_autocommit_rollback",
+ ),
+ path(
+ "postgres-insert-atomic",
+ views.postgres_insert_orm_atomic,
+ name="postgres_insert_orm_atomic",
+ ),
+ path(
+ "postgres-insert-atomic-rollback",
+ views.postgres_insert_orm_atomic_rollback,
+ name="postgres_insert_orm_atomic_rollback",
+ ),
+ path(
+ "postgres-insert-atomic-exception",
+ views.postgres_insert_orm_atomic_exception,
+ name="postgres_insert_orm_atomic_exception",
+ ),
+ path(
+ "postgres-select-slow-from-supplement",
+ helper_views.postgres_select_orm,
+ name="postgres_select_slow_from_supplement",
+ ),
path(
"permission-denied-exc",
views.permission_denied_exc,
@@ -59,6 +102,11 @@ def path(path, *args, **kwargs):
name="csrf_hello_not_exempt",
),
path("sync/thread_ids", views.thread_ids_sync, name="thread_ids_sync"),
+ path(
+ "send-myapp-custom-signal",
+ views.send_myapp_custom_signal,
+ name="send_myapp_custom_signal",
+ ),
]
# async views
@@ -68,11 +116,21 @@ def path(path, *args, **kwargs):
if views.my_async_view is not None:
urlpatterns.append(path("my_async_view", views.my_async_view, name="my_async_view"))
+if views.my_async_view is not None:
+ urlpatterns.append(
+ path("simple_async_view", views.simple_async_view, name="simple_async_view")
+ )
+
if views.thread_ids_async is not None:
urlpatterns.append(
path("async/thread_ids", views.thread_ids_async, name="thread_ids_async")
)
+if views.post_echo_async is not None:
+ urlpatterns.append(
+ path("post_echo_async", views.post_echo_async, name="post_echo_async")
+ )
+
# rest framework
try:
urlpatterns.append(
diff --git a/tests/integrations/django/myapp/views.py b/tests/integrations/django/myapp/views.py
index dbf266e1ab..6d199a3740 100644
--- a/tests/integrations/django/myapp/views.py
+++ b/tests/integrations/django/myapp/views.py
@@ -1,17 +1,27 @@
+import asyncio
import json
import threading
-from django import VERSION
+from django.db import transaction
from django.contrib.auth import login
from django.contrib.auth.models import User
from django.core.exceptions import PermissionDenied
+from django.dispatch import Signal
from django.http import HttpResponse, HttpResponseNotFound, HttpResponseServerError
from django.shortcuts import render
+from django.template import Context, Template
from django.template.response import TemplateResponse
from django.utils.decorators import method_decorator
+from django.views.decorators.cache import cache_page
from django.views.decorators.csrf import csrf_exempt
from django.views.generic import ListView
+
+from tests.integrations.django.myapp.signals import (
+ myapp_custom_signal,
+ myapp_custom_signal_silenced,
+)
+
try:
from rest_framework.decorators import api_view
from rest_framework.response import Response
@@ -42,6 +52,7 @@ def rest_json_response(request):
import sentry_sdk
+from sentry_sdk import capture_message
@csrf_exempt
@@ -49,18 +60,46 @@ def view_exc(request):
1 / 0
+@csrf_exempt
+def view_exc_with_msg(request):
+ capture_message("oops")
+ 1 / 0
+
+
+@cache_page(60)
+def cached_view(request):
+ return HttpResponse("ok")
+
+
+def not_cached_view(request):
+ return HttpResponse("ok")
+
+
+def view_with_cached_template_fragment(request):
+ template = Template(
+ """{% load cache %}
+ Not cached content goes here.
+ {% cache 500 some_identifier %}
+ And here some cached content.
+ {% endcache %}
+ """
+ )
+ rendered = template.render(Context({}))
+ return HttpResponse(rendered)
+
+
# This is a "class based view" as previously found in the sentry codebase. The
# interesting property of this one is that csrf_exempt, as a class attribute,
# is not in __dict__, so regular use of functools.wraps will not forward the
# attribute.
-class SentryClassBasedView(object):
+class SentryClassBasedView:
csrf_exempt = True
def __call__(self, request):
return HttpResponse("ok")
-class SentryClassBasedViewWithCsrf(object):
+class SentryClassBasedViewWithCsrf:
def __call__(self, request):
return HttpResponse("ok")
@@ -77,6 +116,18 @@ def message(request):
return HttpResponse("ok")
+@csrf_exempt
+def nomessage(request):
+ return HttpResponse("ok")
+
+
+@csrf_exempt
+def view_with_signal(request):
+ custom_signal = Signal()
+ custom_signal.send(sender="hello")
+ return HttpResponse("ok")
+
+
@csrf_exempt
def mylogin(request):
user = User.objects.create_user("john", "lennon@thebeatles.com", "johnpassword")
@@ -87,7 +138,7 @@ def mylogin(request):
@csrf_exempt
def handler500(request):
- return HttpResponseServerError("Sentry error: %s" % sentry_sdk.last_event_id())
+ return HttpResponseServerError("Sentry error.")
class ClassBasedView(ListView):
@@ -95,7 +146,7 @@ class ClassBasedView(ListView):
@method_decorator(csrf_exempt)
def dispatch(self, request, *args, **kwargs):
- return super(ClassBasedView, self).dispatch(request, *args, **kwargs)
+ return super().dispatch(request, *args, **kwargs)
def head(self, *args, **kwargs):
sentry_sdk.capture_message("hi")
@@ -144,6 +195,43 @@ def template_test2(request, *args, **kwargs):
)
+@csrf_exempt
+def template_test3(request, *args, **kwargs):
+ traceparent = sentry_sdk.get_current_scope().get_traceparent()
+ if traceparent is None:
+ traceparent = sentry_sdk.get_isolation_scope().get_traceparent()
+
+ baggage = sentry_sdk.get_current_scope().get_baggage()
+ if baggage is None:
+ baggage = sentry_sdk.get_isolation_scope().get_baggage()
+
+ capture_message(traceparent + "\n" + baggage.serialize())
+ return render(request, "trace_meta.html", {})
+
+
+@csrf_exempt
+def template_test4(request, *args, **kwargs):
+ User.objects.create_user("john", "lennon@thebeatles.com", "johnpassword")
+ my_queryset = User.objects.all() # noqa
+
+ template_context = {
+ "user_age": 25,
+ "complex_context": my_queryset,
+ "complex_list": [1, 2, 3, my_queryset],
+ "complex_dict": {
+ "a": 1,
+ "d": my_queryset,
+ },
+ "none_context": None,
+ }
+
+ return TemplateResponse(
+ request,
+ "user_name.html",
+ template_context,
+ )
+
+
@csrf_exempt
def postgres_select(request, *args, **kwargs):
from django.db import connections
@@ -153,6 +241,79 @@ def postgres_select(request, *args, **kwargs):
return HttpResponse("ok")
+@csrf_exempt
+def postgres_select_orm(request, *args, **kwargs):
+ user = User.objects.using("postgres").all().first()
+ return HttpResponse("ok {}".format(user))
+
+
+@csrf_exempt
+def postgres_insert_orm_no_autocommit(request, *args, **kwargs):
+ transaction.set_autocommit(False, using="postgres")
+ try:
+ user = User.objects.db_manager("postgres").create_user(
+ username="user1",
+ )
+ transaction.commit(using="postgres")
+ except Exception:
+ transaction.rollback(using="postgres")
+ transaction.set_autocommit(True, using="postgres")
+ raise
+
+ transaction.set_autocommit(True, using="postgres")
+ return HttpResponse("ok {}".format(user))
+
+
+@csrf_exempt
+def postgres_insert_orm_no_autocommit_rollback(request, *args, **kwargs):
+ transaction.set_autocommit(False, using="postgres")
+ try:
+ user = User.objects.db_manager("postgres").create_user(
+ username="user1",
+ )
+ transaction.rollback(using="postgres")
+ except Exception:
+ transaction.rollback(using="postgres")
+ transaction.set_autocommit(True, using="postgres")
+ raise
+
+ transaction.set_autocommit(True, using="postgres")
+ return HttpResponse("ok {}".format(user))
+
+
+@csrf_exempt
+def postgres_insert_orm_atomic(request, *args, **kwargs):
+ with transaction.atomic(using="postgres"):
+ user = User.objects.db_manager("postgres").create_user(
+ username="user1",
+ )
+ return HttpResponse("ok {}".format(user))
+
+
+@csrf_exempt
+def postgres_insert_orm_atomic_rollback(request, *args, **kwargs):
+ with transaction.atomic(using="postgres"):
+ user = User.objects.db_manager("postgres").create_user(
+ username="user1",
+ )
+ transaction.set_rollback(True, using="postgres")
+ return HttpResponse("ok {}".format(user))
+
+
+@csrf_exempt
+def postgres_insert_orm_atomic_exception(request, *args, **kwargs):
+ try:
+ with transaction.atomic(using="postgres"):
+ user = User.objects.db_manager("postgres").create_user(
+ username="user1",
+ )
+ transaction.set_rollback(True, using="postgres")
+ 1 / 0
+ except ZeroDivisionError:
+ pass
+ return HttpResponse("ok {}".format(user))
+
+
@csrf_exempt
def permission_denied_exc(*args, **kwargs):
raise PermissionDenied("bye")
@@ -172,30 +333,40 @@ def thread_ids_sync(*args, **kwargs):
return HttpResponse(response)
-if VERSION >= (3, 1):
- # Use exec to produce valid Python 2
- exec(
- """async def async_message(request):
+async def async_message(request):
sentry_sdk.capture_message("hi")
- return HttpResponse("ok")"""
- )
+ return HttpResponse("ok")
+
- exec(
- """async def my_async_view(request):
- import asyncio
+async def my_async_view(request):
await asyncio.sleep(1)
- return HttpResponse('Hello World')"""
- )
+ return HttpResponse("Hello World")
+
- exec(
- """async def thread_ids_async(request):
- response = json.dumps({
- "main": threading.main_thread().ident,
- "active": threading.current_thread().ident,
- })
- return HttpResponse(response)"""
+async def simple_async_view(request):
+ return HttpResponse("Simple Hello World")
+
+
+async def thread_ids_async(request):
+ response = json.dumps(
+ {
+ "main": threading.main_thread().ident,
+ "active": threading.current_thread().ident,
+ }
)
-else:
- async_message = None
- my_async_view = None
- thread_ids_async = None
+ return HttpResponse(response)
+
+
+async def post_echo_async(request):
+ sentry_sdk.capture_message("hi")
+ return HttpResponse(request.body)
+
+
+post_echo_async.csrf_exempt = True
+
+
+@csrf_exempt
+def send_myapp_custom_signal(request):
+ myapp_custom_signal.send(sender="hello")
+ myapp_custom_signal_silenced.send(sender="hello")
+ return HttpResponse("ok")
diff --git a/tests/integrations/django/test_basic.py b/tests/integrations/django/test_basic.py
index bc464af836..1c6bb141bd 100644
--- a/tests/integrations/django/test_basic.py
+++ b/tests/integrations/django/test_basic.py
@@ -1,45 +1,48 @@
-from __future__ import absolute_import
-
+import inspect
import json
+import os
import pytest
-import pytest_django
+import re
+import sys
+
from functools import partial
+from unittest.mock import patch
from werkzeug.test import Client
+
from django import VERSION as DJANGO_VERSION
+
from django.contrib.auth.models import User
from django.core.management import execute_from_command_line
from django.db.utils import OperationalError, ProgrammingError, DataError
+from django.http.request import RawPostDataException
+from django.template.context import make_context
+from django.utils.functional import SimpleLazyObject
try:
from django.urls import reverse
except ImportError:
from django.core.urlresolvers import reverse
-from sentry_sdk._compat import PY2, PY310
-from sentry_sdk import capture_message, capture_exception, configure_scope
-from sentry_sdk.integrations.django import DjangoIntegration
+import sentry_sdk
+from sentry_sdk._compat import PY310
+from sentry_sdk import capture_message, capture_exception
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.integrations.django import (
+ DjangoIntegration,
+ DjangoRequestExtractor,
+ _set_db_data,
+)
from sentry_sdk.integrations.django.signals_handlers import _get_receiver_name
from sentry_sdk.integrations.executing import ExecutingIntegration
-
+from sentry_sdk.profiler.utils import get_frame_name
+from sentry_sdk.tracing import Span
+from tests.conftest import unpack_werkzeug_response
from tests.integrations.django.myapp.wsgi import application
+from tests.integrations.django.myapp.signals import myapp_custom_signal_silenced
+from tests.integrations.django.utils import pytest_mark_django_db_decorator
-# Hack to prevent from experimental feature introduced in version `4.3.0` in `pytest-django` that
-# requires explicit database allow from failing the test
-pytest_mark_django_db_decorator = partial(pytest.mark.django_db)
-try:
- pytest_version = tuple(map(int, pytest_django.__version__.split(".")))
- if pytest_version > (4, 2, 0):
- pytest_mark_django_db_decorator = partial(
- pytest.mark.django_db, databases="__all__"
- )
-except ValueError:
- if "dev" in pytest_django.__version__:
- pytest_mark_django_db_decorator = partial(
- pytest.mark.django_db, databases="__all__"
- )
-except AttributeError:
- pass
+DJANGO_VERSION = DJANGO_VERSION[:2]
@pytest.fixture
@@ -112,8 +115,9 @@ def test_middleware_exceptions(sentry_init, client, capture_exceptions):
def test_request_captured(sentry_init, client, capture_events):
sentry_init(integrations=[DjangoIntegration()], send_default_pii=True)
events = capture_events()
- content, status, headers = client.get(reverse("message"))
- assert b"".join(content) == b"ok"
+ content, status, headers = unpack_werkzeug_response(client.get(reverse("message")))
+
+ assert content == b"ok"
(event,) = events
assert event["transaction"] == "/message"
@@ -133,7 +137,9 @@ def test_transaction_with_class_view(sentry_init, client, capture_events):
send_default_pii=True,
)
events = capture_events()
- content, status, headers = client.head(reverse("classbased"))
+ content, status, headers = unpack_werkzeug_response(
+ client.head(reverse("classbased"))
+ )
assert status.lower() == "200 ok"
(event,) = events
@@ -144,18 +150,136 @@ def test_transaction_with_class_view(sentry_init, client, capture_events):
assert event["message"] == "hi"
+def test_has_trace_if_performance_enabled(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ http_methods_to_capture=("HEAD",),
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+ client.head(reverse("view_exc_with_msg"))
+
+ (msg_event, error_event, transaction_event) = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert transaction_event["contexts"]["trace"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+
+ assert (
+ msg_event["contexts"]["trace"]["trace_id"]
+ == error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+def test_has_trace_if_performance_disabled(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ )
+ events = capture_events()
+ client.head(reverse("view_exc_with_msg"))
+
+ (msg_event, error_event) = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert (
+ msg_event["contexts"]["trace"]["trace_id"]
+ == error_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+def test_trace_from_headers_if_performance_enabled(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ http_methods_to_capture=("HEAD",),
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ trace_id = "582b43a4192642f0b136d5159a501701"
+ sentry_trace_header = "{}-{}-{}".format(trace_id, "6e8f22c393e68f19", 1)
+
+ client.head(
+ reverse("view_exc_with_msg"), headers={"sentry-trace": sentry_trace_header}
+ )
+
+ (msg_event, error_event, transaction_event) = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert transaction_event["contexts"]["trace"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert transaction_event["contexts"]["trace"]["trace_id"] == trace_id
+
+
+def test_trace_from_headers_if_performance_disabled(
+ sentry_init, client, capture_events
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ http_methods_to_capture=("HEAD",),
+ )
+ ],
+ )
+
+ events = capture_events()
+
+ trace_id = "582b43a4192642f0b136d5159a501701"
+ sentry_trace_header = "{}-{}-{}".format(trace_id, "6e8f22c393e68f19", 1)
+
+ client.head(
+ reverse("view_exc_with_msg"), headers={"sentry-trace": sentry_trace_header}
+ )
+
+ (msg_event, error_event) = events
+
+ assert msg_event["contexts"]["trace"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert error_event["contexts"]["trace"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert msg_event["contexts"]["trace"]["trace_id"] == trace_id
+ assert error_event["contexts"]["trace"]["trace_id"] == trace_id
+
+
@pytest.mark.forked
-@pytest.mark.django_db
+@pytest_mark_django_db_decorator()
def test_user_captured(sentry_init, client, capture_events):
sentry_init(integrations=[DjangoIntegration()], send_default_pii=True)
events = capture_events()
- content, status, headers = client.get(reverse("mylogin"))
- assert b"".join(content) == b"ok"
+ content, status, headers = unpack_werkzeug_response(client.get(reverse("mylogin")))
+ assert content == b"ok"
assert not events
- content, status, headers = client.get(reverse("message"))
- assert b"".join(content) == b"ok"
+ content, status, headers = unpack_werkzeug_response(client.get(reverse("message")))
+ assert content == b"ok"
(event,) = events
@@ -167,7 +291,7 @@ def test_user_captured(sentry_init, client, capture_events):
@pytest.mark.forked
-@pytest.mark.django_db
+@pytest_mark_django_db_decorator()
def test_queryset_repr(sentry_init, capture_events):
sentry_init(integrations=[DjangoIntegration()])
events = capture_events()
@@ -189,10 +313,31 @@ def test_queryset_repr(sentry_init, capture_events):
)
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_context_nested_queryset_repr(sentry_init, capture_events):
+ sentry_init(integrations=[DjangoIntegration()])
+ events = capture_events()
+ User.objects.create_user("john", "lennon@thebeatles.com", "johnpassword")
+
+ try:
+ context = make_context({"entries": User.objects.all()}) # noqa
+ 1 / 0
+ except Exception:
+ capture_exception()
+
+ (event,) = events
+
+ (exception,) = event["exception"]["values"]
+ assert exception["type"] == "ZeroDivisionError"
+ (frame,) = exception["stacktrace"]["frames"]
+ assert "\n',
+ rendered_meta,
+ )
+ assert match is not None
+ assert match.group(1) == traceparent
+
+ rendered_baggage = match.group(2)
+ assert rendered_baggage == baggage
+
+
@pytest.mark.parametrize("with_executing_integration", [[], [ExecutingIntegration()]])
def test_template_exception(
sentry_init, client, capture_events, with_executing_integration
@@ -582,7 +835,9 @@ def test_template_exception(
sentry_init(integrations=[DjangoIntegration()] + with_executing_integration)
events = capture_events()
- content, status, headers = client.get(reverse("template_exc"))
+ content, status, headers = unpack_werkzeug_response(
+ client.get(reverse("template_exc"))
+ )
assert status.lower() == "500 internal server error"
(event,) = events
@@ -670,7 +925,7 @@ def test_does_not_capture_403(sentry_init, client, capture_events, endpoint):
sentry_init(integrations=[DjangoIntegration()])
events = capture_events()
- _, status, _ = client.get(reverse(endpoint))
+ _, status, _ = unpack_werkzeug_response(client.get(reverse(endpoint)))
assert status.lower() == "403 forbidden"
assert not events
@@ -702,6 +957,44 @@ def test_render_spans(sentry_init, client, capture_events, render_span_tree):
assert expected_line in render_span_tree(transaction)
+@pytest.mark.skipif(DJANGO_VERSION < (1, 9), reason="Requires Django >= 1.9")
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_render_spans_queryset_in_data(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=False,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("template_test4"))
+
+ (transaction,) = events
+ template_context = transaction["spans"][-1]["data"]["context"]
+
+ assert template_context["user_age"] == 25
+ assert template_context["complex_context"].startswith(
+ "= (1, 10):
EXPECTED_MIDDLEWARE_SPANS = """\
- op="http.server": description=null
@@ -730,7 +1023,7 @@ def test_render_spans(sentry_init, client, capture_events, render_span_tree):
def test_middleware_spans(sentry_init, client, capture_events, render_span_tree):
sentry_init(
integrations=[
- DjangoIntegration(signals_spans=False),
+ DjangoIntegration(middleware_spans=True, signals_spans=False),
],
traces_sample_rate=1.0,
)
@@ -747,7 +1040,7 @@ def test_middleware_spans(sentry_init, client, capture_events, render_span_tree)
def test_middleware_spans_disabled(sentry_init, client, capture_events):
sentry_init(
integrations=[
- DjangoIntegration(middleware_spans=False, signals_spans=False),
+ DjangoIntegration(signals_spans=False),
],
traces_sample_rate=1.0,
)
@@ -761,14 +1054,7 @@ def test_middleware_spans_disabled(sentry_init, client, capture_events):
assert not len(transaction["spans"])
-if DJANGO_VERSION >= (1, 10):
- EXPECTED_SIGNALS_SPANS = """\
-- op="http.server": description=null
- - op="event.django": description="django.db.reset_queries"
- - op="event.django": description="django.db.close_old_connections"\
-"""
-else:
- EXPECTED_SIGNALS_SPANS = """\
+EXPECTED_SIGNALS_SPANS = """\
- op="http.server": description=null
- op="event.django": description="django.db.reset_queries"
- op="event.django": description="django.db.close_old_connections"\
@@ -815,6 +1101,47 @@ def test_signals_spans_disabled(sentry_init, client, capture_events):
assert not transaction["spans"]
+EXPECTED_SIGNALS_SPANS_FILTERED = """\
+- op="http.server": description=null
+ - op="event.django": description="django.db.reset_queries"
+ - op="event.django": description="django.db.close_old_connections"
+ - op="event.django": description="tests.integrations.django.myapp.signals.signal_handler"\
+"""
+
+
+def test_signals_spans_filtering(sentry_init, client, capture_events, render_span_tree):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ middleware_spans=False,
+ signals_denylist=[
+ myapp_custom_signal_silenced,
+ ],
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("send_myapp_custom_signal"))
+
+ (transaction,) = events
+
+ assert render_span_tree(transaction) == EXPECTED_SIGNALS_SPANS_FILTERED
+
+ assert transaction["spans"][0]["op"] == "event.django"
+ assert transaction["spans"][0]["description"] == "django.db.reset_queries"
+
+ assert transaction["spans"][1]["op"] == "event.django"
+ assert transaction["spans"][1]["description"] == "django.db.close_old_connections"
+
+ assert transaction["spans"][2]["op"] == "event.django"
+ assert (
+ transaction["spans"][2]["description"]
+ == "tests.integrations.django.myapp.signals.signal_handler"
+ )
+
+
def test_csrf(sentry_init, client):
"""
Assert that CSRF view decorator works even with the view wrapped in our own
@@ -823,28 +1150,39 @@ def test_csrf(sentry_init, client):
sentry_init(integrations=[DjangoIntegration()])
- content, status, _headers = client.post(reverse("csrf_hello_not_exempt"))
+ content, status, _headers = unpack_werkzeug_response(
+ client.post(reverse("csrf_hello_not_exempt"))
+ )
assert status.lower() == "403 forbidden"
- content, status, _headers = client.post(reverse("sentryclass_csrf"))
+ content, status, _headers = unpack_werkzeug_response(
+ client.post(reverse("sentryclass_csrf"))
+ )
assert status.lower() == "403 forbidden"
- content, status, _headers = client.post(reverse("sentryclass"))
+ content, status, _headers = unpack_werkzeug_response(
+ client.post(reverse("sentryclass"))
+ )
assert status.lower() == "200 ok"
- assert b"".join(content) == b"ok"
+ assert content == b"ok"
- content, status, _headers = client.post(reverse("classbased"))
+ content, status, _headers = unpack_werkzeug_response(
+ client.post(reverse("classbased"))
+ )
assert status.lower() == "200 ok"
- assert b"".join(content) == b"ok"
+ assert content == b"ok"
- content, status, _headers = client.post(reverse("message"))
+ content, status, _headers = unpack_werkzeug_response(
+ client.post(reverse("message"))
+ )
assert status.lower() == "200 ok"
- assert b"".join(content) == b"ok"
+ assert content == b"ok"
@pytest.mark.skipif(DJANGO_VERSION < (2, 0), reason="Requires Django > 2.0")
+@pytest.mark.parametrize("middleware_spans", [False, True])
def test_custom_urlconf_middleware(
- settings, sentry_init, client, capture_events, render_span_tree
+ settings, sentry_init, client, capture_events, render_span_tree, middleware_spans
):
"""
Some middlewares (for instance in django-tenants) overwrite request.urlconf.
@@ -855,25 +1193,30 @@ def test_custom_urlconf_middleware(
settings.MIDDLEWARE.insert(0, urlconf)
client.application.load_middleware()
- sentry_init(integrations=[DjangoIntegration()], traces_sample_rate=1.0)
+ sentry_init(
+ integrations=[DjangoIntegration(middleware_spans=middleware_spans)],
+ traces_sample_rate=1.0,
+ )
events = capture_events()
- content, status, _headers = client.get("/custom/ok")
+ content, status, _headers = unpack_werkzeug_response(client.get("/custom/ok"))
assert status.lower() == "200 ok"
- assert b"".join(content) == b"custom ok"
+ assert content == b"custom ok"
event = events.pop(0)
assert event["transaction"] == "/custom/ok"
- assert "custom_urlconf_middleware" in render_span_tree(event)
+ if middleware_spans:
+ assert "custom_urlconf_middleware" in render_span_tree(event)
- _content, status, _headers = client.get("/custom/exc")
+ _content, status, _headers = unpack_werkzeug_response(client.get("/custom/exc"))
assert status.lower() == "500 internal server error"
error_event, transaction_event = events
assert error_event["transaction"] == "/custom/exc"
assert error_event["exception"]["values"][-1]["mechanism"]["type"] == "django"
assert transaction_event["transaction"] == "/custom/exc"
- assert "custom_urlconf_middleware" in render_span_tree(transaction_event)
+ if middleware_spans:
+ assert "custom_urlconf_middleware" in render_span_tree(transaction_event)
settings.MIDDLEWARE.pop(0)
@@ -884,13 +1227,10 @@ def dummy(a, b):
name = _get_receiver_name(dummy)
- if PY2:
- assert name == "tests.integrations.django.test_basic.dummy"
- else:
- assert (
- name
- == "tests.integrations.django.test_basic.test_get_receiver_name..dummy"
- )
+ assert (
+ name
+ == "tests.integrations.django.test_basic.test_get_receiver_name..dummy"
+ )
a_partial = partial(dummy)
name = _get_receiver_name(a_partial)
@@ -898,3 +1238,176 @@ def dummy(a, b):
assert name == "functools.partial()"
else:
assert name == "partial()"
+
+
+@pytest.mark.skipif(DJANGO_VERSION <= (1, 11), reason="Requires Django > 1.11")
+def test_span_origin(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ middleware_spans=True,
+ signals_spans=True,
+ cache_spans=True,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("view_with_signal"))
+
+ (transaction,) = events
+
+ assert transaction["contexts"]["trace"]["origin"] == "auto.http.django"
+
+ signal_span_found = False
+ for span in transaction["spans"]:
+ assert span["origin"] == "auto.http.django"
+ if span["op"] == "event.django":
+ signal_span_found = True
+
+ assert signal_span_found
+
+
+def test_transaction_http_method_default(sentry_init, client, capture_events):
+ """
+ By default OPTIONS and HEAD requests do not create a transaction.
+ """
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get("/nomessage")
+ client.options("/nomessage")
+ client.head("/nomessage")
+
+ (event,) = events
+
+ assert len(events) == 1
+ assert event["request"]["method"] == "GET"
+
+
+def test_transaction_http_method_custom(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ http_methods_to_capture=(
+ "OPTIONS",
+ "head",
+ ), # capitalization does not matter
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get("/nomessage")
+ client.options("/nomessage")
+ client.head("/nomessage")
+
+ assert len(events) == 2
+
+ (event1, event2) = events
+ assert event1["request"]["method"] == "OPTIONS"
+ assert event2["request"]["method"] == "HEAD"
+
+
+def test_ensures_spotlight_middleware_when_spotlight_is_enabled(sentry_init, settings):
+ """
+ Test that ensures if Spotlight is enabled, relevant SpotlightMiddleware
+ is added to middleware list in settings.
+ """
+ settings.DEBUG = True
+ original_middleware = frozenset(settings.MIDDLEWARE)
+
+ sentry_init(integrations=[DjangoIntegration()], spotlight=True)
+
+ added = frozenset(settings.MIDDLEWARE) ^ original_middleware
+
+ assert "sentry_sdk.spotlight.SpotlightMiddleware" in added
+
+
+def test_ensures_no_spotlight_middleware_when_env_killswitch_is_false(
+ monkeypatch, sentry_init, settings
+):
+ """
+ Test that ensures if Spotlight is enabled, but is set to a falsy value
+ the relevant SpotlightMiddleware is NOT added to middleware list in settings.
+ """
+ settings.DEBUG = True
+ monkeypatch.setenv("SENTRY_SPOTLIGHT_ON_ERROR", "no")
+
+ original_middleware = frozenset(settings.MIDDLEWARE)
+
+ sentry_init(integrations=[DjangoIntegration()], spotlight=True)
+
+ added = frozenset(settings.MIDDLEWARE) ^ original_middleware
+
+ assert "sentry_sdk.spotlight.SpotlightMiddleware" not in added
+
+
+def test_ensures_no_spotlight_middleware_when_no_spotlight(
+ monkeypatch, sentry_init, settings
+):
+ """
+ Test that ensures if Spotlight is not enabled
+ the relevant SpotlightMiddleware is NOT added to middleware list in settings.
+ """
+ settings.DEBUG = True
+
+ # We should NOT have the middleware even if the env var is truthy if Spotlight is off
+ monkeypatch.setenv("SENTRY_SPOTLIGHT_ON_ERROR", "1")
+
+ original_middleware = frozenset(settings.MIDDLEWARE)
+
+ sentry_init(integrations=[DjangoIntegration()], spotlight=False)
+
+ added = frozenset(settings.MIDDLEWARE) ^ original_middleware
+
+ assert "sentry_sdk.spotlight.SpotlightMiddleware" not in added
+
+
+def test_get_frame_name_when_in_lazy_object():
+ allowed_to_init = False
+
+ class SimpleLazyObjectWrapper(SimpleLazyObject):
+ def unproxied_method(self):
+ """
+ For testing purposes. We inject a method on the SimpleLazyObject
+ class so if python is executing this method, we should get
+ this class instead of the wrapped class and avoid evaluating
+ the wrapped object too early.
+ """
+ return inspect.currentframe()
+
+ class GetFrame:
+ def __init__(self):
+ assert allowed_to_init, "GetFrame not permitted to initialize yet"
+
+ def proxied_method(self):
+ """
+ For testing purposes. We add an proxied method on the instance
+ class so if python is executing this method, we should get
+ this class instead of the wrapper class.
+ """
+ return inspect.currentframe()
+
+ instance = SimpleLazyObjectWrapper(lambda: GetFrame())
+
+ assert get_frame_name(instance.unproxied_method()) == (
+ "SimpleLazyObjectWrapper.unproxied_method"
+ if sys.version_info < (3, 11)
+ else "test_get_frame_name_when_in_lazy_object..SimpleLazyObjectWrapper.unproxied_method"
+ )
+
+ # Now that we're about to access an instance method on the wrapped class,
+ # we should permit initializing it
+ allowed_to_init = True
+
+ assert get_frame_name(instance.proxied_method()) == (
+ "GetFrame.proxied_method"
+ if sys.version_info < (3, 11)
+ else "test_get_frame_name_when_in_lazy_object..GetFrame.proxied_method"
+ )
diff --git a/tests/integrations/django/test_cache_module.py b/tests/integrations/django/test_cache_module.py
new file mode 100644
index 0000000000..01b97c1302
--- /dev/null
+++ b/tests/integrations/django/test_cache_module.py
@@ -0,0 +1,696 @@
+import os
+import random
+import uuid
+
+import pytest
+from django import VERSION as DJANGO_VERSION
+from werkzeug.test import Client
+
+try:
+ from django.urls import reverse
+except ImportError:
+ from django.core.urlresolvers import reverse
+
+import sentry_sdk
+from sentry_sdk.integrations.django import DjangoIntegration
+from sentry_sdk.integrations.django.caching import _get_span_description
+from tests.integrations.django.myapp.wsgi import application
+from tests.integrations.django.utils import pytest_mark_django_db_decorator
+
+
+DJANGO_VERSION = DJANGO_VERSION[:2]
+
+
+@pytest.fixture
+def client():
+ return Client(application)
+
+
+@pytest.fixture
+def use_django_caching(settings):
+ settings.CACHES = {
+ "default": {
+ "BACKEND": "django.core.cache.backends.locmem.LocMemCache",
+ "LOCATION": "unique-snowflake-%s" % random.randint(1, 1000000),
+ }
+ }
+
+
+@pytest.fixture
+def use_django_caching_with_middlewares(settings):
+ settings.CACHES = {
+ "default": {
+ "BACKEND": "django.core.cache.backends.locmem.LocMemCache",
+ "LOCATION": "unique-snowflake-%s" % random.randint(1, 1000000),
+ }
+ }
+ if hasattr(settings, "MIDDLEWARE"):
+ middleware = settings.MIDDLEWARE
+ elif hasattr(settings, "MIDDLEWARE_CLASSES"):
+ middleware = settings.MIDDLEWARE_CLASSES
+ else:
+ middleware = None
+
+ if middleware is not None:
+ middleware.insert(0, "django.middleware.cache.UpdateCacheMiddleware")
+ middleware.append("django.middleware.cache.FetchFromCacheMiddleware")
+
+
+@pytest.fixture
+def use_django_caching_with_port(settings):
+ settings.CACHES = {
+ "default": {
+ "BACKEND": "django.core.cache.backends.dummy.DummyCache",
+ "LOCATION": "redis://username:password@127.0.0.1:6379",
+ }
+ }
+
+
+@pytest.fixture
+def use_django_caching_without_port(settings):
+ settings.CACHES = {
+ "default": {
+ "BACKEND": "django.core.cache.backends.dummy.DummyCache",
+ "LOCATION": "redis://example.com",
+ }
+ }
+
+
+@pytest.fixture
+def use_django_caching_with_cluster(settings):
+ settings.CACHES = {
+ "default": {
+ "BACKEND": "django.core.cache.backends.dummy.DummyCache",
+ "LOCATION": [
+ "redis://127.0.0.1:6379",
+ "redis://127.0.0.2:6378",
+ "redis://127.0.0.3:6377",
+ ],
+ }
+ }
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+@pytest.mark.skipif(DJANGO_VERSION < (1, 9), reason="Requires Django >= 1.9")
+def test_cache_spans_disabled_middleware(
+ sentry_init, client, capture_events, use_django_caching_with_middlewares
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=False,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("not_cached_view"))
+ client.get(reverse("not_cached_view"))
+
+ (first_event, second_event) = events
+ assert len(first_event["spans"]) == 0
+ assert len(second_event["spans"]) == 0
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+@pytest.mark.skipif(DJANGO_VERSION < (1, 9), reason="Requires Django >= 1.9")
+def test_cache_spans_disabled_decorator(
+ sentry_init, client, capture_events, use_django_caching
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=False,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("cached_view"))
+ client.get(reverse("cached_view"))
+
+ (first_event, second_event) = events
+ assert len(first_event["spans"]) == 0
+ assert len(second_event["spans"]) == 0
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+@pytest.mark.skipif(DJANGO_VERSION < (1, 9), reason="Requires Django >= 1.9")
+def test_cache_spans_disabled_templatetag(
+ sentry_init, client, capture_events, use_django_caching
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=False,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("view_with_cached_template_fragment"))
+ client.get(reverse("view_with_cached_template_fragment"))
+
+ (first_event, second_event) = events
+ assert len(first_event["spans"]) == 0
+ assert len(second_event["spans"]) == 0
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+@pytest.mark.skipif(DJANGO_VERSION < (1, 9), reason="Requires Django >= 1.9")
+def test_cache_spans_middleware(
+ sentry_init, client, capture_events, use_django_caching_with_middlewares
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+
+ client.application.load_middleware()
+ events = capture_events()
+
+ client.get(reverse("not_cached_view"))
+ client.get(reverse("not_cached_view"))
+
+ (first_event, second_event) = events
+ # first_event - cache.get
+ assert first_event["spans"][0]["op"] == "cache.get"
+ assert first_event["spans"][0]["description"].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert first_event["spans"][0]["data"]["network.peer.address"] is not None
+ assert first_event["spans"][0]["data"]["cache.key"][0].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert not first_event["spans"][0]["data"]["cache.hit"]
+ assert "cache.item_size" not in first_event["spans"][0]["data"]
+ # first_event - cache.put
+ assert first_event["spans"][1]["op"] == "cache.put"
+ assert first_event["spans"][1]["description"].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert first_event["spans"][1]["data"]["network.peer.address"] is not None
+ assert first_event["spans"][1]["data"]["cache.key"][0].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert "cache.hit" not in first_event["spans"][1]["data"]
+ assert first_event["spans"][1]["data"]["cache.item_size"] == 2
+ # second_event - cache.get
+ assert second_event["spans"][0]["op"] == "cache.get"
+ assert second_event["spans"][0]["description"].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert second_event["spans"][0]["data"]["network.peer.address"] is not None
+ assert second_event["spans"][0]["data"]["cache.key"][0].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert second_event["spans"][0]["data"]["cache.hit"]
+ assert second_event["spans"][0]["data"]["cache.item_size"] == 2
+ # second_event - cache.get 2
+ assert second_event["spans"][1]["op"] == "cache.get"
+ assert second_event["spans"][1]["description"].startswith(
+ "views.decorators.cache.cache_page."
+ )
+ assert second_event["spans"][1]["data"]["network.peer.address"] is not None
+ assert second_event["spans"][1]["data"]["cache.key"][0].startswith(
+ "views.decorators.cache.cache_page."
+ )
+ assert second_event["spans"][1]["data"]["cache.hit"]
+ assert second_event["spans"][1]["data"]["cache.item_size"] == 58
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+@pytest.mark.skipif(DJANGO_VERSION < (1, 9), reason="Requires Django >= 1.9")
+def test_cache_spans_decorator(sentry_init, client, capture_events, use_django_caching):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("cached_view"))
+ client.get(reverse("cached_view"))
+
+ (first_event, second_event) = events
+ # first_event - cache.get
+ assert first_event["spans"][0]["op"] == "cache.get"
+ assert first_event["spans"][0]["description"].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert first_event["spans"][0]["data"]["network.peer.address"] is not None
+ assert first_event["spans"][0]["data"]["cache.key"][0].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert not first_event["spans"][0]["data"]["cache.hit"]
+ assert "cache.item_size" not in first_event["spans"][0]["data"]
+ # first_event - cache.put
+ assert first_event["spans"][1]["op"] == "cache.put"
+ assert first_event["spans"][1]["description"].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert first_event["spans"][1]["data"]["network.peer.address"] is not None
+ assert first_event["spans"][1]["data"]["cache.key"][0].startswith(
+ "views.decorators.cache.cache_header."
+ )
+ assert "cache.hit" not in first_event["spans"][1]["data"]
+ assert first_event["spans"][1]["data"]["cache.item_size"] == 2
+ # second_event - cache.get
+ assert second_event["spans"][1]["op"] == "cache.get"
+ assert second_event["spans"][1]["description"].startswith(
+ "views.decorators.cache.cache_page."
+ )
+ assert second_event["spans"][1]["data"]["network.peer.address"] is not None
+ assert second_event["spans"][1]["data"]["cache.key"][0].startswith(
+ "views.decorators.cache.cache_page."
+ )
+ assert second_event["spans"][1]["data"]["cache.hit"]
+ assert second_event["spans"][1]["data"]["cache.item_size"] == 58
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+@pytest.mark.skipif(DJANGO_VERSION < (1, 9), reason="Requires Django >= 1.9")
+def test_cache_spans_templatetag(
+ sentry_init, client, capture_events, use_django_caching
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("view_with_cached_template_fragment"))
+ client.get(reverse("view_with_cached_template_fragment"))
+
+ (first_event, second_event) = events
+ assert len(first_event["spans"]) == 2
+ # first_event - cache.get
+ assert first_event["spans"][0]["op"] == "cache.get"
+ assert first_event["spans"][0]["description"].startswith(
+ "template.cache.some_identifier."
+ )
+ assert first_event["spans"][0]["data"]["network.peer.address"] is not None
+ assert first_event["spans"][0]["data"]["cache.key"][0].startswith(
+ "template.cache.some_identifier."
+ )
+ assert not first_event["spans"][0]["data"]["cache.hit"]
+ assert "cache.item_size" not in first_event["spans"][0]["data"]
+ # first_event - cache.put
+ assert first_event["spans"][1]["op"] == "cache.put"
+ assert first_event["spans"][1]["description"].startswith(
+ "template.cache.some_identifier."
+ )
+ assert first_event["spans"][1]["data"]["network.peer.address"] is not None
+ assert first_event["spans"][1]["data"]["cache.key"][0].startswith(
+ "template.cache.some_identifier."
+ )
+ assert "cache.hit" not in first_event["spans"][1]["data"]
+ assert first_event["spans"][1]["data"]["cache.item_size"] == 51
+ # second_event - cache.get
+ assert second_event["spans"][0]["op"] == "cache.get"
+ assert second_event["spans"][0]["description"].startswith(
+ "template.cache.some_identifier."
+ )
+ assert second_event["spans"][0]["data"]["network.peer.address"] is not None
+ assert second_event["spans"][0]["data"]["cache.key"][0].startswith(
+ "template.cache.some_identifier."
+ )
+ assert second_event["spans"][0]["data"]["cache.hit"]
+ assert second_event["spans"][0]["data"]["cache.item_size"] == 51
+
+
+@pytest.mark.parametrize(
+ "method_name, args, kwargs, expected_description",
+ [
+ (None, None, None, ""),
+ ("get", None, None, ""),
+ ("get", [], {}, ""),
+ ("get", ["bla", "blub", "foo"], {}, "bla"),
+ ("get", [uuid.uuid4().bytes], {}, ""),
+ (
+ "get_many",
+ [["bla1", "bla2", "bla3"], "blub", "foo"],
+ {},
+ "bla1, bla2, bla3",
+ ),
+ (
+ "get_many",
+ [["bla:1", "bla:2", "bla:3"], "blub", "foo"],
+ {"key": "bar"},
+ "bla:1, bla:2, bla:3",
+ ),
+ ("get", [], {"key": "bar"}, "bar"),
+ (
+ "get",
+ "something",
+ {},
+ "s",
+ ), # this case should never happen, just making sure that we are not raising an exception in that case.
+ ],
+)
+def test_cache_spans_get_span_description(
+ method_name, args, kwargs, expected_description
+):
+ assert _get_span_description(method_name, args, kwargs) == expected_description
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_cache_spans_location_with_port(
+ sentry_init, client, capture_events, use_django_caching_with_port
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("cached_view"))
+ client.get(reverse("cached_view"))
+
+ for event in events:
+ for span in event["spans"]:
+ assert (
+ span["data"]["network.peer.address"] == "redis://127.0.0.1"
+ ) # Note: the username/password are not included in the address
+ assert span["data"]["network.peer.port"] == 6379
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_cache_spans_location_without_port(
+ sentry_init, client, capture_events, use_django_caching_without_port
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("cached_view"))
+ client.get(reverse("cached_view"))
+
+ for event in events:
+ for span in event["spans"]:
+ assert span["data"]["network.peer.address"] == "redis://example.com"
+ assert "network.peer.port" not in span["data"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_cache_spans_location_with_cluster(
+ sentry_init, client, capture_events, use_django_caching_with_cluster
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("cached_view"))
+ client.get(reverse("cached_view"))
+
+ for event in events:
+ for span in event["spans"]:
+ # because it is a cluster we do not know what host is actually accessed, so we omit the data
+ assert "network.peer.address" not in span["data"].keys()
+ assert "network.peer.port" not in span["data"].keys()
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_cache_spans_item_size(sentry_init, client, capture_events, use_django_caching):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("cached_view"))
+ client.get(reverse("cached_view"))
+
+ (first_event, second_event) = events
+ assert len(first_event["spans"]) == 3
+ assert first_event["spans"][0]["op"] == "cache.get"
+ assert not first_event["spans"][0]["data"]["cache.hit"]
+ assert "cache.item_size" not in first_event["spans"][0]["data"]
+
+ assert first_event["spans"][1]["op"] == "cache.put"
+ assert "cache.hit" not in first_event["spans"][1]["data"]
+ assert first_event["spans"][1]["data"]["cache.item_size"] == 2
+
+ assert first_event["spans"][2]["op"] == "cache.put"
+ assert "cache.hit" not in first_event["spans"][2]["data"]
+ assert first_event["spans"][2]["data"]["cache.item_size"] == 58
+
+ assert len(second_event["spans"]) == 2
+ assert second_event["spans"][0]["op"] == "cache.get"
+ assert second_event["spans"][0]["data"]["cache.hit"]
+ assert second_event["spans"][0]["data"]["cache.item_size"] == 2
+
+ assert second_event["spans"][1]["op"] == "cache.get"
+ assert second_event["spans"][1]["data"]["cache.hit"]
+ assert second_event["spans"][1]["data"]["cache.item_size"] == 58
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_cache_spans_get_custom_default(
+ sentry_init, capture_events, use_django_caching
+):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ id = os.getpid()
+
+ from django.core.cache import cache
+
+ with sentry_sdk.start_transaction():
+ cache.set(f"S{id}", "Sensitive1")
+ cache.set(f"S{id + 1}", "")
+
+ cache.get(f"S{id}", "null")
+ cache.get(f"S{id}", default="null")
+
+ cache.get(f"S{id + 1}", "null")
+ cache.get(f"S{id + 1}", default="null")
+
+ cache.get(f"S{id + 2}", "null")
+ cache.get(f"S{id + 2}", default="null")
+
+ (transaction,) = events
+ assert len(transaction["spans"]) == 8
+
+ assert transaction["spans"][0]["op"] == "cache.put"
+ assert transaction["spans"][0]["description"] == f"S{id}"
+
+ assert transaction["spans"][1]["op"] == "cache.put"
+ assert transaction["spans"][1]["description"] == f"S{id + 1}"
+
+ for span in (transaction["spans"][2], transaction["spans"][3]):
+ assert span["op"] == "cache.get"
+ assert span["description"] == f"S{id}"
+ assert span["data"]["cache.hit"]
+ assert span["data"]["cache.item_size"] == 10
+
+ for span in (transaction["spans"][4], transaction["spans"][5]):
+ assert span["op"] == "cache.get"
+ assert span["description"] == f"S{id + 1}"
+ assert span["data"]["cache.hit"]
+ assert span["data"]["cache.item_size"] == 0
+
+ for span in (transaction["spans"][6], transaction["spans"][7]):
+ assert span["op"] == "cache.get"
+ assert span["description"] == f"S{id + 2}"
+ assert not span["data"]["cache.hit"]
+ assert "cache.item_size" not in span["data"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_cache_spans_get_many(sentry_init, capture_events, use_django_caching):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ id = os.getpid()
+
+ from django.core.cache import cache
+
+ with sentry_sdk.start_transaction():
+ cache.get_many([f"S{id}", f"S{id + 1}"])
+ cache.set(f"S{id}", "Sensitive1")
+ cache.get_many([f"S{id}", f"S{id + 1}"])
+
+ (transaction,) = events
+ assert len(transaction["spans"]) == 7
+
+ assert transaction["spans"][0]["op"] == "cache.get"
+ assert transaction["spans"][0]["description"] == f"S{id}, S{id + 1}"
+ assert not transaction["spans"][0]["data"]["cache.hit"]
+
+ assert transaction["spans"][1]["op"] == "cache.get"
+ assert transaction["spans"][1]["description"] == f"S{id}"
+ assert not transaction["spans"][1]["data"]["cache.hit"]
+
+ assert transaction["spans"][2]["op"] == "cache.get"
+ assert transaction["spans"][2]["description"] == f"S{id + 1}"
+ assert not transaction["spans"][2]["data"]["cache.hit"]
+
+ assert transaction["spans"][3]["op"] == "cache.put"
+ assert transaction["spans"][3]["description"] == f"S{id}"
+
+ assert transaction["spans"][4]["op"] == "cache.get"
+ assert transaction["spans"][4]["description"] == f"S{id}, S{id + 1}"
+ assert transaction["spans"][4]["data"]["cache.hit"]
+
+ assert transaction["spans"][5]["op"] == "cache.get"
+ assert transaction["spans"][5]["description"] == f"S{id}"
+ assert transaction["spans"][5]["data"]["cache.hit"]
+
+ assert transaction["spans"][6]["op"] == "cache.get"
+ assert transaction["spans"][6]["description"] == f"S{id + 1}"
+ assert not transaction["spans"][6]["data"]["cache.hit"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+def test_cache_spans_set_many(sentry_init, capture_events, use_django_caching):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ cache_spans=True,
+ middleware_spans=False,
+ signals_spans=False,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ id = os.getpid()
+
+ from django.core.cache import cache
+
+ with sentry_sdk.start_transaction():
+ cache.set_many({f"S{id}": "Sensitive1", f"S{id + 1}": "Sensitive2"})
+ cache.get(f"S{id}")
+
+ (transaction,) = events
+ assert len(transaction["spans"]) == 4
+
+ assert transaction["spans"][0]["op"] == "cache.put"
+ assert transaction["spans"][0]["description"] == f"S{id}, S{id + 1}"
+
+ assert transaction["spans"][1]["op"] == "cache.put"
+ assert transaction["spans"][1]["description"] == f"S{id}"
+
+ assert transaction["spans"][2]["op"] == "cache.put"
+ assert transaction["spans"][2]["description"] == f"S{id + 1}"
+
+ assert transaction["spans"][3]["op"] == "cache.get"
+ assert transaction["spans"][3]["description"] == f"S{id}"
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator()
+@pytest.mark.skipif(DJANGO_VERSION <= (1, 11), reason="Requires Django > 1.11")
+def test_span_origin_cache(sentry_init, client, capture_events, use_django_caching):
+ sentry_init(
+ integrations=[
+ DjangoIntegration(
+ middleware_spans=True,
+ signals_spans=True,
+ cache_spans=True,
+ )
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client.get(reverse("cached_view"))
+
+ (transaction,) = events
+
+ assert transaction["contexts"]["trace"]["origin"] == "auto.http.django"
+
+ cache_span_found = False
+ for span in transaction["spans"]:
+ assert span["origin"] == "auto.http.django"
+ if span["op"].startswith("cache."):
+ cache_span_found = True
+
+ assert cache_span_found
diff --git a/tests/integrations/django/test_data_scrubbing.py b/tests/integrations/django/test_data_scrubbing.py
index c0ab14ae63..128da9b97e 100644
--- a/tests/integrations/django/test_data_scrubbing.py
+++ b/tests/integrations/django/test_data_scrubbing.py
@@ -1,12 +1,11 @@
-from functools import partial
import pytest
-import pytest_django
from werkzeug.test import Client
from sentry_sdk.integrations.django import DjangoIntegration
-
+from tests.conftest import werkzeug_set_cookie
from tests.integrations.django.myapp.wsgi import application
+from tests.integrations.django.utils import pytest_mark_django_db_decorator
try:
from django.urls import reverse
@@ -14,24 +13,6 @@
from django.core.urlresolvers import reverse
-# Hack to prevent from experimental feature introduced in version `4.3.0` in `pytest-django` that
-# requires explicit database allow from failing the test
-pytest_mark_django_db_decorator = partial(pytest.mark.django_db)
-try:
- pytest_version = tuple(map(int, pytest_django.__version__.split(".")))
- if pytest_version > (4, 2, 0):
- pytest_mark_django_db_decorator = partial(
- pytest.mark.django_db, databases="__all__"
- )
-except ValueError:
- if "dev" in pytest_django.__version__:
- pytest_mark_django_db_decorator = partial(
- pytest.mark.django_db, databases="__all__"
- )
-except AttributeError:
- pass
-
-
@pytest.fixture
def client():
return Client(application)
@@ -46,9 +27,9 @@ def test_scrub_django_session_cookies_removed(
):
sentry_init(integrations=[DjangoIntegration()], send_default_pii=False)
events = capture_events()
- client.set_cookie("localhost", "sessionid", "123")
- client.set_cookie("localhost", "csrftoken", "456")
- client.set_cookie("localhost", "foo", "bar")
+ werkzeug_set_cookie(client, "localhost", "sessionid", "123")
+ werkzeug_set_cookie(client, "localhost", "csrftoken", "456")
+ werkzeug_set_cookie(client, "localhost", "foo", "bar")
client.get(reverse("view_exc"))
(event,) = events
@@ -64,9 +45,9 @@ def test_scrub_django_session_cookies_filtered(
):
sentry_init(integrations=[DjangoIntegration()], send_default_pii=True)
events = capture_events()
- client.set_cookie("localhost", "sessionid", "123")
- client.set_cookie("localhost", "csrftoken", "456")
- client.set_cookie("localhost", "foo", "bar")
+ werkzeug_set_cookie(client, "localhost", "sessionid", "123")
+ werkzeug_set_cookie(client, "localhost", "csrftoken", "456")
+ werkzeug_set_cookie(client, "localhost", "foo", "bar")
client.get(reverse("view_exc"))
(event,) = events
@@ -90,9 +71,9 @@ def test_scrub_django_custom_session_cookies_filtered(
sentry_init(integrations=[DjangoIntegration()], send_default_pii=True)
events = capture_events()
- client.set_cookie("localhost", "my_sess", "123")
- client.set_cookie("localhost", "csrf_secret", "456")
- client.set_cookie("localhost", "foo", "bar")
+ werkzeug_set_cookie(client, "localhost", "my_sess", "123")
+ werkzeug_set_cookie(client, "localhost", "csrf_secret", "456")
+ werkzeug_set_cookie(client, "localhost", "foo", "bar")
client.get(reverse("view_exc"))
(event,) = events
diff --git a/tests/integrations/django/test_db_query_data.py b/tests/integrations/django/test_db_query_data.py
new file mode 100644
index 0000000000..41ad9d5e1c
--- /dev/null
+++ b/tests/integrations/django/test_db_query_data.py
@@ -0,0 +1,526 @@
+import os
+
+import pytest
+from datetime import datetime
+from unittest import mock
+
+from django import VERSION as DJANGO_VERSION
+from django.db import connections
+
+try:
+ from django.urls import reverse
+except ImportError:
+ from django.core.urlresolvers import reverse
+
+from werkzeug.test import Client
+
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.integrations.django import DjangoIntegration
+from sentry_sdk.tracing_utils import record_sql_queries
+
+from tests.conftest import unpack_werkzeug_response
+from tests.integrations.django.utils import pytest_mark_django_db_decorator
+from tests.integrations.django.myapp.wsgi import application
+
+
+@pytest.fixture
+def client():
+ return Client(application)
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_query_source_disabled(sentry_init, client, capture_events):
+ sentry_options = {
+ "integrations": [DjangoIntegration()],
+ "send_default_pii": True,
+ "traces_sample_rate": 1.0,
+ "enable_db_query_source": False,
+ "db_query_source_threshold_ms": 0,
+ }
+
+ sentry_init(**sentry_options)
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ _, status, _ = unpack_werkzeug_response(client.get(reverse("postgres_select_orm")))
+ assert status == "200 OK"
+
+ (event,) = events
+ for span in event["spans"]:
+ if span.get("op") == "db" and "auth_user" in span.get("description"):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+@pytest.mark.parametrize("enable_db_query_source", [None, True])
+def test_query_source_enabled(
+ sentry_init, client, capture_events, enable_db_query_source
+):
+ sentry_options = {
+ "integrations": [DjangoIntegration()],
+ "send_default_pii": True,
+ "traces_sample_rate": 1.0,
+ "db_query_source_threshold_ms": 0,
+ }
+
+ if enable_db_query_source is not None:
+ sentry_options["enable_db_query_source"] = enable_db_query_source
+
+ sentry_init(**sentry_options)
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ _, status, _ = unpack_werkzeug_response(client.get(reverse("postgres_select_orm")))
+ assert status == "200 OK"
+
+ (event,) = events
+ for span in event["spans"]:
+ if span.get("op") == "db" and "auth_user" in span.get("description"):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_query_source(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ _, status, _ = unpack_werkzeug_response(client.get(reverse("postgres_select_orm")))
+ assert status == "200 OK"
+
+ (event,) = events
+ for span in event["spans"]:
+ if span.get("op") == "db" and "auth_user" in span.get("description"):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE)
+ == "tests.integrations.django.myapp.views"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/django/myapp/views.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "postgres_select_orm"
+
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_query_source_with_module_in_search_path(sentry_init, client, capture_events):
+ """
+ Test that query source is relative to the path of the module it ran in
+ """
+ client = Client(application)
+
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ _, status, _ = unpack_werkzeug_response(
+ client.get(reverse("postgres_select_slow_from_supplement"))
+ )
+ assert status == "200 OK"
+
+ (event,) = events
+ for span in event["spans"]:
+ if span.get("op") == "db" and "auth_user" in span.get("description"):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "django_helpers.views"
+ assert data.get(SPANDATA.CODE_FILEPATH) == "django_helpers/views.py"
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "postgres_select_orm"
+
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_query_source_with_in_app_exclude(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=0,
+ in_app_exclude=["tests.integrations.django.myapp.views"],
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ _, status, _ = unpack_werkzeug_response(client.get(reverse("postgres_select_orm")))
+ assert status == "200 OK"
+
+ (event,) = events
+ for span in event["spans"]:
+ if span.get("op") == "db" and "auth_user" in span.get("description"):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+
+ if DJANGO_VERSION >= (1, 11):
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE)
+ == "tests.integrations.django.myapp.settings"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/django/myapp/settings.py"
+ )
+ assert data.get(SPANDATA.CODE_FUNCTION) == "middleware"
+ else:
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE)
+ == "tests.integrations.django.test_db_query_data"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/django/test_db_query_data.py"
+ )
+ assert (
+ data.get(SPANDATA.CODE_FUNCTION)
+ == "test_query_source_with_in_app_exclude"
+ )
+
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_query_source_with_in_app_include(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=0,
+ in_app_include=["django"],
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ _, status, _ = unpack_werkzeug_response(client.get(reverse("postgres_select_orm")))
+ assert status == "200 OK"
+
+ (event,) = events
+ for span in event["spans"]:
+ if span.get("op") == "db" and "auth_user" in span.get("description"):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "django.db.models.sql.compiler"
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "django/db/models/sql/compiler.py"
+ )
+ assert data.get(SPANDATA.CODE_FUNCTION) == "execute_sql"
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_no_query_source_if_duration_too_short(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=100,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ class fake_record_sql_queries: # noqa: N801
+ def __init__(self, *args, **kwargs):
+ with record_sql_queries(*args, **kwargs) as span:
+ self.span = span
+
+ self.span.start_timestamp = datetime(2024, 1, 1, microsecond=0)
+ self.span.timestamp = datetime(2024, 1, 1, microsecond=99999)
+
+ def __enter__(self):
+ return self.span
+
+ def __exit__(self, type, value, traceback):
+ pass
+
+ with mock.patch(
+ "sentry_sdk.integrations.django.record_sql_queries",
+ fake_record_sql_queries,
+ ):
+ _, status, _ = unpack_werkzeug_response(
+ client.get(reverse("postgres_select_orm"))
+ )
+
+ assert status == "200 OK"
+
+ (event,) = events
+ for span in event["spans"]:
+ if span.get("op") == "db" and "auth_user" in span.get("description"):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_query_source_if_duration_over_threshold(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=100,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ class fake_record_sql_queries: # noqa: N801
+ def __init__(self, *args, **kwargs):
+ with record_sql_queries(*args, **kwargs) as span:
+ self.span = span
+
+ self.span.start_timestamp = datetime(2024, 1, 1, microsecond=0)
+ self.span.timestamp = datetime(2024, 1, 1, microsecond=101000)
+
+ def __enter__(self):
+ return self.span
+
+ def __exit__(self, type, value, traceback):
+ pass
+
+ with mock.patch(
+ "sentry_sdk.integrations.django.record_sql_queries",
+ fake_record_sql_queries,
+ ):
+ _, status, _ = unpack_werkzeug_response(
+ client.get(reverse("postgres_select_orm"))
+ )
+
+ assert status == "200 OK"
+
+ (event,) = events
+ for span in event["spans"]:
+ if span.get("op") == "db" and "auth_user" in span.get("description"):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE)
+ == "tests.integrations.django.myapp.views"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/django/myapp/views.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "postgres_select_orm"
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_span_origin_execute(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ client.get(reverse("postgres_select_orm"))
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.django"
+
+ for span in event["spans"]:
+ if span["op"] == "db":
+ assert span["origin"] == "auto.db.django"
+ else:
+ assert span["origin"] == "auto.http.django"
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_span_origin_executemany(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ cursor = connection.cursor()
+
+ query = """UPDATE auth_user SET username = %s where id = %s;"""
+ query_list = (
+ (
+ "test1",
+ 1,
+ ),
+ (
+ "test2",
+ 2,
+ ),
+ )
+ cursor.executemany(query, query_list)
+
+ transaction.commit()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.db.django"
diff --git a/tests/integrations/django/test_db_transactions.py b/tests/integrations/django/test_db_transactions.py
new file mode 100644
index 0000000000..2750397b0e
--- /dev/null
+++ b/tests/integrations/django/test_db_transactions.py
@@ -0,0 +1,977 @@
+import os
+import pytest
+import itertools
+from datetime import datetime
+
+from django.db import connections
+from django.contrib.auth.models import User
+
+try:
+ from django.urls import reverse
+except ImportError:
+ from django.core.urlresolvers import reverse
+
+from werkzeug.test import Client
+
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import SPANDATA, SPANNAME
+from sentry_sdk.integrations.django import DjangoIntegration
+
+from tests.integrations.django.utils import pytest_mark_django_db_decorator
+from tests.integrations.django.myapp.wsgi import application
+
+
+@pytest.fixture
+def client():
+ return Client(application)
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_transaction_spans_disabled_no_autocommit(
+ sentry_init, client, capture_events
+):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ client.get(reverse("postgres_insert_orm_no_autocommit_rollback"))
+ client.get(reverse("postgres_insert_orm_no_autocommit"))
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+
+ transaction.set_autocommit(False)
+ cursor.executemany(query, query_list)
+ transaction.rollback()
+ transaction.set_autocommit(True)
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+
+ transaction.set_autocommit(False)
+ cursor.executemany(query, query_list)
+ transaction.commit()
+ transaction.set_autocommit(True)
+
+ (postgres_rollback, postgres_commit, sqlite_rollback, sqlite_commit) = events
+
+ # Ensure operation is persisted
+ assert User.objects.using("postgres").exists()
+
+ assert postgres_rollback["contexts"]["trace"]["origin"] == "auto.http.django"
+ assert postgres_commit["contexts"]["trace"]["origin"] == "auto.http.django"
+ assert sqlite_rollback["contexts"]["trace"]["origin"] == "manual"
+ assert sqlite_commit["contexts"]["trace"]["origin"] == "manual"
+
+ commit_spans = [
+ span
+ for span in itertools.chain(
+ postgres_rollback["spans"],
+ postgres_commit["spans"],
+ sqlite_rollback["spans"],
+ sqlite_commit["spans"],
+ )
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_COMMIT
+ or span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_ROLLBACK
+ ]
+ assert len(commit_spans) == 0
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_transaction_spans_disabled_atomic(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ client.get(reverse("postgres_insert_orm_atomic_rollback"))
+ client.get(reverse("postgres_insert_orm_atomic"))
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ with transaction.atomic():
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+ cursor.executemany(query, query_list)
+ transaction.set_rollback(True)
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ with transaction.atomic():
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+ cursor.executemany(query, query_list)
+
+ (postgres_rollback, postgres_commit, sqlite_rollback, sqlite_commit) = events
+
+ # Ensure operation is persisted
+ assert User.objects.using("postgres").exists()
+
+ assert postgres_rollback["contexts"]["trace"]["origin"] == "auto.http.django"
+ assert postgres_commit["contexts"]["trace"]["origin"] == "auto.http.django"
+ assert sqlite_rollback["contexts"]["trace"]["origin"] == "manual"
+ assert sqlite_commit["contexts"]["trace"]["origin"] == "manual"
+
+ commit_spans = [
+ span
+ for span in itertools.chain(
+ postgres_rollback["spans"],
+ postgres_commit["spans"],
+ sqlite_rollback["spans"],
+ sqlite_commit["spans"],
+ )
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_COMMIT
+ or span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_ROLLBACK
+ ]
+ assert len(commit_spans) == 0
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_no_autocommit_execute(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ traces_sample_rate=1.0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ client.get(reverse("postgres_insert_orm_no_autocommit"))
+
+ (event,) = events
+
+ # Ensure operation is persisted
+ assert User.objects.using("postgres").exists()
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.django"
+
+ commit_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_COMMIT
+ ]
+ assert len(commit_spans) == 1
+ commit_span = commit_spans[0]
+ assert commit_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert commit_span["data"].get(SPANDATA.DB_SYSTEM) == "postgresql"
+ conn_params = connections["postgres"].get_connection_params()
+ assert commit_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert commit_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+ assert commit_span["data"].get(SPANDATA.SERVER_ADDRESS) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_HOST", "localhost"
+ )
+ assert commit_span["data"].get(SPANDATA.SERVER_PORT) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_PORT", "5432"
+ )
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+ assert len(insert_spans) == 1
+ insert_span = insert_spans[0]
+
+ # Verify query and commit statements are siblings
+ assert commit_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_no_autocommit_executemany(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+
+ transaction.set_autocommit(False)
+ cursor.executemany(query, query_list)
+ transaction.commit()
+ transaction.set_autocommit(True)
+
+ (event,) = events
+
+ # Ensure operation is persisted
+ assert User.objects.exists()
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.db.django"
+
+ commit_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_COMMIT
+ ]
+ assert len(commit_spans) == 1
+ commit_span = commit_spans[0]
+ assert commit_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert commit_span["data"].get(SPANDATA.DB_SYSTEM) == "sqlite"
+ conn_params = connection.get_connection_params()
+ assert commit_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert commit_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+
+ # Verify queries and commit statements are siblings
+ for insert_span in insert_spans:
+ assert commit_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_no_autocommit_rollback_execute(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ traces_sample_rate=1.0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ client.get(reverse("postgres_insert_orm_no_autocommit_rollback"))
+
+ (event,) = events
+
+ # Ensure operation is rolled back
+ assert not User.objects.using("postgres").exists()
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.django"
+
+ rollback_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_ROLLBACK
+ ]
+ assert len(rollback_spans) == 1
+ rollback_span = rollback_spans[0]
+ assert rollback_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert rollback_span["data"].get(SPANDATA.DB_SYSTEM) == "postgresql"
+ conn_params = connections["postgres"].get_connection_params()
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+ assert rollback_span["data"].get(SPANDATA.SERVER_ADDRESS) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_HOST", "localhost"
+ )
+ assert rollback_span["data"].get(SPANDATA.SERVER_PORT) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_PORT", "5432"
+ )
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+ assert len(insert_spans) == 1
+ insert_span = insert_spans[0]
+
+ # Verify query and rollback statements are siblings
+ assert rollback_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_no_autocommit_rollback_executemany(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+
+ transaction.set_autocommit(False)
+ cursor.executemany(query, query_list)
+ transaction.rollback()
+ transaction.set_autocommit(True)
+
+ (event,) = events
+
+ # Ensure operation is rolled back
+ assert not User.objects.exists()
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.db.django"
+
+ rollback_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_ROLLBACK
+ ]
+ assert len(rollback_spans) == 1
+ rollback_span = rollback_spans[0]
+ assert rollback_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert rollback_span["data"].get(SPANDATA.DB_SYSTEM) == "sqlite"
+ conn_params = connection.get_connection_params()
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+
+ # Verify queries and rollback statements are siblings
+ for insert_span in insert_spans:
+ assert rollback_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_atomic_execute(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ traces_sample_rate=1.0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ client.get(reverse("postgres_insert_orm_atomic"))
+
+ (event,) = events
+
+ # Ensure operation is persisted
+ assert User.objects.using("postgres").exists()
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.django"
+
+ commit_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_COMMIT
+ ]
+ assert len(commit_spans) == 1
+ commit_span = commit_spans[0]
+ assert commit_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert commit_span["data"].get(SPANDATA.DB_SYSTEM) == "postgresql"
+ conn_params = connections["postgres"].get_connection_params()
+ assert commit_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert commit_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+ assert commit_span["data"].get(SPANDATA.SERVER_ADDRESS) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_HOST", "localhost"
+ )
+ assert commit_span["data"].get(SPANDATA.SERVER_PORT) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_PORT", "5432"
+ )
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+ assert len(insert_spans) == 1
+ insert_span = insert_spans[0]
+
+ # Verify query and commit statements are siblings
+ assert commit_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_atomic_executemany(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ with transaction.atomic():
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+ cursor.executemany(query, query_list)
+
+ (event,) = events
+
+ # Ensure operation is persisted
+ assert User.objects.exists()
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ commit_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_COMMIT
+ ]
+ assert len(commit_spans) == 1
+ commit_span = commit_spans[0]
+ assert commit_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert commit_span["data"].get(SPANDATA.DB_SYSTEM) == "sqlite"
+ conn_params = connection.get_connection_params()
+ assert commit_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert commit_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+
+ # Verify queries and commit statements are siblings
+ for insert_span in insert_spans:
+ assert commit_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_atomic_rollback_execute(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ client.get(reverse("postgres_insert_orm_atomic_rollback"))
+
+ (event,) = events
+
+ # Ensure operation is rolled back
+ assert not User.objects.using("postgres").exists()
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.django"
+
+ rollback_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_ROLLBACK
+ ]
+ assert len(rollback_spans) == 1
+ rollback_span = rollback_spans[0]
+ assert rollback_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert rollback_span["data"].get(SPANDATA.DB_SYSTEM) == "postgresql"
+ conn_params = connections["postgres"].get_connection_params()
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+ assert rollback_span["data"].get(SPANDATA.SERVER_ADDRESS) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_HOST", "localhost"
+ )
+ assert rollback_span["data"].get(SPANDATA.SERVER_PORT) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_PORT", "5432"
+ )
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+ assert len(insert_spans) == 1
+ insert_span = insert_spans[0]
+
+ # Verify query and rollback statements are siblings
+ assert rollback_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_atomic_rollback_executemany(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ with transaction.atomic():
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+ cursor.executemany(query, query_list)
+ transaction.set_rollback(True)
+
+ (event,) = events
+
+ # Ensure operation is rolled back
+ assert not User.objects.exists()
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ rollback_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_ROLLBACK
+ ]
+ assert len(rollback_spans) == 1
+ rollback_span = rollback_spans[0]
+ assert rollback_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert rollback_span["data"].get(SPANDATA.DB_SYSTEM) == "sqlite"
+ conn_params = connection.get_connection_params()
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+
+ # Verify queries and rollback statements are siblings
+ for insert_span in insert_spans:
+ assert rollback_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_atomic_execute_exception(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ )
+
+ if "postgres" not in connections:
+ pytest.skip("postgres tests disabled")
+
+ # trigger Django to open a new connection by marking the existing one as None.
+ connections["postgres"].connection = None
+
+ events = capture_events()
+
+ client.get(reverse("postgres_insert_orm_atomic_exception"))
+
+ (event,) = events
+
+ # Ensure operation is rolled back
+ assert not User.objects.using("postgres").exists()
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.django"
+
+ rollback_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_ROLLBACK
+ ]
+ assert len(rollback_spans) == 1
+ rollback_span = rollback_spans[0]
+ assert rollback_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert rollback_span["data"].get(SPANDATA.DB_SYSTEM) == "postgresql"
+ conn_params = connections["postgres"].get_connection_params()
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+ assert rollback_span["data"].get(SPANDATA.SERVER_ADDRESS) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_HOST", "localhost"
+ )
+ assert rollback_span["data"].get(SPANDATA.SERVER_PORT) == os.environ.get(
+ "SENTRY_PYTHON_TEST_POSTGRES_PORT", "5432"
+ )
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+ assert len(insert_spans) == 1
+ insert_span = insert_spans[0]
+
+ # Verify query and rollback statements are siblings
+ assert rollback_span["parent_span_id"] == insert_span["parent_span_id"]
+
+
+@pytest.mark.forked
+@pytest_mark_django_db_decorator(transaction=True)
+def test_db_atomic_executemany_exception(sentry_init, client, capture_events):
+ sentry_init(
+ integrations=[DjangoIntegration(db_transaction_spans=True)],
+ send_default_pii=True,
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ from django.db import connection, transaction
+
+ try:
+ with transaction.atomic():
+ cursor = connection.cursor()
+
+ query = """INSERT INTO auth_user (
+ password,
+ is_superuser,
+ username,
+ first_name,
+ last_name,
+ email,
+ is_staff,
+ is_active,
+ date_joined
+)
+VALUES ('password', false, %s, %s, %s, %s, false, true, %s);"""
+
+ query_list = (
+ (
+ "user1",
+ "John",
+ "Doe",
+ "user1@example.com",
+ datetime(1970, 1, 1),
+ ),
+ (
+ "user2",
+ "Max",
+ "Mustermann",
+ "user2@example.com",
+ datetime(1970, 1, 1),
+ ),
+ )
+ cursor.executemany(query, query_list)
+ 1 / 0
+ except ZeroDivisionError:
+ pass
+
+ (event,) = events
+
+ # Ensure operation is rolled back
+ assert not User.objects.exists()
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ rollback_spans = [
+ span
+ for span in event["spans"]
+ if span["data"].get(SPANDATA.DB_OPERATION) == SPANNAME.DB_ROLLBACK
+ ]
+ assert len(rollback_spans) == 1
+ rollback_span = rollback_spans[0]
+ assert rollback_span["origin"] == "auto.db.django"
+
+ # Verify other database attributes
+ assert rollback_span["data"].get(SPANDATA.DB_SYSTEM) == "sqlite"
+ conn_params = connection.get_connection_params()
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) is not None
+ assert rollback_span["data"].get(SPANDATA.DB_NAME) == conn_params.get(
+ "database"
+ ) or conn_params.get("dbname")
+
+ insert_spans = [
+ span for span in event["spans"] if span["description"].startswith("INSERT INTO")
+ ]
+
+ # Verify queries and rollback statements are siblings
+ for insert_span in insert_spans:
+ assert rollback_span["parent_span_id"] == insert_span["parent_span_id"]
diff --git a/tests/integrations/django/test_middleware.py b/tests/integrations/django/test_middleware.py
new file mode 100644
index 0000000000..9c4c1ddfd1
--- /dev/null
+++ b/tests/integrations/django/test_middleware.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+import pytest
+
+from sentry_sdk.integrations.django.middleware import _wrap_middleware
+
+
+def _sync_capable_middleware_factory(sync_capable: "Optional[bool]") -> type:
+ """Create a middleware class with a sync_capable attribute set to the value passed to the factory.
+ If the factory is called with None, the middleware class will not have a sync_capable attribute.
+ """
+ sc = sync_capable # rename so we can set sync_capable in the class
+
+ class TestMiddleware:
+ nonlocal sc
+ if sc is not None:
+ sync_capable = sc
+
+ return TestMiddleware
+
+
+@pytest.mark.parametrize(
+ ("middleware", "sync_capable"),
+ (
+ (_sync_capable_middleware_factory(True), True),
+ (_sync_capable_middleware_factory(False), False),
+ (_sync_capable_middleware_factory(None), True),
+ ),
+)
+def test_wrap_middleware_sync_capable_attribute(middleware, sync_capable):
+ wrapped_middleware = _wrap_middleware(middleware, "test_middleware")
+
+ assert wrapped_middleware.sync_capable is sync_capable
diff --git a/tests/integrations/django/test_tasks.py b/tests/integrations/django/test_tasks.py
new file mode 100644
index 0000000000..56c68b807f
--- /dev/null
+++ b/tests/integrations/django/test_tasks.py
@@ -0,0 +1,186 @@
+import pytest
+
+import sentry_sdk
+from sentry_sdk.integrations.django import DjangoIntegration
+from sentry_sdk.consts import OP
+
+
+try:
+ from django.tasks import task
+
+ HAS_DJANGO_TASKS = True
+except ImportError:
+ HAS_DJANGO_TASKS = False
+
+
+@pytest.fixture
+def immediate_backend(settings):
+ """Configure Django to use the immediate task backend for synchronous testing."""
+ settings.TASKS = {
+ "default": {"BACKEND": "django.tasks.backends.immediate.ImmediateBackend"}
+ }
+
+
+if HAS_DJANGO_TASKS:
+
+ @task
+ def simple_task():
+ return "result"
+
+ @task
+ def add_numbers(a, b):
+ return a + b
+
+ @task
+ def greet(name, greeting="Hello"):
+ return f"{greeting}, {name}!"
+
+ @task
+ def failing_task():
+ raise ValueError("Task failed!")
+
+ @task
+ def task_one():
+ return 1
+
+ @task
+ def task_two():
+ return 2
+
+
+@pytest.mark.skipif(
+ not HAS_DJANGO_TASKS,
+ reason="Django tasks are only available in Django 6.0+",
+)
+def test_task_span_is_created(sentry_init, capture_events, immediate_backend):
+ """Test that the queue.submit.django span is created when a task is enqueued."""
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test_transaction"):
+ simple_task.enqueue()
+
+ (event,) = events
+ assert event["type"] == "transaction"
+
+ queue_submit_spans = [
+ span for span in event["spans"] if span["op"] == OP.QUEUE_SUBMIT_DJANGO
+ ]
+ assert len(queue_submit_spans) == 1
+ assert (
+ queue_submit_spans[0]["description"]
+ == "tests.integrations.django.test_tasks.simple_task"
+ )
+ assert queue_submit_spans[0]["origin"] == "auto.http.django"
+
+
+@pytest.mark.skipif(
+ not HAS_DJANGO_TASKS,
+ reason="Django tasks are only available in Django 6.0+",
+)
+def test_task_enqueue_returns_result(sentry_init, immediate_backend):
+ """Test that the task enqueuing behavior is unchanged from the user perspective."""
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ result = add_numbers.enqueue(3, 5)
+
+ assert result is not None
+ assert result.return_value == 8
+
+
+@pytest.mark.skipif(
+ not HAS_DJANGO_TASKS,
+ reason="Django tasks are only available in Django 6.0+",
+)
+def test_task_enqueue_with_kwargs(sentry_init, immediate_backend, capture_events):
+ """Test that task enqueuing works correctly with keyword arguments."""
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test_transaction"):
+ result = greet.enqueue(name="World", greeting="Hi")
+
+ assert result.return_value == "Hi, World!"
+
+ (event,) = events
+ queue_submit_spans = [
+ span for span in event["spans"] if span["op"] == OP.QUEUE_SUBMIT_DJANGO
+ ]
+ assert len(queue_submit_spans) == 1
+ assert (
+ queue_submit_spans[0]["description"]
+ == "tests.integrations.django.test_tasks.greet"
+ )
+
+
+@pytest.mark.skipif(
+ not HAS_DJANGO_TASKS,
+ reason="Django tasks are only available in Django 6.0+",
+)
+def test_task_error_reporting(sentry_init, immediate_backend, capture_events):
+ """Test that errors in tasks are correctly reported and don't break the span."""
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test_transaction"):
+ result = failing_task.enqueue()
+
+ with pytest.raises(ValueError, match="Task failed"):
+ _ = result.return_value
+
+ assert len(events) == 2
+ transaction_event = events[-1]
+ assert transaction_event["type"] == "transaction"
+
+ queue_submit_spans = [
+ span
+ for span in transaction_event["spans"]
+ if span["op"] == OP.QUEUE_SUBMIT_DJANGO
+ ]
+ assert len(queue_submit_spans) == 1
+ assert (
+ queue_submit_spans[0]["description"]
+ == "tests.integrations.django.test_tasks.failing_task"
+ )
+
+
+@pytest.mark.skipif(
+ not HAS_DJANGO_TASKS,
+ reason="Django tasks are only available in Django 6.0+",
+)
+def test_multiple_task_enqueues_create_multiple_spans(
+ sentry_init, capture_events, immediate_backend
+):
+ """Test that enqueueing multiple tasks creates multiple spans."""
+ sentry_init(
+ integrations=[DjangoIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(name="test_transaction"):
+ task_one.enqueue()
+ task_two.enqueue()
+ task_one.enqueue()
+
+ (event,) = events
+ queue_submit_spans = [
+ span for span in event["spans"] if span["op"] == OP.QUEUE_SUBMIT_DJANGO
+ ]
+ assert len(queue_submit_spans) == 3
+
+ span_names = [span["description"] for span in queue_submit_spans]
+ assert span_names.count("tests.integrations.django.test_tasks.task_one") == 2
+ assert span_names.count("tests.integrations.django.test_tasks.task_two") == 1
diff --git a/tests/integrations/django/test_transactions.py b/tests/integrations/django/test_transactions.py
index 6f16d88cec..14f8170fc3 100644
--- a/tests/integrations/django/test_transactions.py
+++ b/tests/integrations/django/test_transactions.py
@@ -1,45 +1,53 @@
-from __future__ import absolute_import
+from unittest import mock
import pytest
import django
+from django.utils.translation import pgettext_lazy
+
+# django<2.0 has only `url` with regex based patterns.
+# django>=2.0 renames `url` to `re_path`, and additionally introduces `path`
+# for new style URL patterns, e.g. .
if django.VERSION >= (2, 0):
- # TODO: once we stop supporting django < 2, use the real name of this
- # function (re_path)
- from django.urls import re_path as url
+ from django.urls import path, re_path
+ from django.urls.converters import PathConverter
from django.conf.urls import include
else:
- from django.conf.urls import url, include
+ from django.conf.urls import url as re_path, include
if django.VERSION < (1, 9):
- included_url_conf = (url(r"^foo/bar/(?P[\w]+)", lambda x: ""),), "", ""
+ included_url_conf = (re_path(r"^foo/bar/(?P[\w]+)", lambda x: ""),), "", ""
else:
- included_url_conf = ((url(r"^foo/bar/(?P[\w]+)", lambda x: ""),), "")
+ included_url_conf = ((re_path(r"^foo/bar/(?P[\w]+)", lambda x: ""),), "")
from sentry_sdk.integrations.django.transactions import RavenResolver
example_url_conf = (
- url(r"^api/(?P[\w_-]+)/store/$", lambda x: ""),
- url(r"^api/(?P(v1|v2))/author/$", lambda x: ""),
- url(r"^report/", lambda x: ""),
- url(r"^example/", include(included_url_conf)),
+ re_path(r"^api/(?P[\w_-]+)/store/$", lambda x: ""),
+ re_path(r"^api/(?P(v1|v2))/author/$", lambda x: ""),
+ re_path(
+ r"^api/(?P[^\/]+)/product/(?P(?:\d+|[A-Fa-f0-9-]{32,36}))/$",
+ lambda x: "",
+ ),
+ re_path(r"^report/", lambda x: ""),
+ re_path(r"^example/", include(included_url_conf)),
)
-def test_legacy_resolver_no_match():
+def test_resolver_no_match():
resolver = RavenResolver()
result = resolver.resolve("/foo/bar", example_url_conf)
assert result is None
-def test_legacy_resolver_complex_match():
+def test_resolver_re_path_complex_match():
resolver = RavenResolver()
result = resolver.resolve("/api/1234/store/", example_url_conf)
assert result == "/api/{project_id}/store/"
-def test_legacy_resolver_complex_either_match():
+def test_resolver_re_path_complex_either_match():
resolver = RavenResolver()
result = resolver.resolve("/api/v1/author/", example_url_conf)
assert result == "/api/{version}/author/"
@@ -47,17 +55,99 @@ def test_legacy_resolver_complex_either_match():
assert result == "/api/{version}/author/"
-def test_legacy_resolver_included_match():
+def test_resolver_re_path_included_match():
resolver = RavenResolver()
result = resolver.resolve("/example/foo/bar/baz", example_url_conf)
assert result == "/example/foo/bar/{param}"
-@pytest.mark.skipif(django.VERSION < (2, 0), reason="Requires Django > 2.0")
-def test_legacy_resolver_newstyle_django20_urlconf():
- from django.urls import path
+def test_resolver_re_path_multiple_groups():
+ resolver = RavenResolver()
+ result = resolver.resolve(
+ "/api/myproject/product/cb4ef1caf3554c34ae134f3c1b3d605f/", example_url_conf
+ )
+ assert result == "/api/{project_id}/product/{pid}/"
+
+@pytest.mark.skipif(
+ django.VERSION < (2, 0),
+ reason="Django>=2.0 required for patterns",
+)
+def test_resolver_path_group():
url_conf = (path("api/v2//store/", lambda x: ""),)
resolver = RavenResolver()
result = resolver.resolve("/api/v2/1234/store/", url_conf)
assert result == "/api/v2/{project_id}/store/"
+
+
+@pytest.mark.skipif(
+ django.VERSION < (2, 0),
+ reason="Django>=2.0 required for patterns",
+)
+def test_resolver_path_multiple_groups():
+ url_conf = (path("api/v2//product/", lambda x: ""),)
+ resolver = RavenResolver()
+ result = resolver.resolve("/api/v2/myproject/product/5689", url_conf)
+ assert result == "/api/v2/{project_id}/product/{pid}"
+
+
+@pytest.mark.skipif(
+ django.VERSION < (2, 0),
+ reason="Django>=2.0 required for patterns",
+)
+@pytest.mark.skipif(
+ django.VERSION > (5, 1),
+ reason="get_converter removed in 5.1",
+)
+def test_resolver_path_complex_path_legacy():
+ class CustomPathConverter(PathConverter):
+ regex = r"[^/]+(/[^/]+){0,2}"
+
+ with mock.patch(
+ "django.urls.resolvers.get_converter",
+ return_value=CustomPathConverter,
+ ):
+ url_conf = (path("api/v3/", lambda x: ""),)
+ resolver = RavenResolver()
+ result = resolver.resolve("/api/v3/abc/def/ghi", url_conf)
+ assert result == "/api/v3/{my_path}"
+
+
+@pytest.mark.skipif(
+ django.VERSION < (5, 1),
+ reason="get_converters is used in 5.1",
+)
+def test_resolver_path_complex_path():
+ class CustomPathConverter(PathConverter):
+ regex = r"[^/]+(/[^/]+){0,2}"
+
+ with mock.patch(
+ "django.urls.resolvers.get_converters",
+ return_value={"custom_path": CustomPathConverter},
+ ):
+ url_conf = (path("api/v3/", lambda x: ""),)
+ resolver = RavenResolver()
+ result = resolver.resolve("/api/v3/abc/def/ghi", url_conf)
+ assert result == "/api/v3/{my_path}"
+
+
+@pytest.mark.skipif(
+ django.VERSION < (2, 0),
+ reason="Django>=2.0 required for patterns",
+)
+def test_resolver_path_no_converter():
+ url_conf = (path("api/v4/", lambda x: ""),)
+ resolver = RavenResolver()
+ result = resolver.resolve("/api/v4/myproject", url_conf)
+ assert result == "/api/v4/{project_id}"
+
+
+@pytest.mark.skipif(
+ django.VERSION < (2, 0),
+ reason="Django>=2.0 required for path patterns",
+)
+def test_resolver_path_with_i18n():
+ url_conf = (path(pgettext_lazy("url", "pgettext"), lambda x: ""),)
+ resolver = RavenResolver()
+ result = resolver.resolve("/pgettext", url_conf)
+ assert result == "/pgettext"
diff --git a/tests/integrations/django/utils.py b/tests/integrations/django/utils.py
new file mode 100644
index 0000000000..8f68c8fa14
--- /dev/null
+++ b/tests/integrations/django/utils.py
@@ -0,0 +1,22 @@
+from functools import partial
+
+import pytest
+import pytest_django
+
+
+# Hack to prevent from experimental feature introduced in version `4.3.0` in `pytest-django` that
+# requires explicit database allow from failing the test
+pytest_mark_django_db_decorator = partial(pytest.mark.django_db)
+try:
+ pytest_version = tuple(map(int, pytest_django.__version__.split(".")))
+ if pytest_version > (4, 2, 0):
+ pytest_mark_django_db_decorator = partial(
+ pytest.mark.django_db, databases="__all__"
+ )
+except ValueError:
+ if "dev" in pytest_django.__version__:
+ pytest_mark_django_db_decorator = partial(
+ pytest.mark.django_db, databases="__all__"
+ )
+except AttributeError:
+ pass
diff --git a/tests/integrations/dramatiq/__init__.py b/tests/integrations/dramatiq/__init__.py
new file mode 100644
index 0000000000..70bbf21db4
--- /dev/null
+++ b/tests/integrations/dramatiq/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("dramatiq")
diff --git a/tests/integrations/dramatiq/test_dramatiq.py b/tests/integrations/dramatiq/test_dramatiq.py
new file mode 100644
index 0000000000..a9d3966839
--- /dev/null
+++ b/tests/integrations/dramatiq/test_dramatiq.py
@@ -0,0 +1,414 @@
+import uuid
+
+import dramatiq
+import pytest
+from dramatiq.brokers.stub import StubBroker
+from dramatiq.middleware import Middleware, SkipMessage
+
+import sentry_sdk
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import SPANSTATUS
+from sentry_sdk.integrations.dramatiq import DramatiqIntegration
+from sentry_sdk.integrations.logging import ignore_logger
+from sentry_sdk.tracing import Transaction, TransactionSource
+
+ignore_logger("dramatiq.worker.WorkerThread")
+
+
+@pytest.fixture(scope="function")
+def broker(request, sentry_init):
+ sentry_init(
+ integrations=[DramatiqIntegration()],
+ traces_sample_rate=getattr(request, "param", None),
+ )
+ broker = StubBroker()
+ broker.emit_after("process_boot")
+ dramatiq.set_broker(broker)
+ yield broker
+ broker.flush_all()
+ broker.close()
+
+
+@pytest.fixture
+def worker(broker):
+ worker = dramatiq.Worker(broker, worker_timeout=100, worker_threads=1)
+ worker.start()
+ yield worker
+ worker.stop()
+
+
+@pytest.mark.parametrize(
+ "fail_fast",
+ [
+ False,
+ True,
+ ],
+)
+def test_that_a_single_error_is_captured(broker, worker, capture_events, fail_fast):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=0)
+ def dummy_actor(x, y):
+ return x / y
+
+ dummy_actor.send(1, 2)
+ dummy_actor.send(1, 0)
+ if fail_fast:
+ with pytest.raises(ZeroDivisionError):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ (event,) = events
+ exception = event["exception"]["values"][0]
+ assert exception["type"] == "ZeroDivisionError"
+
+
+@pytest.mark.parametrize(
+ "broker,expected_span_status,fail_fast",
+ [
+ (1.0, SPANSTATUS.INTERNAL_ERROR, False),
+ (1.0, SPANSTATUS.OK, False),
+ (1.0, SPANSTATUS.INTERNAL_ERROR, True),
+ (1.0, SPANSTATUS.OK, True),
+ ],
+ ids=["error", "success", "error_fail_fast", "success_fail_fast"],
+ indirect=["broker"],
+)
+def test_task_transaction(
+ broker, worker, capture_events, expected_span_status, fail_fast
+):
+ events = capture_events()
+ task_fails = expected_span_status == SPANSTATUS.INTERNAL_ERROR
+
+ @dramatiq.actor(max_retries=0)
+ def dummy_actor(x, y):
+ return x / y
+
+ dummy_actor.send(1, int(not task_fails))
+
+ if expected_span_status == SPANSTATUS.INTERNAL_ERROR and fail_fast:
+ with pytest.raises(ZeroDivisionError):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+
+ worker.join()
+
+ if task_fails:
+ error_event = events.pop(0)
+ exception = error_event["exception"]["values"][0]
+ assert exception["type"] == "ZeroDivisionError"
+ assert exception["mechanism"]["type"] == DramatiqIntegration.identifier
+
+ (event,) = events
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "dummy_actor"
+ assert event["transaction_info"] == {"source": TransactionSource.TASK}
+ assert event["contexts"]["trace"]["status"] == expected_span_status
+
+
+@pytest.mark.parametrize("broker", [1.0], indirect=True)
+def test_dramatiq_propagate_trace(broker, worker, capture_events):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=0)
+ def propagated_trace_task():
+ pass
+
+ with start_transaction() as outer_transaction:
+ propagated_trace_task.send()
+ broker.join(propagated_trace_task.queue_name)
+ worker.join()
+
+ assert (
+ events[0]["transaction"] == "propagated_trace_task"
+ ) # the "inner" transaction
+ assert events[0]["contexts"]["trace"]["trace_id"] == outer_transaction.trace_id
+
+
+@pytest.mark.parametrize(
+ "fail_fast",
+ [
+ False,
+ True,
+ ],
+)
+def test_that_dramatiq_message_id_is_set_as_extra(
+ broker, worker, capture_events, fail_fast
+):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=0)
+ def dummy_actor(x, y):
+ sentry_sdk.capture_message("hi")
+ return x / y
+
+ dummy_actor.send(1, 0)
+ if fail_fast:
+ with pytest.raises(ZeroDivisionError):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ event_message, event_error = events
+ assert "dramatiq_message_id" in event_message["extra"]
+ assert "dramatiq_message_id" in event_error["extra"]
+ assert (
+ event_message["extra"]["dramatiq_message_id"]
+ == event_error["extra"]["dramatiq_message_id"]
+ )
+ msg_ids = [e["extra"]["dramatiq_message_id"] for e in events]
+ assert all(uuid.UUID(msg_id) and isinstance(msg_id, str) for msg_id in msg_ids)
+
+
+@pytest.mark.parametrize(
+ "fail_fast",
+ [
+ False,
+ True,
+ ],
+)
+def test_that_local_variables_are_captured(broker, worker, capture_events, fail_fast):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=0)
+ def dummy_actor(x, y):
+ foo = 42 # noqa
+ return x / y
+
+ dummy_actor.send(1, 2)
+ dummy_actor.send(1, 0)
+ if fail_fast:
+ with pytest.raises(ZeroDivisionError):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ (event,) = events
+ exception = event["exception"]["values"][0]
+ assert exception["stacktrace"]["frames"][-1]["vars"] == {
+ "x": "1",
+ "y": "0",
+ "foo": "42",
+ }
+
+
+def test_that_messages_are_captured(broker, worker, capture_events):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=0)
+ def dummy_actor():
+ sentry_sdk.capture_message("hi")
+
+ dummy_actor.send()
+ broker.join(dummy_actor.queue_name)
+ worker.join()
+
+ (event,) = events
+ assert event["message"] == "hi"
+ assert event["level"] == "info"
+ assert event["transaction"] == "dummy_actor"
+
+
+@pytest.mark.parametrize(
+ "fail_fast",
+ [
+ False,
+ True,
+ ],
+)
+def test_that_sub_actor_errors_are_captured(broker, worker, capture_events, fail_fast):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=0)
+ def dummy_actor(x, y):
+ sub_actor.send(x, y)
+
+ @dramatiq.actor(max_retries=0)
+ def sub_actor(x, y):
+ return x / y
+
+ dummy_actor.send(1, 2)
+ dummy_actor.send(1, 0)
+ if fail_fast:
+ with pytest.raises(ZeroDivisionError):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ (event,) = events
+ assert event["transaction"] == "sub_actor"
+
+ exception = event["exception"]["values"][0]
+ assert exception["type"] == "ZeroDivisionError"
+
+
+@pytest.mark.parametrize(
+ "fail_fast",
+ [
+ False,
+ True,
+ ],
+)
+def test_that_multiple_errors_are_captured(broker, worker, capture_events, fail_fast):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=0)
+ def dummy_actor(x, y):
+ return x / y
+
+ dummy_actor.send(1, 0)
+ if fail_fast:
+ with pytest.raises(ZeroDivisionError):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ dummy_actor.send(1, None)
+ if fail_fast:
+ with pytest.raises(ZeroDivisionError):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ event1, event2 = events
+
+ assert event1["transaction"] == "dummy_actor"
+ exception = event1["exception"]["values"][0]
+ assert exception["type"] == "ZeroDivisionError"
+
+ assert event2["transaction"] == "dummy_actor"
+ exception = event2["exception"]["values"][0]
+ assert exception["type"] == "TypeError"
+
+
+@pytest.mark.parametrize(
+ "fail_fast",
+ [
+ False,
+ True,
+ ],
+)
+def test_that_message_data_is_added_as_request(
+ broker, worker, capture_events, fail_fast
+):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=0)
+ def dummy_actor(x, y):
+ return x / y
+
+ dummy_actor.send_with_options(
+ args=(
+ 1,
+ 0,
+ ),
+ max_retries=0,
+ )
+ if fail_fast:
+ with pytest.raises(ZeroDivisionError):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ (event,) = events
+
+ assert event["transaction"] == "dummy_actor"
+ request_data = event["contexts"]["dramatiq"]["data"]
+ assert request_data["queue_name"] == "default"
+ assert request_data["actor_name"] == "dummy_actor"
+ assert request_data["args"] == [1, 0]
+ assert request_data["kwargs"] == {}
+ assert request_data["options"]["max_retries"] == 0
+ assert uuid.UUID(request_data["message_id"])
+ assert isinstance(request_data["message_timestamp"], int)
+
+
+@pytest.mark.parametrize(
+ "fail_fast",
+ [
+ False,
+ True,
+ ],
+)
+def test_that_expected_exceptions_are_not_captured(
+ broker, worker, capture_events, fail_fast
+):
+ events = capture_events()
+
+ class ExpectedException(Exception):
+ pass
+
+ @dramatiq.actor(max_retries=0, throws=ExpectedException)
+ def dummy_actor():
+ raise ExpectedException
+
+ dummy_actor.send()
+ if fail_fast:
+ with pytest.raises(ExpectedException):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ assert events == []
+
+
+@pytest.mark.parametrize(
+ "fail_fast",
+ [
+ False,
+ True,
+ ],
+)
+def test_that_retry_exceptions_are_not_captured(
+ broker, worker, capture_events, fail_fast
+):
+ events = capture_events()
+
+ @dramatiq.actor(max_retries=2)
+ def dummy_actor():
+ raise dramatiq.errors.Retry("Retrying", delay=100)
+
+ dummy_actor.send()
+ if fail_fast:
+ with pytest.raises(dramatiq.errors.Retry):
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ else:
+ broker.join(dummy_actor.queue_name, fail_fast=fail_fast)
+ worker.join()
+
+ assert events == []
+
+
+@pytest.mark.parametrize("broker", [1.0], indirect=True)
+def test_that_skip_message_cleans_up_scope_and_transaction(
+ broker, worker, capture_events
+):
+ transactions: list[Transaction] = []
+
+ class SkipMessageMiddleware(Middleware):
+ def before_process_message(self, broker, message):
+ transactions.append(sentry_sdk.get_current_scope().transaction)
+ raise SkipMessage()
+
+ broker.add_middleware(SkipMessageMiddleware())
+
+ @dramatiq.actor(max_retries=0)
+ def skipped_actor(): ...
+
+ skipped_actor.send()
+
+ broker.join(skipped_actor.queue_name)
+ worker.join()
+
+ (transaction,) = transactions
+ assert transaction.timestamp is not None
diff --git a/tests/integrations/excepthook/test_excepthook.py b/tests/integrations/excepthook/test_excepthook.py
index 18deccd76e..5a19b4f985 100644
--- a/tests/integrations/excepthook/test_excepthook.py
+++ b/tests/integrations/excepthook/test_excepthook.py
@@ -5,25 +5,34 @@
from textwrap import dedent
-def test_excepthook(tmpdir):
+TEST_PARAMETERS = [("", "HttpTransport")]
+
+if sys.version_info >= (3, 8):
+ TEST_PARAMETERS.append(('_experiments={"transport_http2": True}', "Http2Transport"))
+
+
+@pytest.mark.parametrize("options, transport", TEST_PARAMETERS)
+def test_excepthook(tmpdir, options, transport):
app = tmpdir.join("app.py")
app.write(
dedent(
"""
from sentry_sdk import init, transport
- def send_event(self, event):
- print("capture event was called")
- print(event)
+ def capture_envelope(self, envelope):
+ print("capture_envelope was called")
+ event = envelope.get_event()
+ if event is not None:
+ print(event)
- transport.HttpTransport._send_event = send_event
+ transport.{transport}.capture_envelope = capture_envelope
- init("http://foobar@localhost/123")
+ init("http://foobar@localhost/123", {options})
frame_value = "LOL"
1/0
- """
+ """.format(transport=transport, options=options)
)
)
@@ -31,14 +40,14 @@ def send_event(self, event):
subprocess.check_output([sys.executable, str(app)], stderr=subprocess.STDOUT)
output = excinfo.value.output
- print(output)
assert b"ZeroDivisionError" in output
assert b"LOL" in output
- assert b"capture event was called" in output
+ assert b"capture_envelope was called" in output
-def test_always_value_excepthook(tmpdir):
+@pytest.mark.parametrize("options, transport", TEST_PARAMETERS)
+def test_always_value_excepthook(tmpdir, options, transport):
app = tmpdir.join("app.py")
app.write(
dedent(
@@ -47,21 +56,24 @@ def test_always_value_excepthook(tmpdir):
from sentry_sdk import init, transport
from sentry_sdk.integrations.excepthook import ExcepthookIntegration
- def send_event(self, event):
- print("capture event was called")
- print(event)
+ def capture_envelope(self, envelope):
+ print("capture_envelope was called")
+ event = envelope.get_event()
+ if event is not None:
+ print(event)
- transport.HttpTransport._send_event = send_event
+ transport.{transport}.capture_envelope = capture_envelope
sys.ps1 = "always_value_test"
init("http://foobar@localhost/123",
- integrations=[ExcepthookIntegration(always_run=True)]
+ integrations=[ExcepthookIntegration(always_run=True)],
+ {options}
)
frame_value = "LOL"
1/0
- """
+ """.format(transport=transport, options=options)
)
)
@@ -69,8 +81,7 @@ def send_event(self, event):
subprocess.check_output([sys.executable, str(app)], stderr=subprocess.STDOUT)
output = excinfo.value.output
- print(output)
assert b"ZeroDivisionError" in output
assert b"LOL" in output
- assert b"capture event was called" in output
+ assert b"capture_envelope was called" in output
diff --git a/tests/integrations/falcon/__init__.py b/tests/integrations/falcon/__init__.py
new file mode 100644
index 0000000000..2319937c18
--- /dev/null
+++ b/tests/integrations/falcon/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("falcon")
diff --git a/tests/integrations/falcon/test_falcon.py b/tests/integrations/falcon/test_falcon.py
index dd7aa80dfe..f972419092 100644
--- a/tests/integrations/falcon/test_falcon.py
+++ b/tests/integrations/falcon/test_falcon.py
@@ -1,16 +1,25 @@
-from __future__ import absolute_import
-
import logging
import pytest
-pytest.importorskip("falcon")
-
import falcon
import falcon.testing
import sentry_sdk
+from sentry_sdk.consts import DEFAULT_MAX_VALUE_LENGTH
from sentry_sdk.integrations.falcon import FalconIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
+from sentry_sdk.utils import parse_version
+
+
+try:
+ import falcon.asgi
+except ImportError:
+ pass
+else:
+ import falcon.inspect # We only need this module for the ASGI test
+
+
+FALCON_VERSION = parse_version(falcon.__version__)
@pytest.fixture
@@ -26,9 +35,22 @@ def on_get(self, req, resp, message_id):
sentry_sdk.capture_message("hi")
resp.media = "hi"
+ class CustomError(Exception):
+ pass
+
+ class CustomErrorResource:
+ def on_get(self, req, resp):
+ raise CustomError()
+
+ def custom_error_handler(*args, **kwargs):
+ raise falcon.HTTPError(status=falcon.HTTP_400)
+
app = falcon.API()
app.add_route("/message", MessageResource())
app.add_route("/message/{message_id:int}", MessageByIdResource())
+ app.add_route("/custom-error", CustomErrorResource())
+
+ app.add_error_handler(CustomError, custom_error_handler)
return app
@@ -90,7 +112,7 @@ def test_transaction_style(
def test_unhandled_errors(sentry_init, capture_exceptions, capture_events):
- sentry_init(integrations=[FalconIntegration()], debug=True)
+ sentry_init(integrations=[FalconIntegration()])
class Resource:
def on_get(self, req, resp):
@@ -118,7 +140,7 @@ def on_get(self, req, resp):
def test_raised_5xx_errors(sentry_init, capture_exceptions, capture_events):
- sentry_init(integrations=[FalconIntegration()], debug=True)
+ sentry_init(integrations=[FalconIntegration()])
class Resource:
def on_get(self, req, resp):
@@ -142,7 +164,7 @@ def on_get(self, req, resp):
def test_raised_4xx_errors(sentry_init, capture_exceptions, capture_events):
- sentry_init(integrations=[FalconIntegration()], debug=True)
+ sentry_init(integrations=[FalconIntegration()])
class Resource:
def on_get(self, req, resp):
@@ -166,7 +188,7 @@ def test_http_status(sentry_init, capture_exceptions, capture_events):
This just demonstrates, that if Falcon raises a HTTPStatus with code 500
(instead of a HTTPError with code 500) Sentry will not capture it.
"""
- sentry_init(integrations=[FalconIntegration()], debug=True)
+ sentry_init(integrations=[FalconIntegration()])
class Resource:
def on_get(self, req, resp):
@@ -186,9 +208,9 @@ def on_get(self, req, resp):
def test_falcon_large_json_request(sentry_init, capture_events):
- sentry_init(integrations=[FalconIntegration()])
+ sentry_init(integrations=[FalconIntegration()], max_request_body_size="always")
- data = {"foo": {"bar": "a" * 2000}}
+ data = {"foo": {"bar": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10)}}
class Resource:
def on_post(self, req, resp):
@@ -207,9 +229,14 @@ def on_post(self, req, resp):
(event,) = events
assert event["_meta"]["request"]["data"]["foo"]["bar"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]["bar"]) == 1024
+ assert len(event["request"]["data"]["foo"]["bar"]) == DEFAULT_MAX_VALUE_LENGTH
@pytest.mark.parametrize("data", [{}, []], ids=["empty-dict", "empty-list"])
@@ -282,7 +309,7 @@ def on_get(self, req, resp):
assert event["level"] == "error"
-def test_500(sentry_init, capture_events):
+def test_500(sentry_init):
sentry_init(integrations=[FalconIntegration()])
app = falcon.API()
@@ -295,17 +322,14 @@ def on_get(self, req, resp):
def http500_handler(ex, req, resp, params):
sentry_sdk.capture_exception(ex)
- resp.media = {"message": "Sentry error: %s" % sentry_sdk.last_event_id()}
+ resp.media = {"message": "Sentry error."}
app.add_error_handler(Exception, http500_handler)
- events = capture_events()
-
client = falcon.testing.TestClient(app)
response = client.simulate_get("/")
- (event,) = events
- assert response.json == {"message": "Sentry error: %s" % event["event_id"]}
+ assert response.json == {"message": "Sentry error."}
def test_error_in_errorhandler(sentry_init, capture_events):
@@ -361,20 +385,17 @@ def test_does_not_leak_scope(sentry_init, capture_events):
sentry_init(integrations=[FalconIntegration()])
events = capture_events()
- with sentry_sdk.configure_scope() as scope:
- scope.set_tag("request_data", False)
+ sentry_sdk.get_isolation_scope().set_tag("request_data", False)
app = falcon.API()
class Resource:
def on_get(self, req, resp):
- with sentry_sdk.configure_scope() as scope:
- scope.set_tag("request_data", True)
+ sentry_sdk.get_isolation_scope().set_tag("request_data", True)
def generator():
for row in range(1000):
- with sentry_sdk.configure_scope() as scope:
- assert scope._tags["request_data"]
+ assert sentry_sdk.get_isolation_scope()._tags["request_data"]
yield (str(row) + "\n").encode()
@@ -388,6 +409,105 @@ def generator():
expected_response = "".join(str(row) + "\n" for row in range(1000))
assert response.text == expected_response
assert not events
+ assert not sentry_sdk.get_isolation_scope()._tags["request_data"]
+
+
+@pytest.mark.skipif(
+ not hasattr(falcon, "asgi"), reason="This Falcon version lacks ASGI support."
+)
+def test_falcon_not_breaking_asgi(sentry_init):
+ """
+ This test simply verifies that the Falcon integration does not break ASGI
+ Falcon apps.
+
+ The test does not verify ASGI Falcon support, since our Falcon integration
+ currently lacks support for ASGI Falcon apps.
+ """
+ sentry_init(integrations=[FalconIntegration()])
+
+ asgi_app = falcon.asgi.App()
+
+ try:
+ falcon.inspect.inspect_app(asgi_app)
+ except TypeError:
+ pytest.fail("Falcon integration causing errors in ASGI apps.")
+
+
+@pytest.mark.skipif(
+ (FALCON_VERSION or ()) < (3,),
+ reason="The Sentry Falcon integration only supports custom error handlers on Falcon 3+",
+)
+def test_falcon_custom_error_handler(sentry_init, make_app, capture_events):
+ """
+ When a custom error handler handles what otherwise would have resulted in a 5xx error,
+ changing the HTTP status to a non-5xx status, no error event should be sent to Sentry.
+ """
+ sentry_init(integrations=[FalconIntegration()])
+ events = capture_events()
+
+ app = make_app()
+ client = falcon.testing.TestClient(app)
+
+ client.simulate_get("/custom-error")
+
+ assert len(events) == 0
+
+
+def test_span_origin(sentry_init, capture_events, make_client):
+ sentry_init(
+ integrations=[FalconIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = make_client()
+ client.simulate_get("/message")
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.falcon"
+
+
+def test_falcon_request_media(sentry_init):
+ # test_passed stores whether the test has passed.
+ test_passed = False
+
+ # test_failure_reason stores the reason why the test failed
+ # if test_passed is False. The value is meaningless when
+ # test_passed is True.
+ test_failure_reason = "test endpoint did not get called"
+
+ class SentryCaptureMiddleware:
+ def process_request(self, _req, _resp):
+ # This capture message forces Falcon event processors to run
+ # before the request handler runs
+ sentry_sdk.capture_message("Processing request")
+
+ class RequestMediaResource:
+ def on_post(self, req, _):
+ nonlocal test_passed, test_failure_reason
+ raw_data = req.bounded_stream.read()
+
+ # If the raw_data is empty, the request body stream
+ # has been exhausted by the SDK. Test should fail in
+ # this case.
+ test_passed = raw_data != b""
+ test_failure_reason = "request body has been read"
+
+ sentry_init(integrations=[FalconIntegration()])
+
+ try:
+ app_class = falcon.App # Falcon ≥3.0
+ except AttributeError:
+ app_class = falcon.API # Falcon <3.0
+
+ app = app_class(middleware=[SentryCaptureMiddleware()])
+ app.add_route("/read_body", RequestMediaResource())
+
+ client = falcon.testing.TestClient(app)
+
+ client.simulate_post("/read_body", json={"foo": "bar"})
- with sentry_sdk.configure_scope() as scope:
- assert not scope._tags["request_data"]
+ # Check that simulate_post actually calls the resource, and
+ # that the SDK does not exhaust the request body stream.
+ assert test_passed, test_failure_reason
diff --git a/tests/integrations/fastapi/test_fastapi.py b/tests/integrations/fastapi/test_fastapi.py
index 17b1cecd52..005189f00c 100644
--- a/tests/integrations/fastapi/test_fastapi.py
+++ b/tests/integrations/fastapi/test_fastapi.py
@@ -1,31 +1,55 @@
import json
-import threading
-
+import logging
import pytest
-from sentry_sdk.integrations.fastapi import FastApiIntegration
-
-fastapi = pytest.importorskip("fastapi")
+import threading
+import warnings
+from unittest import mock
-from fastapi import FastAPI
+import fastapi
+from fastapi import FastAPI, HTTPException, Request
from fastapi.testclient import TestClient
+from fastapi.middleware.trustedhost import TrustedHostMiddleware
+
+import sentry_sdk
from sentry_sdk import capture_message
-from sentry_sdk.integrations.starlette import StarletteIntegration
+from sentry_sdk.feature_flags import add_feature_flag
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
+from sentry_sdk.integrations.fastapi import FastApiIntegration
+from sentry_sdk.integrations.starlette import StarletteIntegration
+from sentry_sdk.utils import parse_version
+
-try:
- from unittest import mock # python 3.3 and above
-except ImportError:
- import mock # python < 3.3
+FASTAPI_VERSION = parse_version(fastapi.__version__)
+
+from tests.integrations.conftest import parametrize_test_configurable_status_codes
+from tests.integrations.starlette import test_starlette
def fastapi_app_factory():
app = FastAPI()
+ @app.get("/error")
+ async def _error():
+ capture_message("Hi")
+ 1 / 0
+ return {"message": "Hi"}
+
@app.get("/message")
async def _message():
capture_message("Hi")
return {"message": "Hi"}
+ @app.delete("/nomessage")
+ @app.get("/nomessage")
+ @app.head("/nomessage")
+ @app.options("/nomessage")
+ @app.patch("/nomessage")
+ @app.post("/nomessage")
+ @app.put("/nomessage")
+ @app.trace("/nomessage")
+ async def _nomessage():
+ return {"message": "nothing here..."}
+
@app.get("/message/{message_id}")
async def _message_with_id(message_id):
capture_message("Hi")
@@ -57,7 +81,6 @@ async def test_response(sentry_init, capture_events):
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=1.0,
send_default_pii=True,
- debug=True,
)
app = fastapi_app_factory()
@@ -160,11 +183,11 @@ def test_legacy_setup(
@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"])
-@mock.patch("sentry_sdk.profiler.PROFILE_MINIMUM_SAMPLES", 0)
+@mock.patch("sentry_sdk.profiler.transaction_profiler.PROFILE_MINIMUM_SAMPLES", 0)
def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, endpoint):
sentry_init(
traces_sample_rate=1.0,
- _experiments={"profiles_sample_rate": 1.0},
+ profiles_sample_rate=1.0,
)
app = fastapi_app_factory()
asgi_app = SentryAsgiMiddleware(app)
@@ -183,7 +206,557 @@ def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, en
profiles = [item for item in envelopes[0].items if item.type == "profile"]
assert len(profiles) == 1
- for profile in profiles:
- transactions = profile.payload.json["transactions"]
+ for item in profiles:
+ transactions = item.payload.json["transactions"]
assert len(transactions) == 1
assert str(data["active"]) == transactions[0]["active_thread_id"]
+
+ transactions = [item for item in envelopes[0].items if item.type == "transaction"]
+ assert len(transactions) == 1
+
+ for item in transactions:
+ transaction = item.payload.json
+ trace_context = transaction["contexts"]["trace"]
+ assert str(data["active"]) == trace_context["data"]["thread.id"]
+
+
+@pytest.mark.asyncio
+async def test_original_request_not_scrubbed(sentry_init, capture_events):
+ sentry_init(
+ integrations=[StarletteIntegration(), FastApiIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ app = FastAPI()
+
+ @app.post("/error")
+ async def _error(request: Request):
+ logging.critical("Oh no!")
+ assert request.headers["Authorization"] == "Bearer ohno"
+ assert await request.json() == {"password": "secret"}
+
+ return {"error": "Oh no!"}
+
+ events = capture_events()
+
+ client = TestClient(app)
+ client.post(
+ "/error", json={"password": "secret"}, headers={"Authorization": "Bearer ohno"}
+ )
+
+ event = events[0]
+ assert event["request"]["data"] == {"password": "[Filtered]"}
+ assert event["request"]["headers"]["authorization"] == "[Filtered]"
+
+
+def test_response_status_code_ok_in_transaction_context(sentry_init, capture_envelopes):
+ """
+ Tests that the response status code is added to the transaction "response" context.
+ """
+ sentry_init(
+ integrations=[StarletteIntegration(), FastApiIntegration()],
+ traces_sample_rate=1.0,
+ release="demo-release",
+ )
+
+ envelopes = capture_envelopes()
+
+ app = fastapi_app_factory()
+
+ client = TestClient(app)
+ client.get("/message")
+
+ (_, transaction_envelope) = envelopes
+ transaction = transaction_envelope.get_transaction_event()
+
+ assert transaction["type"] == "transaction"
+ assert len(transaction["contexts"]) > 0
+ assert "response" in transaction["contexts"].keys(), (
+ "Response context not found in transaction"
+ )
+ assert transaction["contexts"]["response"]["status_code"] == 200
+
+
+def test_response_status_code_error_in_transaction_context(
+ sentry_init,
+ capture_envelopes,
+):
+ """
+ Tests that the response status code is added to the transaction "response" context.
+ """
+ sentry_init(
+ integrations=[StarletteIntegration(), FastApiIntegration()],
+ traces_sample_rate=1.0,
+ release="demo-release",
+ )
+
+ envelopes = capture_envelopes()
+
+ app = fastapi_app_factory()
+
+ client = TestClient(app)
+ with pytest.raises(ZeroDivisionError):
+ client.get("/error")
+
+ (
+ _,
+ _,
+ transaction_envelope,
+ ) = envelopes
+ transaction = transaction_envelope.get_transaction_event()
+
+ assert transaction["type"] == "transaction"
+ assert len(transaction["contexts"]) > 0
+ assert "response" in transaction["contexts"].keys(), (
+ "Response context not found in transaction"
+ )
+ assert transaction["contexts"]["response"]["status_code"] == 500
+
+
+def test_response_status_code_not_found_in_transaction_context(
+ sentry_init,
+ capture_envelopes,
+):
+ """
+ Tests that the response status code is added to the transaction "response" context.
+ """
+ sentry_init(
+ integrations=[StarletteIntegration(), FastApiIntegration()],
+ traces_sample_rate=1.0,
+ release="demo-release",
+ )
+
+ envelopes = capture_envelopes()
+
+ app = fastapi_app_factory()
+
+ client = TestClient(app)
+ client.get("/non-existing-route-123")
+
+ (transaction_envelope,) = envelopes
+ transaction = transaction_envelope.get_transaction_event()
+
+ assert transaction["type"] == "transaction"
+ assert len(transaction["contexts"]) > 0
+ assert "response" in transaction["contexts"].keys(), (
+ "Response context not found in transaction"
+ )
+ assert transaction["contexts"]["response"]["status_code"] == 404
+
+
+@pytest.mark.parametrize(
+ "request_url,transaction_style,expected_transaction_name,expected_transaction_source",
+ [
+ (
+ "/message/123456",
+ "endpoint",
+ "tests.integrations.fastapi.test_fastapi.fastapi_app_factory.._message_with_id",
+ "component",
+ ),
+ (
+ "/message/123456",
+ "url",
+ "/message/{message_id}",
+ "route",
+ ),
+ ],
+)
+def test_transaction_name(
+ sentry_init,
+ request_url,
+ transaction_style,
+ expected_transaction_name,
+ expected_transaction_source,
+ capture_envelopes,
+):
+ """
+ Tests that the transaction name is something meaningful.
+ """
+ sentry_init(
+ auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request.
+ integrations=[
+ StarletteIntegration(transaction_style=transaction_style),
+ FastApiIntegration(transaction_style=transaction_style),
+ ],
+ traces_sample_rate=1.0,
+ )
+
+ envelopes = capture_envelopes()
+
+ app = fastapi_app_factory()
+
+ client = TestClient(app)
+ client.get(request_url)
+
+ (_, transaction_envelope) = envelopes
+ transaction_event = transaction_envelope.get_transaction_event()
+
+ assert transaction_event["transaction"] == expected_transaction_name
+ assert (
+ transaction_event["transaction_info"]["source"] == expected_transaction_source
+ )
+
+
+def test_route_endpoint_equal_dependant_call(sentry_init):
+ """
+ Tests that the route endpoint name is equal to the wrapped dependant call name.
+ """
+ sentry_init(
+ auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request.
+ integrations=[
+ StarletteIntegration(),
+ FastApiIntegration(),
+ ],
+ traces_sample_rate=1.0,
+ )
+
+ app = fastapi_app_factory()
+
+ for route in app.router.routes:
+ if not hasattr(route, "dependant"):
+ continue
+ assert route.endpoint.__qualname__ == route.dependant.call.__qualname__
+
+
+@pytest.mark.parametrize(
+ "request_url,transaction_style,expected_transaction_name,expected_transaction_source",
+ [
+ (
+ "/message/123456",
+ "endpoint",
+ "http://testserver/message/123456",
+ "url",
+ ),
+ (
+ "/message/123456",
+ "url",
+ "http://testserver/message/123456",
+ "url",
+ ),
+ ],
+)
+def test_transaction_name_in_traces_sampler(
+ sentry_init,
+ request_url,
+ transaction_style,
+ expected_transaction_name,
+ expected_transaction_source,
+):
+ """
+ Tests that a custom traces_sampler retrieves a meaningful transaction name.
+ In this case the URL or endpoint, because we do not have the route yet.
+ """
+
+ def dummy_traces_sampler(sampling_context):
+ assert (
+ sampling_context["transaction_context"]["name"] == expected_transaction_name
+ )
+ assert (
+ sampling_context["transaction_context"]["source"]
+ == expected_transaction_source
+ )
+
+ sentry_init(
+ auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request.
+ integrations=[StarletteIntegration(transaction_style=transaction_style)],
+ traces_sampler=dummy_traces_sampler,
+ traces_sample_rate=1.0,
+ )
+
+ app = fastapi_app_factory()
+
+ client = TestClient(app)
+ client.get(request_url)
+
+
+@pytest.mark.parametrize("middleware_spans", [False, True])
+@pytest.mark.parametrize(
+ "request_url,transaction_style,expected_transaction_name,expected_transaction_source",
+ [
+ (
+ "/message/123456",
+ "endpoint",
+ "starlette.middleware.trustedhost.TrustedHostMiddleware",
+ "component",
+ ),
+ (
+ "/message/123456",
+ "url",
+ "http://testserver/message/123456",
+ "url",
+ ),
+ ],
+)
+def test_transaction_name_in_middleware(
+ sentry_init,
+ middleware_spans,
+ request_url,
+ transaction_style,
+ expected_transaction_name,
+ expected_transaction_source,
+ capture_envelopes,
+):
+ """
+ Tests that the transaction name is something meaningful.
+ """
+ sentry_init(
+ auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request.
+ integrations=[
+ StarletteIntegration(
+ transaction_style=transaction_style, middleware_spans=middleware_spans
+ ),
+ FastApiIntegration(
+ transaction_style=transaction_style, middleware_spans=middleware_spans
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+
+ envelopes = capture_envelopes()
+
+ app = fastapi_app_factory()
+
+ app.add_middleware(
+ TrustedHostMiddleware,
+ allowed_hosts=[
+ "example.com",
+ ],
+ )
+
+ client = TestClient(app)
+ client.get(request_url)
+
+ (transaction_envelope,) = envelopes
+ transaction_event = transaction_envelope.get_transaction_event()
+
+ assert transaction_event["contexts"]["response"]["status_code"] == 400
+ assert transaction_event["transaction"] == expected_transaction_name
+ assert (
+ transaction_event["transaction_info"]["source"] == expected_transaction_source
+ )
+
+
+@test_starlette.parametrize_test_configurable_status_codes_deprecated
+def test_configurable_status_codes_deprecated(
+ sentry_init,
+ capture_events,
+ failed_request_status_codes,
+ status_code,
+ expected_error,
+):
+ with pytest.warns(DeprecationWarning):
+ starlette_integration = StarletteIntegration(
+ failed_request_status_codes=failed_request_status_codes
+ )
+
+ with pytest.warns(DeprecationWarning):
+ fast_api_integration = FastApiIntegration(
+ failed_request_status_codes=failed_request_status_codes
+ )
+
+ sentry_init(
+ integrations=[
+ starlette_integration,
+ fast_api_integration,
+ ]
+ )
+
+ events = capture_events()
+
+ app = FastAPI()
+
+ @app.get("/error")
+ async def _error():
+ raise HTTPException(status_code)
+
+ client = TestClient(app)
+ client.get("/error")
+
+ if expected_error:
+ assert len(events) == 1
+ else:
+ assert not events
+
+
+@pytest.mark.skipif(
+ FASTAPI_VERSION < (0, 80),
+ reason="Requires FastAPI >= 0.80, because earlier versions do not support HTTP 'HEAD' requests",
+)
+def test_transaction_http_method_default(sentry_init, capture_events):
+ """
+ By default OPTIONS and HEAD requests do not create a transaction.
+ """
+ # FastAPI is heavily based on Starlette so we also need
+ # to enable StarletteIntegration.
+ # In the future this will be auto enabled.
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[
+ StarletteIntegration(),
+ FastApiIntegration(),
+ ],
+ )
+
+ app = fastapi_app_factory()
+
+ events = capture_events()
+
+ client = TestClient(app)
+ client.get("/nomessage")
+ client.options("/nomessage")
+ client.head("/nomessage")
+
+ assert len(events) == 1
+
+ (event,) = events
+
+ assert event["request"]["method"] == "GET"
+
+
+@pytest.mark.skipif(
+ FASTAPI_VERSION < (0, 80),
+ reason="Requires FastAPI >= 0.80, because earlier versions do not support HTTP 'HEAD' requests",
+)
+def test_transaction_http_method_custom(sentry_init, capture_events):
+ # FastAPI is heavily based on Starlette so we also need
+ # to enable StarletteIntegration.
+ # In the future this will be auto enabled.
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[
+ StarletteIntegration(
+ http_methods_to_capture=(
+ "OPTIONS",
+ "head",
+ ), # capitalization does not matter
+ ),
+ FastApiIntegration(
+ http_methods_to_capture=(
+ "OPTIONS",
+ "head",
+ ), # capitalization does not matter
+ ),
+ ],
+ )
+
+ app = fastapi_app_factory()
+
+ events = capture_events()
+
+ client = TestClient(app)
+ client.get("/nomessage")
+ client.options("/nomessage")
+ client.head("/nomessage")
+
+ assert len(events) == 2
+
+ (event1, event2) = events
+
+ assert event1["request"]["method"] == "OPTIONS"
+ assert event2["request"]["method"] == "HEAD"
+
+
+@parametrize_test_configurable_status_codes
+def test_configurable_status_codes(
+ sentry_init,
+ capture_events,
+ failed_request_status_codes,
+ status_code,
+ expected_error,
+):
+ integration_kwargs = {}
+ if failed_request_status_codes is not None:
+ integration_kwargs["failed_request_status_codes"] = failed_request_status_codes
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", DeprecationWarning)
+ starlette_integration = StarletteIntegration(**integration_kwargs)
+ fastapi_integration = FastApiIntegration(**integration_kwargs)
+
+ sentry_init(integrations=[starlette_integration, fastapi_integration])
+
+ events = capture_events()
+
+ app = FastAPI()
+
+ @app.get("/error")
+ async def _error():
+ raise HTTPException(status_code)
+
+ client = TestClient(app)
+ client.get("/error")
+
+ assert len(events) == int(expected_error)
+
+
+@pytest.mark.parametrize("transaction_style", ["endpoint", "url"])
+def test_app_host(sentry_init, capture_events, transaction_style):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[
+ StarletteIntegration(transaction_style=transaction_style),
+ FastApiIntegration(transaction_style=transaction_style),
+ ],
+ )
+
+ app = FastAPI()
+ subapp = FastAPI()
+
+ @subapp.get("/subapp")
+ async def subapp_route():
+ return {"message": "Hello world!"}
+
+ app.host("subapp", subapp)
+
+ events = capture_events()
+
+ client = TestClient(app)
+ client.get("/subapp", headers={"Host": "subapp"})
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert "transaction" in event
+
+ if transaction_style == "url":
+ assert event["transaction"] == "/subapp"
+ else:
+ assert event["transaction"].endswith("subapp_route")
+
+
+@pytest.mark.asyncio
+async def test_feature_flags(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[StarletteIntegration(), FastApiIntegration()],
+ )
+
+ events = capture_events()
+
+ app = FastAPI()
+
+ @app.get("/error")
+ async def _error():
+ add_feature_flag("hello", False)
+
+ with sentry_sdk.start_span(name="test-span"):
+ with sentry_sdk.start_span(name="test-span-2"):
+ raise ValueError("something is wrong!")
+
+ try:
+ client = TestClient(app)
+ client.get("/error")
+ except ValueError:
+ pass
+
+ found = False
+ for event in events:
+ if "exception" in event.keys():
+ assert event["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": False},
+ ]
+ }
+ found = True
+
+ assert found, "No event with exception found"
diff --git a/tests/integrations/fastmcp/__init__.py b/tests/integrations/fastmcp/__init__.py
new file mode 100644
index 0000000000..01ef442500
--- /dev/null
+++ b/tests/integrations/fastmcp/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("mcp")
diff --git a/tests/integrations/fastmcp/test_fastmcp.py b/tests/integrations/fastmcp/test_fastmcp.py
new file mode 100644
index 0000000000..bcfd9a62d1
--- /dev/null
+++ b/tests/integrations/fastmcp/test_fastmcp.py
@@ -0,0 +1,1315 @@
+"""
+Unit tests for the Sentry MCP integration with FastMCP.
+
+This test suite verifies that Sentry's MCPIntegration properly instruments
+both FastMCP implementations:
+- mcp.server.fastmcp.FastMCP (FastMCP from the mcp package)
+- fastmcp.FastMCP (standalone fastmcp package)
+
+Tests focus on verifying Sentry integration behavior:
+- Integration doesn't break FastMCP functionality
+- Span creation when tools/prompts/resources are called through MCP protocol
+- Span data accuracy (operation, description, origin, etc.)
+- Error capture and instrumentation
+- PII and include_prompts flag behavior
+- Request context data extraction
+- Transport detection (stdio, http, sse)
+
+All tests invoke tools/prompts/resources through the MCP Server's low-level
+request handlers (via CallToolRequest, GetPromptRequest, ReadResourceRequest)
+to properly trigger Sentry instrumentation and span creation. This ensures
+accurate testing of the integration's behavior in real MCP Server scenarios.
+"""
+
+import anyio
+import asyncio
+import json
+import pytest
+from unittest import mock
+
+try:
+ from unittest.mock import AsyncMock
+except ImportError:
+
+ class AsyncMock(mock.MagicMock):
+ async def __call__(self, *args, **kwargs):
+ return super(AsyncMock, self).__call__(*args, **kwargs)
+
+
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import SPANDATA, OP
+from sentry_sdk.integrations.mcp import MCPIntegration
+
+from mcp.server.sse import SseServerTransport
+from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
+
+try:
+ from fastmcp.prompts import Message
+except ImportError:
+ Message = None
+
+
+from starlette.responses import Response
+from starlette.routing import Mount, Route
+from starlette.applications import Starlette
+
+# Try to import both FastMCP implementations
+try:
+ from mcp.server.fastmcp import FastMCP as MCPFastMCP
+
+ HAS_MCP_FASTMCP = True
+except ImportError:
+ HAS_MCP_FASTMCP = False
+ MCPFastMCP = None
+
+try:
+ from fastmcp import FastMCP as StandaloneFastMCP
+
+ HAS_STANDALONE_FASTMCP = True
+except ImportError:
+ HAS_STANDALONE_FASTMCP = False
+ StandaloneFastMCP = None
+
+# Try to import request_ctx for context testing
+try:
+ from mcp.server.lowlevel.server import request_ctx
+except ImportError:
+ request_ctx = None
+
+# Try to import MCP types for helper functions
+try:
+ from mcp.types import CallToolRequest, GetPromptRequest, ReadResourceRequest
+except ImportError:
+ # If mcp.types not available, tests will be skipped anyway
+ CallToolRequest = None
+ GetPromptRequest = None
+ ReadResourceRequest = None
+
+try:
+ from fastmcp import __version__ as FASTMCP_VERSION
+except ImportError:
+ FASTMCP_VERSION = None
+
+# Collect available FastMCP implementations for parametrization
+fastmcp_implementations = []
+fastmcp_ids = []
+
+if HAS_MCP_FASTMCP:
+ fastmcp_implementations.append(MCPFastMCP)
+ fastmcp_ids.append("mcp.server.fastmcp")
+
+if HAS_STANDALONE_FASTMCP:
+ fastmcp_implementations.append(StandaloneFastMCP)
+ fastmcp_ids.append("fastmcp")
+
+
+# Helper functions to call tools through MCP Server protocol
+def call_tool_through_mcp(mcp_instance, tool_name, arguments):
+ """
+ Call a tool through MCP Server's low-level handler.
+ This properly triggers Sentry instrumentation.
+
+ Args:
+ mcp_instance: The FastMCP instance
+ tool_name: Name of the tool to call
+ arguments: Dictionary of arguments to pass to the tool
+
+ Returns:
+ The tool result normalized to {"result": value} format
+ """
+ handler = mcp_instance._mcp_server.request_handlers[CallToolRequest]
+ request = CallToolRequest(
+ method="tools/call", params={"name": tool_name, "arguments": arguments}
+ )
+
+ result = asyncio.run(handler(request))
+
+ if hasattr(result, "root"):
+ result = result.root
+ if hasattr(result, "structuredContent") and result.structuredContent:
+ result = result.structuredContent
+ elif hasattr(result, "content"):
+ if result.content:
+ text = result.content[0].text
+ try:
+ result = json.loads(text)
+ except (json.JSONDecodeError, TypeError):
+ result = text
+ else:
+ # Empty content means None return
+ result = None
+
+ # Normalize return value to consistent format
+ # If already a dict, return as-is (tool functions return dicts directly)
+ if isinstance(result, dict):
+ return result
+
+ # Handle string "None" or "null" as actual None
+ if isinstance(result, str) and result in ("None", "null"):
+ result = None
+
+ # Wrap primitive values (int, str, bool, None) in dict format for consistency
+ return {"result": result}
+
+
+async def call_tool_through_mcp_async(mcp_instance, tool_name, arguments):
+ """Async version of call_tool_through_mcp."""
+ handler = mcp_instance._mcp_server.request_handlers[CallToolRequest]
+ request = CallToolRequest(
+ method="tools/call", params={"name": tool_name, "arguments": arguments}
+ )
+
+ result = await handler(request)
+
+ if hasattr(result, "root"):
+ result = result.root
+ if hasattr(result, "structuredContent") and result.structuredContent:
+ result = result.structuredContent
+ elif hasattr(result, "content"):
+ if result.content:
+ text = result.content[0].text
+ try:
+ result = json.loads(text)
+ except (json.JSONDecodeError, TypeError):
+ result = text
+ else:
+ # Empty content means None return
+ result = None
+
+ # Normalize return value to consistent format
+ # If already a dict, return as-is (tool functions return dicts directly)
+ if isinstance(result, dict):
+ return result
+
+ # Handle string "None" or "null" as actual None
+ if isinstance(result, str) and result in ("None", "null"):
+ result = None
+
+ # Wrap primitive values (int, str, bool, None) in dict format for consistency
+ return {"result": result}
+
+
+def call_prompt_through_mcp(mcp_instance, prompt_name, arguments=None):
+ """Call a prompt through MCP Server's low-level handler."""
+ handler = mcp_instance._mcp_server.request_handlers[GetPromptRequest]
+ request = GetPromptRequest(
+ method="prompts/get", params={"name": prompt_name, "arguments": arguments or {}}
+ )
+
+ result = asyncio.run(handler(request))
+ if hasattr(result, "root"):
+ result = result.root
+ return result
+
+
+async def call_prompt_through_mcp_async(mcp_instance, prompt_name, arguments=None):
+ """Async version of call_prompt_through_mcp."""
+ handler = mcp_instance._mcp_server.request_handlers[GetPromptRequest]
+ request = GetPromptRequest(
+ method="prompts/get", params={"name": prompt_name, "arguments": arguments or {}}
+ )
+
+ result = await handler(request)
+ if hasattr(result, "root"):
+ result = result.root
+ return result
+
+
+def call_resource_through_mcp(mcp_instance, uri):
+ """Call a resource through MCP Server's low-level handler."""
+ handler = mcp_instance._mcp_server.request_handlers[ReadResourceRequest]
+ request = ReadResourceRequest(method="resources/read", params={"uri": str(uri)})
+
+ result = asyncio.run(handler(request))
+ if hasattr(result, "root"):
+ result = result.root
+ return result
+
+
+async def call_resource_through_mcp_async(mcp_instance, uri):
+ """Async version of call_resource_through_mcp."""
+ handler = mcp_instance._mcp_server.request_handlers[ReadResourceRequest]
+ request = ReadResourceRequest(method="resources/read", params={"uri": str(uri)})
+
+ result = await handler(request)
+ if hasattr(result, "root"):
+ result = result.root
+ return result
+
+
+# Skip all tests if neither implementation is available
+pytestmark = pytest.mark.skipif(
+ not (HAS_MCP_FASTMCP or HAS_STANDALONE_FASTMCP),
+ reason="Neither mcp.fastmcp nor standalone fastmcp is installed",
+)
+
+
+@pytest.fixture(autouse=True)
+def reset_request_ctx():
+ """Reset request context before and after each test"""
+ if request_ctx is not None:
+ try:
+ if request_ctx.get() is not None:
+ request_ctx.set(None)
+ except LookupError:
+ pass
+
+ yield
+
+ if request_ctx is not None:
+ try:
+ request_ctx.set(None)
+ except LookupError:
+ pass
+
+
+# =============================================================================
+# Tool Handler Tests - Verifying Sentry Integration
+# =============================================================================
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+async def test_fastmcp_tool_sync(
+ sentry_init, capture_events, FastMCP, send_default_pii, include_prompts, stdio
+):
+ """Test that FastMCP synchronous tool handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def add_numbers(a: int, b: int) -> dict:
+ """Add two numbers together"""
+ return {"result": a + b, "operation": "addition"}
+
+ with start_transaction(name="fastmcp tx"):
+ # Call through MCP protocol to trigger instrumentation
+ result = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "add_numbers",
+ "arguments": {"a": 10, "b": 5},
+ },
+ request_id="req-123",
+ )
+
+ assert json.loads(result.message.root.result["content"][0]["text"]) == {
+ "result": 15,
+ "operation": "addition",
+ }
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+
+ # Verify span structure
+ span = tx["spans"][0]
+ assert span["op"] == OP.MCP_SERVER
+ assert span["origin"] == "auto.ai.mcp"
+ assert span["description"] == "tools/call add_numbers"
+ assert span["data"][SPANDATA.MCP_TOOL_NAME] == "add_numbers"
+ assert span["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call"
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio"
+ assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-123"
+
+ # Check PII-sensitive data
+ if send_default_pii and include_prompts:
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT in span["data"]
+ else:
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"]
+
+
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+async def test_fastmcp_tool_async(
+ sentry_init,
+ capture_events,
+ FastMCP,
+ send_default_pii,
+ include_prompts,
+ json_rpc,
+ select_transactions_with_mcp_spans,
+):
+ """Test that FastMCP async tool handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ session_manager = StreamableHTTPSessionManager(
+ app=mcp._mcp_server,
+ json_response=True,
+ )
+
+ app = Starlette(
+ routes=[
+ Mount("/mcp", app=session_manager.handle_request),
+ ],
+ lifespan=lambda app: session_manager.run(),
+ )
+
+ @mcp.tool()
+ async def multiply_numbers(x: int, y: int) -> dict:
+ """Multiply two numbers together"""
+ return {"result": x * y, "operation": "multiplication"}
+
+ session_id, result = json_rpc(
+ app,
+ method="tools/call",
+ params={
+ "name": "multiply_numbers",
+ "arguments": {"x": 7, "y": 6},
+ },
+ request_id="req-456",
+ )
+
+ assert json.loads(result.json()["result"]["content"][0]["text"]) == {
+ "result": 42,
+ "operation": "multiplication",
+ }
+
+ transactions = select_transactions_with_mcp_spans(events, method_name="tools/call")
+ assert len(transactions) == 1
+ tx = transactions[0]
+ assert len(tx["spans"]) == 1
+ span = tx["spans"][0]
+
+ assert span["op"] == OP.MCP_SERVER
+ assert span["origin"] == "auto.ai.mcp"
+ assert span["description"] == "tools/call multiply_numbers"
+ assert span["data"][SPANDATA.MCP_TOOL_NAME] == "multiply_numbers"
+ assert span["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call"
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "http"
+ assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-456"
+ assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id
+
+ # Check PII-sensitive data
+ if send_default_pii and include_prompts:
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT in span["data"]
+ else:
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_tool_with_error(sentry_init, capture_events, FastMCP, stdio):
+ """Test that FastMCP tool handler errors are captured properly"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def failing_tool(value: int) -> int:
+ """A tool that always fails"""
+ raise ValueError("Tool execution failed")
+
+ with start_transaction(name="fastmcp tx"):
+ result = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "failing_tool",
+ "arguments": {"value": 42},
+ },
+ request_id="req-error",
+ )
+ # If no exception raised, check if result indicates error
+ assert result.message.root.result["isError"] is True
+
+ # Should have transaction and error events
+ assert len(events) >= 1
+
+ # Check span was created
+ tx = [e for e in events if e.get("type") == "transaction"][0]
+ tool_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(tool_spans) == 1
+
+ # Check error event was captured
+ error_events = [e for e in events if e.get("level") == "error"]
+ assert len(error_events) >= 1
+ error_event = error_events[0]
+ assert error_event["exception"]["values"][0]["type"] == "ValueError"
+ assert error_event["exception"]["values"][0]["value"] == "Tool execution failed"
+ # Verify span is marked with error
+ assert tool_spans[0]["data"][SPANDATA.MCP_TOOL_RESULT_IS_ERROR] is True
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_multiple_tools(sentry_init, capture_events, FastMCP, stdio):
+ """Test that multiple FastMCP tool calls create multiple spans"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def tool_one(x: int) -> int:
+ """First tool"""
+ return x * 2
+
+ @mcp.tool()
+ def tool_two(y: int) -> int:
+ """Second tool"""
+ return y + 10
+
+ @mcp.tool()
+ def tool_three(z: int) -> int:
+ """Third tool"""
+ return z - 5
+
+ with start_transaction(name="fastmcp tx"):
+ result1 = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "tool_one",
+ "arguments": {"x": 5},
+ },
+ request_id="req-multi",
+ )
+
+ result2 = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "tool_two",
+ "arguments": {
+ "y": int(result1.message.root.result["content"][0]["text"])
+ },
+ },
+ request_id="req-multi",
+ )
+
+ result3 = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "tool_three",
+ "arguments": {
+ "z": int(result2.message.root.result["content"][0]["text"])
+ },
+ },
+ request_id="req-multi",
+ )
+
+ assert result1.message.root.result["content"][0]["text"] == "10"
+ assert result2.message.root.result["content"][0]["text"] == "20"
+ assert result3.message.root.result["content"][0]["text"] == "15"
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+ # Verify three spans were created
+ tool_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(tool_spans) == 3
+ assert tool_spans[0]["data"][SPANDATA.MCP_TOOL_NAME] == "tool_one"
+ assert tool_spans[1]["data"][SPANDATA.MCP_TOOL_NAME] == "tool_two"
+ assert tool_spans[2]["data"][SPANDATA.MCP_TOOL_NAME] == "tool_three"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_tool_with_complex_return(
+ sentry_init, capture_events, FastMCP, stdio
+):
+ """Test FastMCP tool with complex nested return value"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def get_user_data(user_id: int) -> dict:
+ """Get complex user data"""
+ return {
+ "id": user_id,
+ "name": "Alice",
+ "nested": {"preferences": {"theme": "dark", "notifications": True}},
+ "tags": ["admin", "verified"],
+ }
+
+ with start_transaction(name="fastmcp tx"):
+ result = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "get_user_data",
+ "arguments": {"user_id": 123},
+ },
+ request_id="req-complex",
+ )
+
+ assert json.loads(result.message.root.result["content"][0]["text"]) == {
+ "id": 123,
+ "name": "Alice",
+ "nested": {"preferences": {"theme": "dark", "notifications": True}},
+ "tags": ["admin", "verified"],
+ }
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+ # Verify span was created with complex data
+ tool_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(tool_spans) == 1
+ assert tool_spans[0]["op"] == OP.MCP_SERVER
+ assert tool_spans[0]["data"][SPANDATA.MCP_TOOL_NAME] == "get_user_data"
+ # Complex return value should be captured since include_prompts=True and send_default_pii=True
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT in tool_spans[0]["data"]
+
+
+# =============================================================================
+# Prompt Handler Tests (if supported)
+# =============================================================================
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (False, False)],
+)
+async def test_fastmcp_prompt_sync(
+ sentry_init, capture_events, FastMCP, send_default_pii, include_prompts, stdio
+):
+ """Test that FastMCP synchronous prompt handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ # Try to register a prompt handler (may not be supported in all versions)
+ if hasattr(mcp, "prompt"):
+
+ @mcp.prompt()
+ def code_help_prompt(language: str):
+ """Get help for a programming language"""
+ message = {
+ "role": "user",
+ "content": {
+ "type": "text",
+ "text": f"Tell me about {language}",
+ },
+ }
+
+ if FASTMCP_VERSION is not None and FASTMCP_VERSION.startswith("3"):
+ message = Message(message)
+
+ return [message]
+
+ with start_transaction(name="fastmcp tx"):
+ result = await stdio(
+ mcp._mcp_server,
+ method="prompts/get",
+ params={
+ "name": "code_help_prompt",
+ "arguments": {"language": "python"},
+ },
+ request_id="req-prompt",
+ )
+
+ assert result.message.root.result["messages"][0]["role"] == "user"
+ assert (
+ "python"
+ in result.message.root.result["messages"][0]["content"]["text"].lower()
+ )
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+ # Verify prompt span was created
+ prompt_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(prompt_spans) == 1
+ span = prompt_spans[0]
+ assert span["origin"] == "auto.ai.mcp"
+ assert span["description"] == "prompts/get code_help_prompt"
+ assert span["data"][SPANDATA.MCP_PROMPT_NAME] == "code_help_prompt"
+
+ # Check PII-sensitive data
+ if send_default_pii and include_prompts:
+ assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT in span["data"]
+ else:
+ assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT not in span["data"]
+
+
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+@pytest.mark.asyncio
+async def test_fastmcp_prompt_async(
+ sentry_init,
+ capture_events,
+ FastMCP,
+ json_rpc,
+ select_transactions_with_mcp_spans,
+):
+ """Test that FastMCP async prompt handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ session_manager = StreamableHTTPSessionManager(
+ app=mcp._mcp_server,
+ json_response=True,
+ )
+
+ app = Starlette(
+ routes=[
+ Mount("/mcp", app=session_manager.handle_request),
+ ],
+ lifespan=lambda app: session_manager.run(),
+ )
+
+ # Try to register an async prompt handler
+ if hasattr(mcp, "prompt"):
+
+ @mcp.prompt()
+ async def async_prompt(topic: str):
+ """Get async prompt for a topic"""
+ message1 = {
+ "role": "user",
+ "content": {"type": "text", "text": f"What is {topic}?"},
+ }
+
+ message2 = {
+ "role": "assistant",
+ "content": {
+ "type": "text",
+ "text": "Let me explain that",
+ },
+ }
+
+ if FASTMCP_VERSION is not None and FASTMCP_VERSION.startswith("3"):
+ message1 = Message(message1)
+ message2 = Message(message2)
+
+ return [message1, message2]
+
+ _, result = json_rpc(
+ app,
+ method="prompts/get",
+ params={
+ "name": "async_prompt",
+ "arguments": {"topic": "MCP"},
+ },
+ request_id="req-async-prompt",
+ )
+
+ assert len(result.json()["result"]["messages"]) == 2
+
+ transactions = select_transactions_with_mcp_spans(
+ events, method_name="prompts/get"
+ )
+ assert len(transactions) == 1
+
+
+# =============================================================================
+# Resource Handler Tests (if supported)
+# =============================================================================
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_resource_sync(sentry_init, capture_events, FastMCP, stdio):
+ """Test that FastMCP synchronous resource handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ # Try to register a resource handler
+ try:
+ if hasattr(mcp, "resource"):
+
+ @mcp.resource("file:///{path}")
+ def read_file(path: str):
+ """Read a file resource"""
+ return "file contents"
+
+ with start_transaction(name="fastmcp tx"):
+ try:
+ result = await stdio(
+ mcp._mcp_server,
+ method="resources/read",
+ params={
+ "uri": "file:///test.txt",
+ },
+ request_id="req-resource",
+ )
+ except ValueError as e:
+ # Older FastMCP versions may not support this URI pattern
+ if "Unknown resource" in str(e):
+ pytest.skip(
+ f"Resource URI not supported in this FastMCP version: {e}"
+ )
+ raise
+
+ # Resource content is returned as-is
+ assert "file contents" in result.message.root.result["contents"][0]["text"]
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+ # Verify resource span was created
+ resource_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(resource_spans) == 1
+ span = resource_spans[0]
+ assert span["origin"] == "auto.ai.mcp"
+ assert span["description"] == "resources/read file:///test.txt"
+ assert span["data"][SPANDATA.MCP_RESOURCE_PROTOCOL] == "file"
+ except (AttributeError, TypeError):
+ # Resource handler not supported in this version
+ pytest.skip("Resource handlers not supported in this FastMCP version")
+
+
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+@pytest.mark.asyncio
+async def test_fastmcp_resource_async(
+ sentry_init,
+ capture_events,
+ FastMCP,
+ json_rpc,
+ select_transactions_with_mcp_spans,
+):
+ """Test that FastMCP async resource handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ session_manager = StreamableHTTPSessionManager(
+ app=mcp._mcp_server,
+ json_response=True,
+ )
+
+ app = Starlette(
+ routes=[
+ Mount("/mcp", app=session_manager.handle_request),
+ ],
+ lifespan=lambda app: session_manager.run(),
+ )
+
+ # Try to register an async resource handler
+ try:
+ if hasattr(mcp, "resource"):
+
+ @mcp.resource("https://example.com/{resource}")
+ async def read_url(resource: str):
+ """Read a URL resource"""
+ return "resource data"
+
+ _, result = json_rpc(
+ app,
+ method="resources/read",
+ params={
+ "uri": "https://example.com/resource",
+ },
+ request_id="req-async-resource",
+ )
+ # Older FastMCP versions may not support this URI pattern
+ if (
+ "error" in result.json()
+ and "Unknown resource" in result.json()["error"]["message"]
+ ):
+ pytest.skip("Resource URI not supported in this FastMCP version.")
+ return
+
+ assert "resource data" in result.json()["result"]["contents"][0]["text"]
+
+ transactions = select_transactions_with_mcp_spans(
+ events, method_name="resources/read"
+ )
+ assert len(transactions) == 1
+ tx = transactions[0]
+ assert len(tx["spans"]) == 1
+ span = tx["spans"][0]
+
+ assert span["data"][SPANDATA.MCP_RESOURCE_PROTOCOL] == "https"
+ except (AttributeError, TypeError):
+ # Resource handler not supported in this version
+ pytest.skip("Resource handlers not supported in this FastMCP version")
+
+
+# =============================================================================
+# Span Origin and Metadata Tests
+# =============================================================================
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_span_origin(sentry_init, capture_events, FastMCP, stdio):
+ """Test that FastMCP span origin is set correctly"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def test_tool(value: int) -> int:
+ """Test tool for origin checking"""
+ return value * 2
+
+ with start_transaction(name="fastmcp tx"):
+ await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "test_tool",
+ "arguments": {"value": 21},
+ },
+ request_id="req-origin",
+ )
+
+ (tx,) = events
+
+ assert tx["contexts"]["trace"]["origin"] == "manual"
+
+ # Verify MCP span has correct origin
+ mcp_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(mcp_spans) == 1
+ assert mcp_spans[0]["origin"] == "auto.ai.mcp"
+
+
+# =============================================================================
+# Transport Detection Tests
+# =============================================================================
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_sse_transport(
+ sentry_init, capture_events, FastMCP, json_rpc_sse
+):
+ """Test that FastMCP correctly detects SSE transport"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+ sse = SseServerTransport("/messages/")
+
+ sse_connection_closed = asyncio.Event()
+
+ async def handle_sse(request):
+ async with sse.connect_sse(
+ request.scope, request.receive, request._send
+ ) as streams:
+ async with anyio.create_task_group() as tg:
+
+ async def run_server():
+ await mcp._mcp_server.run(
+ streams[0],
+ streams[1],
+ mcp._mcp_server.create_initialization_options(),
+ )
+
+ tg.start_soon(run_server)
+
+ sse_connection_closed.set()
+ return Response()
+
+ app = Starlette(
+ routes=[
+ Route("/sse", endpoint=handle_sse, methods=["GET"]),
+ Mount("/messages/", app=sse.handle_post_message),
+ ],
+ )
+
+ @mcp.tool()
+ def sse_tool(value: str) -> dict:
+ """Tool for SSE transport test"""
+ return {"message": f"Received: {value}"}
+
+ keep_sse_alive = asyncio.Event()
+ app_task, _, result = await json_rpc_sse(
+ app,
+ method="tools/call",
+ params={
+ "name": "sse_tool",
+ "arguments": {"value": "hello"},
+ },
+ request_id="req-sse",
+ keep_sse_alive=keep_sse_alive,
+ )
+
+ await sse_connection_closed.wait()
+ await app_task
+
+ assert json.loads(result["result"]["content"][0]["text"]) == {
+ "message": "Received: hello"
+ }
+
+ transactions = [
+ event
+ for event in events
+ if event["type"] == "transaction" and event["transaction"] == "/sse"
+ ]
+ assert len(transactions) == 1
+ tx = transactions[0]
+
+ # Find MCP spans
+ mcp_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(mcp_spans) >= 1
+ span = mcp_spans[0]
+ # Check that SSE transport is detected
+ assert span["data"].get(SPANDATA.MCP_TRANSPORT) == "sse"
+
+
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+def test_fastmcp_http_transport(
+ sentry_init,
+ capture_events,
+ FastMCP,
+ json_rpc,
+ select_transactions_with_mcp_spans,
+):
+ """Test that FastMCP correctly detects HTTP transport"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ session_manager = StreamableHTTPSessionManager(
+ app=mcp._mcp_server,
+ json_response=True,
+ )
+
+ app = Starlette(
+ routes=[
+ Mount("/mcp", app=session_manager.handle_request),
+ ],
+ lifespan=lambda app: session_manager.run(),
+ )
+
+ @mcp.tool()
+ def http_tool(data: str) -> dict:
+ """Tool for HTTP transport test"""
+ return {"processed": data.upper()}
+
+ _, result = json_rpc(
+ app,
+ method="tools/call",
+ params={
+ "name": "http_tool",
+ "arguments": {"data": "test"},
+ },
+ request_id="req-http",
+ )
+
+ assert json.loads(result.json()["result"]["content"][0]["text"]) == {
+ "processed": "TEST"
+ }
+
+ transactions = select_transactions_with_mcp_spans(events, method_name="tools/call")
+ assert len(transactions) == 1
+ tx = transactions[0]
+ assert len(tx["spans"]) == 1
+ span = tx["spans"][0]
+
+ # Check that HTTP transport is detected
+ assert span["data"].get(SPANDATA.MCP_TRANSPORT) == "http"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_stdio_transport(sentry_init, capture_events, FastMCP, stdio):
+ """Test that FastMCP correctly detects stdio transport"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def stdio_tool(n: int) -> dict:
+ """Tool for stdio transport test"""
+ return {"squared": n * n}
+
+ with start_transaction(name="fastmcp tx"):
+ result = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "stdio_tool",
+ "arguments": {"n": 7},
+ },
+ request_id="req-stdio",
+ )
+
+ assert json.loads(result.message.root.result["content"][0]["text"]) == {
+ "squared": 49
+ }
+
+ (tx,) = events
+
+ # Find MCP spans
+ mcp_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(mcp_spans) >= 1
+ span = mcp_spans[0]
+ # Check that stdio transport is detected
+ assert span["data"].get(SPANDATA.MCP_TRANSPORT) == "stdio"
+
+
+# =============================================================================
+# Integration-specific Tests
+# =============================================================================
+
+
+@pytest.mark.skipif(not HAS_MCP_FASTMCP, reason="mcp.server.fastmcp not installed")
+def test_mcp_fastmcp_specific_features(sentry_init, capture_events):
+ """Test features specific to mcp.server.fastmcp (from mcp package)"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ from mcp.server.fastmcp import FastMCP
+
+ mcp = FastMCP("MCP Package Server")
+
+ @mcp.tool()
+ def package_specific_tool(x: int) -> int:
+ """Tool for mcp.server.fastmcp package"""
+ return x + 100
+
+ with start_transaction(name="mcp.server.fastmcp tx"):
+ result = call_tool_through_mcp(mcp, "package_specific_tool", {"x": 50})
+
+ assert result["result"] == 150
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ not HAS_STANDALONE_FASTMCP, reason="standalone fastmcp not installed"
+)
+async def test_standalone_fastmcp_specific_features(sentry_init, capture_events, stdio):
+ """Test features specific to standalone fastmcp package"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ from fastmcp import FastMCP
+
+ mcp = FastMCP("Standalone FastMCP Server")
+
+ @mcp.tool()
+ def standalone_specific_tool(message: str) -> dict:
+ """Tool for standalone fastmcp package"""
+ return {"echo": message, "length": len(message)}
+
+ with start_transaction(name="standalone fastmcp tx"):
+ result = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "standalone_specific_tool",
+ "arguments": {"message": "Hello FastMCP"},
+ },
+ )
+
+ assert json.loads(result.message.root.result["content"][0]["text"]) == {
+ "echo": "Hello FastMCP",
+ "length": 13,
+ }
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+
+# =============================================================================
+# Edge Cases and Robustness Tests
+# =============================================================================
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_tool_with_no_arguments(
+ sentry_init, capture_events, FastMCP, stdio
+):
+ """Test FastMCP tool with no arguments"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def no_args_tool() -> str:
+ """Tool that takes no arguments"""
+ return "success"
+
+ with start_transaction(name="fastmcp tx"):
+ result = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "no_args_tool",
+ "arguments": {},
+ },
+ )
+
+ assert result.message.root.result["content"][0]["text"] == "success"
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_tool_with_none_return(
+ sentry_init, capture_events, FastMCP, stdio
+):
+ """Test FastMCP tool that returns None"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def none_return_tool(action: str) -> None:
+ """Tool that returns None"""
+ pass
+
+ with start_transaction(name="fastmcp tx"):
+ result = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "none_return_tool",
+ "arguments": {"action": "log"},
+ },
+ )
+
+ if (
+ isinstance(mcp, StandaloneFastMCP) and FASTMCP_VERSION is not None
+ ) or isinstance(mcp, MCPFastMCP):
+ assert len(result.message.root.result["content"]) == 0
+ else:
+ assert result.message.root.result["content"] == [
+ {"type": "text", "text": "None"}
+ ]
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
+async def test_fastmcp_mixed_sync_async_tools(
+ sentry_init, capture_events, FastMCP, stdio
+):
+ """Test mixing sync and async tools in FastMCP"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mcp = FastMCP("Test Server")
+
+ @mcp.tool()
+ def sync_add(a: int, b: int) -> int:
+ """Sync addition"""
+ return a + b
+
+ @mcp.tool()
+ async def async_multiply(x: int, y: int) -> int:
+ """Async multiplication"""
+ return x * y
+
+ with start_transaction(name="fastmcp tx"):
+ # Use async version for both since we're in an async context
+ result1 = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "sync_add",
+ "arguments": {"a": 3, "b": 4},
+ },
+ request_id="req-mixed",
+ )
+ result2 = await stdio(
+ mcp._mcp_server,
+ method="tools/call",
+ params={
+ "name": "async_multiply",
+ "arguments": {"x": 5, "y": 6},
+ },
+ request_id="req-mixed",
+ )
+
+ assert result1.message.root.result["content"][0]["text"] == "7"
+ assert result2.message.root.result["content"][0]["text"] == "30"
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+
+ # Verify both sync and async tool spans were created
+ mcp_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
+ assert len(mcp_spans) == 2
+ assert mcp_spans[0]["data"][SPANDATA.MCP_TOOL_NAME] == "sync_add"
+ assert mcp_spans[1]["data"][SPANDATA.MCP_TOOL_NAME] == "async_multiply"
diff --git a/tests/integrations/flask/__init__.py b/tests/integrations/flask/__init__.py
new file mode 100644
index 0000000000..601f9ed8d5
--- /dev/null
+++ b/tests/integrations/flask/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("flask")
diff --git a/tests/integrations/flask/test_flask.py b/tests/integrations/flask/test_flask.py
index 8983c4e5ff..e117b98ca9 100644
--- a/tests/integrations/flask/test_flask.py
+++ b/tests/integrations/flask/test_flask.py
@@ -1,11 +1,9 @@
import json
-import pytest
+import re
import logging
-
from io import BytesIO
-flask = pytest.importorskip("flask")
-
+import pytest
from flask import (
Flask,
Response,
@@ -15,19 +13,23 @@
render_template_string,
)
from flask.views import View
-
from flask_login import LoginManager, login_user
+try:
+ from werkzeug.wrappers.request import UnsupportedMediaType
+except ImportError:
+ UnsupportedMediaType = None
+
+import sentry_sdk
+import sentry_sdk.integrations.flask as flask_sentry
from sentry_sdk import (
set_tag,
- configure_scope,
capture_message,
capture_exception,
- last_event_id,
- Hub,
)
+from sentry_sdk.consts import DEFAULT_MAX_VALUE_LENGTH
from sentry_sdk.integrations.logging import LoggingIntegration
-import sentry_sdk.integrations.flask as flask_sentry
+from sentry_sdk.serializer import MAX_DATABAG_BREADTH
login_manager = LoginManager()
@@ -46,6 +48,10 @@ def hi():
capture_message("hi")
return "ok"
+ @app.route("/nomessage")
+ def nohi():
+ return "ok"
+
@app.route("/message/")
def hi_with_id(message_id):
capture_message("hi again")
@@ -123,7 +129,7 @@ def test_errors(
testing,
integration_enabled_params,
):
- sentry_init(debug=True, **integration_enabled_params)
+ sentry_init(**integration_enabled_params)
app.debug = debug
app.testing = testing
@@ -209,7 +215,7 @@ def test_flask_login_configured(
):
sentry_init(send_default_pii=send_default_pii, **integration_enabled_params)
- class User(object):
+ class User:
is_authenticated = is_active = True
is_anonymous = user_id is not None
@@ -243,9 +249,11 @@ def login():
def test_flask_large_json_request(sentry_init, capture_events, app):
- sentry_init(integrations=[flask_sentry.FlaskIntegration()])
+ sentry_init(
+ integrations=[flask_sentry.FlaskIntegration()], max_request_body_size="always"
+ )
- data = {"foo": {"bar": "a" * 2000}}
+ data = {"foo": {"bar": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10)}}
@app.route("/", methods=["POST"])
def index():
@@ -263,9 +271,14 @@ def index():
(event,) = events
assert event["_meta"]["request"]["data"]["foo"]["bar"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]["bar"]) == 1024
+ assert len(event["request"]["data"]["foo"]["bar"]) == DEFAULT_MAX_VALUE_LENGTH
def test_flask_session_tracking(sentry_init, capture_envelopes, app):
@@ -276,8 +289,7 @@ def test_flask_session_tracking(sentry_init, capture_envelopes, app):
@app.route("/")
def index():
- with configure_scope() as scope:
- scope.set_user({"ip_address": "1.2.3.4", "id": "42"})
+ sentry_sdk.get_isolation_scope().set_user({"ip_address": "1.2.3.4", "id": "42"})
try:
raise ValueError("stuff")
except Exception:
@@ -292,7 +304,7 @@ def index():
except ZeroDivisionError:
pass
- Hub.current.client.flush()
+ sentry_sdk.get_client().flush()
(first_event, error_event, session) = envelopes
first_event = first_event.get_event()
@@ -332,15 +344,21 @@ def index():
def test_flask_medium_formdata_request(sentry_init, capture_events, app):
- sentry_init(integrations=[flask_sentry.FlaskIntegration()])
+ sentry_init(
+ integrations=[flask_sentry.FlaskIntegration()], max_request_body_size="always"
+ )
- data = {"foo": "a" * 2000}
+ data = {"foo": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10)}
@app.route("/", methods=["POST"])
def index():
assert request.form["foo"] == data["foo"]
assert not request.get_data()
- assert not request.get_json()
+ try:
+ assert not request.get_json()
+ except UnsupportedMediaType:
+ # flask/werkzeug 3
+ pass
capture_message("hi")
return "ok"
@@ -352,9 +370,14 @@ def index():
(event,) = events
assert event["_meta"]["request"]["data"]["foo"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]) == 1024
+ assert len(event["request"]["data"]["foo"]) == DEFAULT_MAX_VALUE_LENGTH
def test_flask_formdata_request_appear_transaction_body(
@@ -372,7 +395,11 @@ def index():
assert request.form["username"] == data["username"]
assert request.form["age"] == data["age"]
assert not request.get_data()
- assert not request.get_json()
+ try:
+ assert not request.get_json()
+ except UnsupportedMediaType:
+ # flask/werkzeug 3
+ pass
set_tag("view", "yes")
capture_message("hi")
return "ok"
@@ -392,7 +419,9 @@ def index():
@pytest.mark.parametrize("input_char", ["a", b"a"])
def test_flask_too_large_raw_request(sentry_init, input_char, capture_events, app):
- sentry_init(integrations=[flask_sentry.FlaskIntegration()], request_bodies="small")
+ sentry_init(
+ integrations=[flask_sentry.FlaskIntegration()], max_request_body_size="small"
+ )
data = input_char * 2000
@@ -403,7 +432,11 @@ def index():
assert request.get_data() == data
else:
assert request.get_data() == data.encode("ascii")
- assert not request.get_json()
+ try:
+ assert not request.get_json()
+ except UnsupportedMediaType:
+ # flask/werkzeug 3
+ pass
capture_message("hi")
return "ok"
@@ -419,15 +452,24 @@ def index():
def test_flask_files_and_form(sentry_init, capture_events, app):
- sentry_init(integrations=[flask_sentry.FlaskIntegration()], request_bodies="always")
+ sentry_init(
+ integrations=[flask_sentry.FlaskIntegration()], max_request_body_size="always"
+ )
- data = {"foo": "a" * 2000, "file": (BytesIO(b"hello"), "hello.txt")}
+ data = {
+ "foo": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10),
+ "file": (BytesIO(b"hello"), "hello.txt"),
+ }
@app.route("/", methods=["POST"])
def index():
assert list(request.form) == ["foo"]
assert list(request.files) == ["file"]
- assert not request.get_json()
+ try:
+ assert not request.get_json()
+ except UnsupportedMediaType:
+ # flask/werkzeug 3
+ pass
capture_message("hi")
return "ok"
@@ -439,14 +481,47 @@ def index():
(event,) = events
assert event["_meta"]["request"]["data"]["foo"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]) == 1024
+ assert len(event["request"]["data"]["foo"]) == DEFAULT_MAX_VALUE_LENGTH
assert event["_meta"]["request"]["data"]["file"] == {"": {"rem": [["!raw", "x"]]}}
assert not event["request"]["data"]["file"]
+def test_json_not_truncated_if_max_request_body_size_is_always(
+ sentry_init, capture_events, app
+):
+ sentry_init(
+ integrations=[flask_sentry.FlaskIntegration()], max_request_body_size="always"
+ )
+
+ data = {
+ "key{}".format(i): "value{}".format(i) for i in range(MAX_DATABAG_BREADTH + 10)
+ }
+
+ @app.route("/", methods=["POST"])
+ def index():
+ assert request.get_json() == data
+ assert request.get_data() == json.dumps(data).encode("ascii")
+ capture_message("hi")
+ return "ok"
+
+ events = capture_events()
+
+ client = app.test_client()
+ response = client.post("/", content_type="application/json", data=json.dumps(data))
+ assert response.status_code == 200
+
+ (event,) = events
+ assert event["request"]["data"] == data
+
+
@pytest.mark.parametrize(
"integrations",
[
@@ -513,9 +588,12 @@ def test_cli_commands_raise(app):
def foo():
1 / 0
+ def create_app(*_):
+ return app
+
with pytest.raises(ZeroDivisionError):
app.cli.main(
- args=["foo"], prog_name="myapp", obj=ScriptInfo(create_app=lambda _: app)
+ args=["foo"], prog_name="myapp", obj=ScriptInfo(create_app=create_app)
)
@@ -545,7 +623,7 @@ def wsgi_app(environ, start_response):
assert event["exception"]["values"][0]["mechanism"]["type"] == "wsgi"
-def test_500(sentry_init, capture_events, app):
+def test_500(sentry_init, app):
sentry_init(integrations=[flask_sentry.FlaskIntegration()])
app.debug = False
@@ -557,15 +635,12 @@ def index():
@app.errorhandler(500)
def error_handler(err):
- return "Sentry error: %s" % last_event_id()
-
- events = capture_events()
+ return "Sentry error."
client = app.test_client()
response = client.get("/")
- (event,) = events
- assert response.data.decode("utf-8") == "Sentry error: %s" % event["event_id"]
+ assert response.data.decode("utf-8") == "Sentry error."
def test_error_in_errorhandler(sentry_init, capture_events, app):
@@ -617,18 +692,15 @@ def test_does_not_leak_scope(sentry_init, capture_events, app):
sentry_init(integrations=[flask_sentry.FlaskIntegration()])
events = capture_events()
- with configure_scope() as scope:
- scope.set_tag("request_data", False)
+ sentry_sdk.get_isolation_scope().set_tag("request_data", False)
@app.route("/")
def index():
- with configure_scope() as scope:
- scope.set_tag("request_data", True)
+ sentry_sdk.get_isolation_scope().set_tag("request_data", True)
def generate():
for row in range(1000):
- with configure_scope() as scope:
- assert scope._tags["request_data"]
+ assert sentry_sdk.get_isolation_scope()._tags["request_data"]
yield str(row) + "\n"
@@ -639,8 +711,7 @@ def generate():
assert response.data.decode() == "".join(str(row) + "\n" for row in range(1000))
assert not events
- with configure_scope() as scope:
- assert not scope._tags["request_data"]
+ assert not sentry_sdk.get_isolation_scope()._tags["request_data"]
def test_scoped_test_client(sentry_init, app):
@@ -738,6 +809,25 @@ def error():
assert exception["type"] == "ZeroDivisionError"
+def test_error_has_trace_context_if_tracing_disabled(sentry_init, capture_events, app):
+ sentry_init(integrations=[flask_sentry.FlaskIntegration()])
+
+ events = capture_events()
+
+ @app.route("/error")
+ def error():
+ 1 / 0
+
+ with pytest.raises(ZeroDivisionError):
+ with app.test_client() as client:
+ response = client.get("/error")
+ assert response.status_code == 500
+
+ (error_event,) = events
+
+ assert error_event["contexts"]["trace"]
+
+
def test_class_based_views(sentry_init, app, capture_events):
sentry_init(integrations=[flask_sentry.FlaskIntegration()])
events = capture_events()
@@ -760,22 +850,36 @@ def dispatch_request(self):
assert event["transaction"] == "hello_class"
-def test_sentry_trace_context(sentry_init, app, capture_events):
+@pytest.mark.parametrize(
+ "template_string", ["{{ sentry_trace }}", "{{ sentry_trace_meta }}"]
+)
+def test_template_tracing_meta(sentry_init, app, capture_events, template_string):
sentry_init(integrations=[flask_sentry.FlaskIntegration()])
events = capture_events()
@app.route("/")
def index():
- sentry_span = Hub.current.scope.span
- capture_message(sentry_span.to_traceparent())
- return render_template_string("{{ sentry_trace }}")
+ capture_message(sentry_sdk.get_traceparent() + "\n" + sentry_sdk.get_baggage())
+ return render_template_string(template_string)
with app.test_client() as client:
response = client.get("/")
assert response.status_code == 200
- assert response.data.decode(
- "utf-8"
- ) == '' % (events[0]["message"],)
+
+ rendered_meta = response.data.decode("utf-8")
+ traceparent, baggage = events[0]["message"].split("\n")
+ assert traceparent != ""
+ assert baggage != ""
+
+ match = re.match(
+ r'^',
+ rendered_meta,
+ )
+ assert match is not None
+ assert match.group(1) == traceparent
+
+ rendered_baggage = match.group(2)
+ assert rendered_baggage == baggage
def test_dont_override_sentry_trace_context(sentry_init, app):
@@ -789,3 +893,167 @@ def index():
response = client.get("/")
assert response.status_code == 200
assert response.data == b"hi"
+
+
+def test_request_not_modified_by_reference(sentry_init, capture_events, app):
+ sentry_init(integrations=[flask_sentry.FlaskIntegration()])
+
+ @app.route("/", methods=["POST"])
+ def index():
+ logging.critical("oops")
+ assert request.get_json() == {"password": "ohno"}
+ assert request.headers["Authorization"] == "Bearer ohno"
+ return "ok"
+
+ events = capture_events()
+
+ client = app.test_client()
+ client.post(
+ "/", json={"password": "ohno"}, headers={"Authorization": "Bearer ohno"}
+ )
+
+ (event,) = events
+
+ assert event["request"]["data"]["password"] == "[Filtered]"
+ assert event["request"]["headers"]["Authorization"] == "[Filtered]"
+
+
+def test_response_status_code_ok_in_transaction_context(
+ sentry_init, capture_envelopes, app
+):
+ """
+ Tests that the response status code is added to the transaction context.
+ This also works for when there is an Exception during the request, but somehow the test flask app doesn't seem to trigger that.
+ """
+ sentry_init(
+ integrations=[flask_sentry.FlaskIntegration()],
+ traces_sample_rate=1.0,
+ release="demo-release",
+ )
+
+ envelopes = capture_envelopes()
+
+ client = app.test_client()
+ client.get("/message")
+
+ sentry_sdk.get_client().flush()
+
+ (_, transaction_envelope, _) = envelopes
+ transaction = transaction_envelope.get_transaction_event()
+
+ assert transaction["type"] == "transaction"
+ assert len(transaction["contexts"]) > 0
+ assert "response" in transaction["contexts"].keys(), (
+ "Response context not found in transaction"
+ )
+ assert transaction["contexts"]["response"]["status_code"] == 200
+
+
+def test_response_status_code_not_found_in_transaction_context(
+ sentry_init, capture_envelopes, app
+):
+ sentry_init(
+ integrations=[flask_sentry.FlaskIntegration()],
+ traces_sample_rate=1.0,
+ release="demo-release",
+ )
+
+ envelopes = capture_envelopes()
+
+ client = app.test_client()
+ client.get("/not-existing-route")
+
+ sentry_sdk.get_client().flush()
+
+ (transaction_envelope, _) = envelopes
+ transaction = transaction_envelope.get_transaction_event()
+
+ assert transaction["type"] == "transaction"
+ assert len(transaction["contexts"]) > 0
+ assert "response" in transaction["contexts"].keys(), (
+ "Response context not found in transaction"
+ )
+ assert transaction["contexts"]["response"]["status_code"] == 404
+
+
+def test_span_origin(sentry_init, app, capture_events):
+ sentry_init(
+ integrations=[flask_sentry.FlaskIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = app.test_client()
+ client.get("/message")
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.flask"
+
+
+def test_transaction_http_method_default(
+ sentry_init,
+ app,
+ capture_events,
+):
+ """
+ By default OPTIONS and HEAD requests do not create a transaction.
+ """
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[flask_sentry.FlaskIntegration()],
+ )
+ events = capture_events()
+
+ client = app.test_client()
+ response = client.get("/nomessage")
+ assert response.status_code == 200
+
+ response = client.options("/nomessage")
+ assert response.status_code == 200
+
+ response = client.head("/nomessage")
+ assert response.status_code == 200
+
+ (event,) = events
+
+ assert len(events) == 1
+ assert event["request"]["method"] == "GET"
+
+
+def test_transaction_http_method_custom(
+ sentry_init,
+ app,
+ capture_events,
+):
+ """
+ Configure FlaskIntegration to ONLY capture OPTIONS and HEAD requests.
+ """
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[
+ flask_sentry.FlaskIntegration(
+ http_methods_to_capture=(
+ "OPTIONS",
+ "head",
+ ) # capitalization does not matter
+ ) # case does not matter
+ ],
+ )
+ events = capture_events()
+
+ client = app.test_client()
+ response = client.get("/nomessage")
+ assert response.status_code == 200
+
+ response = client.options("/nomessage")
+ assert response.status_code == 200
+
+ response = client.head("/nomessage")
+ assert response.status_code == 200
+
+ assert len(events) == 2
+
+ (event1, event2) = events
+ assert event1["request"]["method"] == "OPTIONS"
+ assert event2["request"]["method"] == "HEAD"
diff --git a/tests/integrations/gcp/__init__.py b/tests/integrations/gcp/__init__.py
new file mode 100644
index 0000000000..eaf1ba89bb
--- /dev/null
+++ b/tests/integrations/gcp/__init__.py
@@ -0,0 +1,6 @@
+import pytest
+import os
+
+
+if "gcp" not in os.environ.get("TOX_ENV_NAME", ""):
+ pytest.skip("GCP tests only run in GCP environment", allow_module_level=True)
diff --git a/tests/integrations/gcp/test_gcp.py b/tests/integrations/gcp/test_gcp.py
index 3ccdbd752a..c27c7653aa 100644
--- a/tests/integrations/gcp/test_gcp.py
+++ b/tests/integrations/gcp/test_gcp.py
@@ -2,6 +2,7 @@
# GCP Cloud Functions unit tests
"""
+
import json
from textwrap import dedent
import tempfile
@@ -12,10 +13,6 @@
import os.path
import os
-pytestmark = pytest.mark.skipif(
- not hasattr(tempfile, "TemporaryDirectory"), reason="need Python 3.2+"
-)
-
FUNCTIONS_PRELUDE = """
from unittest.mock import Mock
@@ -62,17 +59,9 @@ def envelope_processor(envelope):
return item.get_bytes()
class TestTransport(HttpTransport):
- def _send_event(self, event):
- event = event_processor(event)
- # Writing a single string to stdout holds the GIL (seems like) and
- # therefore cannot be interleaved with other threads. This is why we
- # explicitly add a newline at the end even though `print` would provide
- # us one.
- print("\\nEVENT: {}\\n".format(json.dumps(event)))
-
- def _send_envelope(self, envelope):
- envelope = envelope_processor(envelope)
- print("\\nENVELOPE: {}\\n".format(envelope.decode(\"utf-8\")))
+ def capture_envelope(self, envelope):
+ envelope_item = envelope_processor(envelope)
+ print("\\nENVELOPE: {}\\n".format(envelope_item.decode(\"utf-8\")))
def init_sdk(timeout_warning=False, **extra_init_args):
@@ -93,9 +82,7 @@ def init_sdk(timeout_warning=False, **extra_init_args):
@pytest.fixture
def run_cloud_function():
def inner(code, subprocess_kwargs=()):
-
- event = []
- envelope = []
+ envelope_items = []
return_value = None
# STEP : Create a zip of cloud function
@@ -114,14 +101,14 @@ def inner(code, subprocess_kwargs=()):
subprocess.check_call(
[sys.executable, "setup.py", "sdist", "-d", os.path.join(tmpdir, "..")],
- **subprocess_kwargs
+ **subprocess_kwargs,
)
subprocess.check_call(
"pip install ../*.tar.gz -t .",
cwd=tmpdir,
shell=True,
- **subprocess_kwargs
+ **subprocess_kwargs,
)
stream = os.popen("python {}/main.py".format(tmpdir))
@@ -131,12 +118,9 @@ def inner(code, subprocess_kwargs=()):
for line in stream_data.splitlines():
print("GCP:", line)
- if line.startswith("EVENT: "):
- line = line[len("EVENT: ") :]
- event = json.loads(line)
- elif line.startswith("ENVELOPE: "):
+ if line.startswith("ENVELOPE: "):
line = line[len("ENVELOPE: ") :]
- envelope = json.loads(line)
+ envelope_items.append(json.loads(line))
elif line.startswith("RETURN VALUE: "):
line = line[len("RETURN VALUE: ") :]
return_value = json.loads(line)
@@ -145,13 +129,13 @@ def inner(code, subprocess_kwargs=()):
stream.close()
- return envelope, event, return_value
+ return envelope_items, return_value
return inner
def test_handled_exception(run_cloud_function):
- envelope, event, return_value = run_cloud_function(
+ envelope_items, return_value = run_cloud_function(
dedent(
"""
functionhandler = None
@@ -168,16 +152,17 @@ def cloud_function(functionhandler, event):
"""
)
)
- assert event["level"] == "error"
- (exception,) = event["exception"]["values"]
+ assert envelope_items[0]["level"] == "error"
+ (exception,) = envelope_items[0]["exception"]["values"]
assert exception["type"] == "Exception"
assert exception["value"] == "something went wrong"
- assert exception["mechanism"] == {"type": "gcp", "handled": False}
+ assert exception["mechanism"]["type"] == "gcp"
+ assert not exception["mechanism"]["handled"]
def test_unhandled_exception(run_cloud_function):
- envelope, event, return_value = run_cloud_function(
+ envelope_items, _ = run_cloud_function(
dedent(
"""
functionhandler = None
@@ -195,21 +180,23 @@ def cloud_function(functionhandler, event):
"""
)
)
- assert event["level"] == "error"
- (exception,) = event["exception"]["values"]
+ assert envelope_items[0]["level"] == "error"
+ (exception,) = envelope_items[0]["exception"]["values"]
assert exception["type"] == "ZeroDivisionError"
assert exception["value"] == "division by zero"
- assert exception["mechanism"] == {"type": "gcp", "handled": False}
+ assert exception["mechanism"]["type"] == "gcp"
+ assert not exception["mechanism"]["handled"]
def test_timeout_error(run_cloud_function):
- envelope, event, return_value = run_cloud_function(
+ envelope_items, _ = run_cloud_function(
dedent(
"""
functionhandler = None
event = {}
def cloud_function(functionhandler, event):
+ sentry_sdk.set_tag("cloud_function", "true")
time.sleep(10)
return "3"
"""
@@ -222,19 +209,22 @@ def cloud_function(functionhandler, event):
"""
)
)
- assert event["level"] == "error"
- (exception,) = event["exception"]["values"]
+ assert envelope_items[0]["level"] == "error"
+ (exception,) = envelope_items[0]["exception"]["values"]
assert exception["type"] == "ServerlessTimeoutWarning"
assert (
exception["value"]
== "WARNING : Function is expected to get timed out. Configured timeout duration = 3 seconds."
)
- assert exception["mechanism"] == {"type": "threading", "handled": False}
+ assert exception["mechanism"]["type"] == "threading"
+ assert not exception["mechanism"]["handled"]
+
+ assert envelope_items[0]["tags"]["cloud_function"] == "true"
def test_performance_no_error(run_cloud_function):
- envelope, event, return_value = run_cloud_function(
+ envelope_items, _ = run_cloud_function(
dedent(
"""
functionhandler = None
@@ -252,15 +242,15 @@ def cloud_function(functionhandler, event):
)
)
- assert envelope["type"] == "transaction"
- assert envelope["contexts"]["trace"]["op"] == "function.gcp"
- assert envelope["transaction"].startswith("Google Cloud function")
- assert envelope["transaction_info"] == {"source": "component"}
- assert envelope["transaction"] in envelope["request"]["url"]
+ assert envelope_items[0]["type"] == "transaction"
+ assert envelope_items[0]["contexts"]["trace"]["op"] == "function.gcp"
+ assert envelope_items[0]["transaction"].startswith("Google Cloud function")
+ assert envelope_items[0]["transaction_info"] == {"source": "component"}
+ assert envelope_items[0]["transaction"] in envelope_items[0]["request"]["url"]
def test_performance_error(run_cloud_function):
- envelope, event, return_value = run_cloud_function(
+ envelope_items, _ = run_cloud_function(
dedent(
"""
functionhandler = None
@@ -278,20 +268,23 @@ def cloud_function(functionhandler, event):
)
)
- assert envelope["type"] == "transaction"
- assert envelope["contexts"]["trace"]["op"] == "function.gcp"
- assert envelope["transaction"].startswith("Google Cloud function")
- assert envelope["transaction"] in envelope["request"]["url"]
- assert event["level"] == "error"
- (exception,) = event["exception"]["values"]
+ assert envelope_items[0]["level"] == "error"
+ (exception,) = envelope_items[0]["exception"]["values"]
assert exception["type"] == "Exception"
assert exception["value"] == "something went wrong"
- assert exception["mechanism"] == {"type": "gcp", "handled": False}
+ assert exception["mechanism"]["type"] == "gcp"
+ assert not exception["mechanism"]["handled"]
+
+ assert envelope_items[1]["type"] == "transaction"
+ assert envelope_items[1]["contexts"]["trace"]["op"] == "function.gcp"
+ assert envelope_items[1]["transaction"].startswith("Google Cloud function")
+ assert envelope_items[1]["transaction"] in envelope_items[0]["request"]["url"]
def test_traces_sampler_gets_correct_values_in_sampling_context(
- run_cloud_function, DictionaryContaining # noqa:N803
+ run_cloud_function,
+ DictionaryContaining, # noqa:N803
):
# TODO: There are some decent sized hacks below. For more context, see the
# long comment in the test of the same name in the AWS integration. The
@@ -300,7 +293,7 @@ def test_traces_sampler_gets_correct_values_in_sampling_context(
import inspect
- envelopes, events, return_value = run_cloud_function(
+ _, return_value = run_cloud_function(
dedent(
"""
functionhandler = None
@@ -367,3 +360,208 @@ def _safe_is_equal(x, y):
)
assert return_value["AssertionError raised"] is False
+
+
+def test_error_has_new_trace_context_performance_enabled(run_cloud_function):
+ """
+ Check if an 'trace' context is added to errros and transactions when performance monitoring is enabled.
+ """
+ envelope_items, _ = run_cloud_function(
+ dedent(
+ """
+ functionhandler = None
+ event = {}
+ def cloud_function(functionhandler, event):
+ sentry_sdk.capture_message("hi")
+ x = 3/0
+ return "3"
+ """
+ )
+ + FUNCTIONS_PRELUDE
+ + dedent(
+ """
+ init_sdk(traces_sample_rate=1.0)
+ gcp_functions.worker_v1.FunctionHandler.invoke_user_function(functionhandler, event)
+ """
+ )
+ )
+ (msg_event, error_event, transaction_event) = envelope_items
+
+ assert "trace" in msg_event["contexts"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert "trace" in error_event["contexts"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert "trace" in transaction_event["contexts"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+
+ assert (
+ msg_event["contexts"]["trace"]["trace_id"]
+ == error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+def test_error_has_new_trace_context_performance_disabled(run_cloud_function):
+ """
+ Check if an 'trace' context is added to errros and transactions when performance monitoring is disabled.
+ """
+ envelope_items, _ = run_cloud_function(
+ dedent(
+ """
+ functionhandler = None
+ event = {}
+ def cloud_function(functionhandler, event):
+ sentry_sdk.capture_message("hi")
+ x = 3/0
+ return "3"
+ """
+ )
+ + FUNCTIONS_PRELUDE
+ + dedent(
+ """
+ init_sdk(traces_sample_rate=None), # this is the default, just added for clarity
+ gcp_functions.worker_v1.FunctionHandler.invoke_user_function(functionhandler, event)
+ """
+ )
+ )
+
+ (msg_event, error_event) = envelope_items
+
+ assert "trace" in msg_event["contexts"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert "trace" in error_event["contexts"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert (
+ msg_event["contexts"]["trace"]["trace_id"]
+ == error_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+def test_error_has_existing_trace_context_performance_enabled(run_cloud_function):
+ """
+ Check if an 'trace' context is added to errros and transactions
+ from the incoming 'sentry-trace' header when performance monitoring is enabled.
+ """
+ trace_id = "471a43a4192642f0b136d5159a501701"
+ parent_span_id = "6e8f22c393e68f19"
+ parent_sampled = 1
+ sentry_trace_header = "{}-{}-{}".format(trace_id, parent_span_id, parent_sampled)
+
+ envelope_items, _ = run_cloud_function(
+ dedent(
+ """
+ functionhandler = None
+
+ from collections import namedtuple
+ GCPEvent = namedtuple("GCPEvent", ["headers"])
+ event = GCPEvent(headers={"sentry-trace": "%s"})
+
+ def cloud_function(functionhandler, event):
+ sentry_sdk.capture_message("hi")
+ x = 3/0
+ return "3"
+ """
+ % sentry_trace_header
+ )
+ + FUNCTIONS_PRELUDE
+ + dedent(
+ """
+ init_sdk(traces_sample_rate=1.0)
+ gcp_functions.worker_v1.FunctionHandler.invoke_user_function(functionhandler, event)
+ """
+ )
+ )
+ (msg_event, error_event, transaction_event) = envelope_items
+
+ assert "trace" in msg_event["contexts"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert "trace" in error_event["contexts"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert "trace" in transaction_event["contexts"]
+ assert "trace_id" in transaction_event["contexts"]["trace"]
+
+ assert (
+ msg_event["contexts"]["trace"]["trace_id"]
+ == error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ == "471a43a4192642f0b136d5159a501701"
+ )
+
+
+def test_error_has_existing_trace_context_performance_disabled(run_cloud_function):
+ """
+ Check if an 'trace' context is added to errros and transactions
+ from the incoming 'sentry-trace' header when performance monitoring is disabled.
+ """
+ trace_id = "471a43a4192642f0b136d5159a501701"
+ parent_span_id = "6e8f22c393e68f19"
+ parent_sampled = 1
+ sentry_trace_header = "{}-{}-{}".format(trace_id, parent_span_id, parent_sampled)
+
+ envelope_items, _ = run_cloud_function(
+ dedent(
+ """
+ functionhandler = None
+
+ from collections import namedtuple
+ GCPEvent = namedtuple("GCPEvent", ["headers"])
+ event = GCPEvent(headers={"sentry-trace": "%s"})
+
+ def cloud_function(functionhandler, event):
+ sentry_sdk.capture_message("hi")
+ x = 3/0
+ return "3"
+ """
+ % sentry_trace_header
+ )
+ + FUNCTIONS_PRELUDE
+ + dedent(
+ """
+ init_sdk(traces_sample_rate=None), # this is the default, just added for clarity
+ gcp_functions.worker_v1.FunctionHandler.invoke_user_function(functionhandler, event)
+ """
+ )
+ )
+ (msg_event, error_event) = envelope_items
+
+ assert "trace" in msg_event["contexts"]
+ assert "trace_id" in msg_event["contexts"]["trace"]
+
+ assert "trace" in error_event["contexts"]
+ assert "trace_id" in error_event["contexts"]["trace"]
+
+ assert (
+ msg_event["contexts"]["trace"]["trace_id"]
+ == error_event["contexts"]["trace"]["trace_id"]
+ == "471a43a4192642f0b136d5159a501701"
+ )
+
+
+def test_span_origin(run_cloud_function):
+ events, _ = run_cloud_function(
+ dedent(
+ """
+ functionhandler = None
+ event = {}
+ def cloud_function(functionhandler, event):
+ return "test_string"
+ """
+ )
+ + FUNCTIONS_PRELUDE
+ + dedent(
+ """
+ init_sdk(traces_sample_rate=1.0)
+ gcp_functions.worker_v1.FunctionHandler.invoke_user_function(functionhandler, event)
+ """
+ )
+ )
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.function.gcp"
diff --git a/tests/integrations/google_genai/__init__.py b/tests/integrations/google_genai/__init__.py
new file mode 100644
index 0000000000..5143bf4536
--- /dev/null
+++ b/tests/integrations/google_genai/__init__.py
@@ -0,0 +1,4 @@
+import pytest
+
+pytest.importorskip("google")
+pytest.importorskip("google.genai")
diff --git a/tests/integrations/google_genai/test_google_genai.py b/tests/integrations/google_genai/test_google_genai.py
new file mode 100644
index 0000000000..6e91ba6634
--- /dev/null
+++ b/tests/integrations/google_genai/test_google_genai.py
@@ -0,0 +1,2213 @@
+import json
+import pytest
+from unittest import mock
+
+from google import genai
+from google.genai import types as genai_types
+from google.genai.types import Content, Part
+
+from sentry_sdk import start_transaction
+from sentry_sdk._types import BLOB_DATA_SUBSTITUTE
+from sentry_sdk.consts import OP, SPANDATA
+from sentry_sdk.integrations.google_genai import GoogleGenAIIntegration
+from sentry_sdk.integrations.google_genai.utils import extract_contents_messages
+
+
+@pytest.fixture
+def mock_genai_client():
+ """Fixture that creates a real genai.Client with mocked HTTP responses."""
+ client = genai.Client(api_key="test-api-key")
+ return client
+
+
+def create_mock_http_response(response_body):
+ """
+ Create a mock HTTP response that the API client's request() method would return.
+
+ Args:
+ response_body: The JSON body as a string or dict
+
+ Returns:
+ An HttpResponse object with headers and body
+ """
+ if isinstance(response_body, dict):
+ response_body = json.dumps(response_body)
+
+ return genai_types.HttpResponse(
+ headers={
+ "content-type": "application/json; charset=UTF-8",
+ },
+ body=response_body,
+ )
+
+
+def create_mock_streaming_responses(response_chunks):
+ """
+ Create a generator that yields mock HTTP responses for streaming.
+
+ Args:
+ response_chunks: List of dicts, each representing a chunk's JSON body
+
+ Returns:
+ A generator that yields HttpResponse objects
+ """
+ for chunk in response_chunks:
+ yield create_mock_http_response(chunk)
+
+
+# Sample API response JSON (based on real API format from user)
+EXAMPLE_API_RESPONSE_JSON = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "Hello! How can I help you today?"}],
+ },
+ "finishReason": "STOP",
+ }
+ ],
+ "usageMetadata": {
+ "promptTokenCount": 10,
+ "candidatesTokenCount": 20,
+ "totalTokenCount": 30,
+ "cachedContentTokenCount": 5,
+ "thoughtsTokenCount": 3,
+ },
+ "modelVersion": "gemini-1.5-flash",
+ "responseId": "response-id-123",
+}
+
+
+def create_test_config(
+ temperature=None,
+ top_p=None,
+ top_k=None,
+ max_output_tokens=None,
+ presence_penalty=None,
+ frequency_penalty=None,
+ seed=None,
+ system_instruction=None,
+ tools=None,
+):
+ """Create a GenerateContentConfig."""
+ config_dict = {}
+
+ if temperature is not None:
+ config_dict["temperature"] = temperature
+ if top_p is not None:
+ config_dict["top_p"] = top_p
+ if top_k is not None:
+ config_dict["top_k"] = top_k
+ if max_output_tokens is not None:
+ config_dict["max_output_tokens"] = max_output_tokens
+ if presence_penalty is not None:
+ config_dict["presence_penalty"] = presence_penalty
+ if frequency_penalty is not None:
+ config_dict["frequency_penalty"] = frequency_penalty
+ if seed is not None:
+ config_dict["seed"] = seed
+ if system_instruction is not None:
+ config_dict["system_instruction"] = system_instruction
+ if tools is not None:
+ config_dict["tools"] = tools
+
+ return genai_types.GenerateContentConfig(**config_dict)
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_nonstreaming_generate_content(
+ sentry_init, capture_events, send_default_pii, include_prompts, mock_genai_client
+):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ # Mock the HTTP response at the _api_client.request() level
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client,
+ "request",
+ return_value=mock_http_response,
+ ):
+ with start_transaction(name="google_genai"):
+ config = create_test_config(temperature=0.7, max_output_tokens=100)
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents="Tell me a joke", config=config
+ )
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "google_genai"
+
+ assert len(event["spans"]) == 1
+ chat_span = event["spans"][0]
+
+ # Check chat span
+ assert chat_span["op"] == OP.GEN_AI_CHAT
+ assert chat_span["description"] == "chat gemini-1.5-flash"
+ assert chat_span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+ assert chat_span["data"][SPANDATA.GEN_AI_SYSTEM] == "gcp.gemini"
+ assert chat_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gemini-1.5-flash"
+
+ if send_default_pii and include_prompts:
+ # Response text is stored as a JSON array
+ response_text = chat_span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ # Parse the JSON array
+ response_texts = json.loads(response_text)
+ assert response_texts == ["Hello! How can I help you today?"]
+ else:
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in chat_span["data"]
+
+ # Check token usage
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ # Output tokens now include reasoning tokens: candidates_token_count (20) + thoughts_token_count (3) = 23
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 23
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 3
+
+
+@pytest.mark.parametrize("generate_content_config", (False, True))
+@pytest.mark.parametrize(
+ "system_instructions,expected_texts",
+ [
+ (None, None),
+ ({}, []),
+ (Content(role="system", parts=[]), []),
+ ({"parts": []}, []),
+ ("You are a helpful assistant.", ["You are a helpful assistant."]),
+ (Part(text="You are a helpful assistant."), ["You are a helpful assistant."]),
+ (
+ Content(role="system", parts=[Part(text="You are a helpful assistant.")]),
+ ["You are a helpful assistant."],
+ ),
+ ({"text": "You are a helpful assistant."}, ["You are a helpful assistant."]),
+ (
+ {"parts": [Part(text="You are a helpful assistant.")]},
+ ["You are a helpful assistant."],
+ ),
+ (
+ {"parts": [{"text": "You are a helpful assistant."}]},
+ ["You are a helpful assistant."],
+ ),
+ (["You are a helpful assistant."], ["You are a helpful assistant."]),
+ ([Part(text="You are a helpful assistant.")], ["You are a helpful assistant."]),
+ ([{"text": "You are a helpful assistant."}], ["You are a helpful assistant."]),
+ ],
+)
+def test_generate_content_with_system_instruction(
+ sentry_init,
+ capture_events,
+ mock_genai_client,
+ generate_content_config,
+ system_instructions,
+ expected_texts,
+):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ config = {
+ "system_instruction": system_instructions,
+ "temperature": 0.5,
+ }
+
+ if generate_content_config:
+ config = create_test_config(**config)
+
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash",
+ contents="What is 2+2?",
+ config=config,
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ if expected_texts is None:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in invoke_span["data"]
+ return
+
+ # (PII is enabled and include_prompts is True in this test)
+ system_instructions = json.loads(
+ invoke_span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ )
+
+ assert system_instructions == [
+ {"type": "text", "content": text} for text in expected_texts
+ ]
+
+
+def test_generate_content_with_tools(sentry_init, capture_events, mock_genai_client):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Create a mock tool function
+ def get_weather(location: str) -> str:
+ """Get the weather for a location"""
+ return f"The weather in {location} is sunny"
+
+ # Create a tool with function declarations using real types
+ function_declaration = genai_types.FunctionDeclaration(
+ name="get_weather_tool",
+ description="Get weather information (tool object)",
+ parameters=genai_types.Schema(
+ type=genai_types.Type.OBJECT,
+ properties={
+ "location": genai_types.Schema(
+ type=genai_types.Type.STRING,
+ description="The location to get weather for",
+ )
+ },
+ required=["location"],
+ ),
+ )
+
+ mock_tool = genai_types.Tool(function_declarations=[function_declaration])
+
+ # API response for tool usage
+ tool_response_json = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "I'll check the weather."}],
+ },
+ "finishReason": "STOP",
+ }
+ ],
+ "usageMetadata": {
+ "promptTokenCount": 15,
+ "candidatesTokenCount": 10,
+ "totalTokenCount": 25,
+ },
+ }
+
+ mock_http_response = create_mock_http_response(tool_response_json)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ config = create_test_config(tools=[get_weather, mock_tool])
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents="What's the weather?", config=config
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ # Check that tools are recorded (data is serialized as a string)
+ tools_data_str = invoke_span["data"][SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS]
+ # Parse the JSON string to verify content
+ tools_data = json.loads(tools_data_str)
+ assert len(tools_data) == 2
+
+ # The order of tools may not be guaranteed, so sort by name and description for comparison
+ sorted_tools = sorted(
+ tools_data, key=lambda t: (t.get("name", ""), t.get("description", ""))
+ )
+
+ # The function tool
+ assert sorted_tools[0]["name"] == "get_weather"
+ assert sorted_tools[0]["description"] == "Get the weather for a location"
+
+ # The FunctionDeclaration tool
+ assert sorted_tools[1]["name"] == "get_weather_tool"
+ assert sorted_tools[1]["description"] == "Get weather information (tool object)"
+
+
+def test_tool_execution(sentry_init, capture_events):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Create a mock tool function
+ def get_weather(location: str) -> str:
+ """Get the weather for a location"""
+ return f"The weather in {location} is sunny"
+
+ # Create wrapped version of the tool
+ from sentry_sdk.integrations.google_genai.utils import wrapped_tool
+
+ wrapped_weather = wrapped_tool(get_weather)
+
+ # Execute the wrapped tool
+ with start_transaction(name="test_tool"):
+ result = wrapped_weather("San Francisco")
+
+ assert result == "The weather in San Francisco is sunny"
+
+ (event,) = events
+ assert len(event["spans"]) == 1
+ tool_span = event["spans"][0]
+
+ assert tool_span["op"] == OP.GEN_AI_EXECUTE_TOOL
+ assert tool_span["description"] == "execute_tool get_weather"
+ assert tool_span["data"][SPANDATA.GEN_AI_TOOL_NAME] == "get_weather"
+ assert (
+ tool_span["data"][SPANDATA.GEN_AI_TOOL_DESCRIPTION]
+ == "Get the weather for a location"
+ )
+
+
+def test_error_handling(sentry_init, capture_events, mock_genai_client):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Mock an error at the HTTP level
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", side_effect=Exception("API Error")
+ ):
+ with start_transaction(name="google_genai"):
+ with pytest.raises(Exception, match="API Error"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash",
+ contents="This will fail",
+ config=create_test_config(),
+ )
+
+ # Should have both transaction and error events
+ assert len(events) == 2
+ error_event, transaction_event = events
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "Exception"
+ assert error_event["exception"]["values"][0]["value"] == "API Error"
+ assert error_event["exception"]["values"][0]["mechanism"]["type"] == "google_genai"
+
+
+def test_streaming_generate_content(sentry_init, capture_events, mock_genai_client):
+ """Test streaming with generate_content_stream, verifying chunk accumulation."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Create streaming chunks - simulating a multi-chunk response
+ # Chunk 1: First part of text with partial usage metadata
+ chunk1_json = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "Hello! "}],
+ },
+ # No finishReason in intermediate chunks
+ }
+ ],
+ "usageMetadata": {
+ "promptTokenCount": 10,
+ "candidatesTokenCount": 2,
+ "totalTokenCount": 12,
+ },
+ "responseId": "response-id-stream-123",
+ "modelVersion": "gemini-1.5-flash",
+ }
+
+ # Chunk 2: Second part of text with intermediate usage metadata
+ chunk2_json = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "How can I "}],
+ },
+ }
+ ],
+ "usageMetadata": {
+ "promptTokenCount": 10,
+ "candidatesTokenCount": 3,
+ "totalTokenCount": 13,
+ },
+ }
+
+ # Chunk 3: Final part with finish reason and complete usage metadata
+ chunk3_json = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "help you today?"}],
+ },
+ "finishReason": "STOP",
+ }
+ ],
+ "usageMetadata": {
+ "promptTokenCount": 10,
+ "candidatesTokenCount": 7,
+ "totalTokenCount": 25,
+ "cachedContentTokenCount": 5,
+ "thoughtsTokenCount": 3,
+ },
+ }
+
+ # Create streaming mock responses
+ stream_chunks = [chunk1_json, chunk2_json, chunk3_json]
+ mock_stream = create_mock_streaming_responses(stream_chunks)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request_streamed", return_value=mock_stream
+ ):
+ with start_transaction(name="google_genai"):
+ config = create_test_config()
+ stream = mock_genai_client.models.generate_content_stream(
+ model="gemini-1.5-flash", contents="Stream me a response", config=config
+ )
+
+ # Consume the stream (this is what users do with the integration wrapper)
+ collected_chunks = list(stream)
+
+ # Verify we got all chunks
+ assert len(collected_chunks) == 3
+ assert collected_chunks[0].candidates[0].content.parts[0].text == "Hello! "
+ assert collected_chunks[1].candidates[0].content.parts[0].text == "How can I "
+ assert collected_chunks[2].candidates[0].content.parts[0].text == "help you today?"
+
+ (event,) = events
+
+ assert len(event["spans"]) == 1
+ chat_span = event["spans"][0]
+
+ # Check that streaming flag is set on both spans
+ assert chat_span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+ # Verify accumulated response text (all chunks combined)
+ expected_full_text = "Hello! How can I help you today?"
+ # Response text is stored as a JSON string
+ chat_response_text = json.loads(chat_span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT])
+ assert chat_response_text == [expected_full_text]
+
+ # Verify finish reasons (only the final chunk has a finish reason)
+ # When there's a single finish reason, it's stored as a plain string (not JSON)
+ assert SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS in chat_span["data"]
+ assert chat_span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == "STOP"
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 25
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5
+ assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 3
+
+ # Verify model name
+ assert chat_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gemini-1.5-flash"
+
+
+def test_span_origin(sentry_init, capture_events, mock_genai_client):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ config = create_test_config()
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents="Test origin", config=config
+ )
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ for span in event["spans"]:
+ assert span["origin"] == "auto.ai.google_genai"
+
+
+def test_response_without_usage_metadata(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test handling of responses without usage metadata"""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Response without usage metadata
+ response_json = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "No usage data"}],
+ },
+ "finishReason": "STOP",
+ }
+ ],
+ }
+
+ mock_http_response = create_mock_http_response(response_json)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ config = create_test_config()
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents="Test", config=config
+ )
+
+ (event,) = events
+ chat_span = event["spans"][0]
+
+ # Usage data should not be present
+ assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS not in chat_span["data"]
+ assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS not in chat_span["data"]
+ assert SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS not in chat_span["data"]
+
+
+def test_multiple_candidates(sentry_init, capture_events, mock_genai_client):
+ """Test handling of multiple response candidates"""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Response with multiple candidates
+ multi_candidate_json = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "Response 1"}],
+ },
+ "finishReason": "STOP",
+ },
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "Response 2"}],
+ },
+ "finishReason": "MAX_TOKENS",
+ },
+ ],
+ "usageMetadata": {
+ "promptTokenCount": 5,
+ "candidatesTokenCount": 15,
+ "totalTokenCount": 20,
+ },
+ }
+
+ mock_http_response = create_mock_http_response(multi_candidate_json)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ config = create_test_config()
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents="Generate multiple", config=config
+ )
+
+ (event,) = events
+ chat_span = event["spans"][0]
+
+ # Should capture all responses
+ # Response text is stored as a JSON string when there are multiple responses
+ response_text = chat_span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ if isinstance(response_text, str) and response_text.startswith("["):
+ # It's a JSON array
+ response_list = json.loads(response_text)
+ assert response_list == ["Response 1", "Response 2"]
+ else:
+ # It's concatenated
+ assert response_text == "Response 1\nResponse 2"
+
+ # Finish reasons are serialized as JSON
+ finish_reasons = json.loads(
+ chat_span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS]
+ )
+ assert finish_reasons == ["STOP", "MAX_TOKENS"]
+
+
+def test_all_configuration_parameters(sentry_init, capture_events, mock_genai_client):
+ """Test that all configuration parameters are properly recorded"""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ config = create_test_config(
+ temperature=0.8,
+ top_p=0.95,
+ top_k=40,
+ max_output_tokens=2048,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ seed=12345,
+ )
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents="Test all params", config=config
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ # Check all parameters are recorded
+ assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.8
+ assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.95
+ assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_TOP_K] == 40
+ assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 2048
+ assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_SEED] == 12345
+
+
+def test_empty_response(sentry_init, capture_events, mock_genai_client):
+ """Test handling of minimal response with no content"""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Minimal response with empty candidates array
+ minimal_response_json = {"candidates": []}
+ mock_http_response = create_mock_http_response(minimal_response_json)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ response = mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents="Test", config=create_test_config()
+ )
+
+ # Response will have an empty candidates list
+ assert response is not None
+ assert len(response.candidates) == 0
+
+ (event,) = events
+ # Should still create spans even with empty candidates
+ assert len(event["spans"]) == 1
+
+
+def test_response_with_different_id_fields(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test handling of different response ID field names"""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Response with response_id and model_version
+ response_json = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [{"text": "Test"}],
+ },
+ "finishReason": "STOP",
+ }
+ ],
+ "responseId": "resp-456",
+ "modelVersion": "gemini-1.5-flash-001",
+ }
+
+ mock_http_response = create_mock_http_response(response_json)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents="Test", config=create_test_config()
+ )
+
+ (event,) = events
+ chat_span = event["spans"][0]
+
+ assert chat_span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "resp-456"
+ assert chat_span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "gemini-1.5-flash-001"
+
+
+def test_tool_with_async_function(sentry_init, capture_events):
+ """Test that async tool functions are properly wrapped"""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ capture_events()
+
+ # Create an async tool function
+ async def async_tool(param: str) -> str:
+ """An async tool"""
+ return f"Async result: {param}"
+
+ # Import is skipped in sync tests, but we can test the wrapping logic
+ from sentry_sdk.integrations.google_genai.utils import wrapped_tool
+
+ # The wrapper should handle async functions
+ wrapped_async_tool = wrapped_tool(async_tool)
+ assert wrapped_async_tool != async_tool # Should be wrapped
+ assert hasattr(wrapped_async_tool, "__wrapped__") # Should preserve original
+
+
+def test_contents_as_none(sentry_init, capture_events, mock_genai_client):
+ """Test handling when contents parameter is None"""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=None, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ # Should handle None contents gracefully
+ messages = invoke_span["data"].get(SPANDATA.GEN_AI_REQUEST_MESSAGES, [])
+ # Should only have system message if any, not user message
+ assert all(msg["role"] != "user" or msg["content"] is not None for msg in messages)
+
+
+def test_tool_calls_extraction(sentry_init, capture_events, mock_genai_client):
+ """Test extraction of tool/function calls from response"""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Response with function calls
+ function_call_response_json = {
+ "candidates": [
+ {
+ "content": {
+ "role": "model",
+ "parts": [
+ {"text": "I'll help you with that."},
+ {
+ "functionCall": {
+ "name": "get_weather",
+ "args": {
+ "location": "San Francisco",
+ "unit": "celsius",
+ },
+ }
+ },
+ {
+ "functionCall": {
+ "name": "get_time",
+ "args": {"timezone": "PST"},
+ }
+ },
+ ],
+ },
+ "finishReason": "STOP",
+ }
+ ],
+ "usageMetadata": {
+ "promptTokenCount": 20,
+ "candidatesTokenCount": 30,
+ "totalTokenCount": 50,
+ },
+ }
+
+ mock_http_response = create_mock_http_response(function_call_response_json)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash",
+ contents="What's the weather and time?",
+ config=create_test_config(),
+ )
+
+ (event,) = events
+ chat_span = event["spans"][0] # The chat span
+
+ # Check that tool calls are extracted and stored
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS in chat_span["data"]
+
+ # Parse the JSON string to verify content
+ tool_calls = json.loads(chat_span["data"][SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS])
+
+ assert len(tool_calls) == 2
+
+ # First tool call
+ assert tool_calls[0]["name"] == "get_weather"
+ assert tool_calls[0]["type"] == "function_call"
+ # Arguments are serialized as JSON strings
+ assert json.loads(tool_calls[0]["arguments"]) == {
+ "location": "San Francisco",
+ "unit": "celsius",
+ }
+
+ # Second tool call
+ assert tool_calls[1]["name"] == "get_time"
+ assert tool_calls[1]["type"] == "function_call"
+ # Arguments are serialized as JSON strings
+ assert json.loads(tool_calls[1]["arguments"]) == {"timezone": "PST"}
+
+
+def test_google_genai_message_truncation(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test that large messages are truncated properly in Google GenAI integration."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ large_content = (
+ "This is a very long message that will exceed our size limits. " * 1000
+ )
+ small_content = "This is a small user message"
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash",
+ contents=[large_content, small_content],
+ config=create_test_config(),
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in invoke_span["data"]
+
+ messages_data = invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert isinstance(messages_data, str)
+
+ parsed_messages = json.loads(messages_data)
+ assert isinstance(parsed_messages, list)
+ assert len(parsed_messages) == 1
+ assert parsed_messages[0]["role"] == "user"
+
+ # What "small content" becomes because the large message used the entire character limit
+ assert "..." in parsed_messages[0]["content"][1]["text"]
+
+
+# Sample embed content API response JSON
+EXAMPLE_EMBED_RESPONSE_JSON = {
+ "embeddings": [
+ {
+ "values": [0.1, 0.2, 0.3, 0.4, 0.5], # Simplified embedding vector
+ "statistics": {
+ "tokenCount": 10,
+ "truncated": False,
+ },
+ },
+ {
+ "values": [0.2, 0.3, 0.4, 0.5, 0.6],
+ "statistics": {
+ "tokenCount": 15,
+ "truncated": False,
+ },
+ },
+ ],
+ "metadata": {
+ "billableCharacterCount": 42,
+ },
+}
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_embed_content(
+ sentry_init, capture_events, send_default_pii, include_prompts, mock_genai_client
+):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ # Mock the HTTP response at the _api_client.request() level
+ mock_http_response = create_mock_http_response(EXAMPLE_EMBED_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client,
+ "request",
+ return_value=mock_http_response,
+ ):
+ with start_transaction(name="google_genai_embeddings"):
+ mock_genai_client.models.embed_content(
+ model="text-embedding-004",
+ contents=[
+ "What is your name?",
+ "What is your favorite color?",
+ ],
+ )
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "google_genai_embeddings"
+
+ # Should have 1 span for embeddings
+ assert len(event["spans"]) == 1
+ (embed_span,) = event["spans"]
+
+ # Check embeddings span
+ assert embed_span["op"] == OP.GEN_AI_EMBEDDINGS
+ assert embed_span["description"] == "embeddings text-embedding-004"
+ assert embed_span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
+ assert embed_span["data"][SPANDATA.GEN_AI_SYSTEM] == "gcp.gemini"
+ assert embed_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-004"
+
+ # Check input texts if PII is allowed
+ if send_default_pii and include_prompts:
+ input_texts = json.loads(embed_span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT])
+ assert input_texts == [
+ "What is your name?",
+ "What is your favorite color?",
+ ]
+ else:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in embed_span["data"]
+
+ # Check usage data (sum of token counts from statistics: 10 + 15 = 25)
+ # Note: Only available in newer versions with ContentEmbeddingStatistics
+ if SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in embed_span["data"]:
+ assert embed_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 25
+
+
+def test_embed_content_string_input(sentry_init, capture_events, mock_genai_client):
+ """Test embed_content with a single string instead of list."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Mock response with single embedding
+ single_embed_response = {
+ "embeddings": [
+ {
+ "values": [0.1, 0.2, 0.3],
+ "statistics": {
+ "tokenCount": 5,
+ "truncated": False,
+ },
+ },
+ ],
+ "metadata": {
+ "billableCharacterCount": 10,
+ },
+ }
+ mock_http_response = create_mock_http_response(single_embed_response)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai_embeddings"):
+ mock_genai_client.models.embed_content(
+ model="text-embedding-004",
+ contents="Single text input",
+ )
+
+ (event,) = events
+ (embed_span,) = event["spans"]
+
+ # Check that single string is handled correctly
+ input_texts = json.loads(embed_span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT])
+ assert input_texts == ["Single text input"]
+ # Should use token_count from statistics (5), not billable_character_count (10)
+ # Note: Only available in newer versions with ContentEmbeddingStatistics
+ if SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in embed_span["data"]:
+ assert embed_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 5
+
+
+def test_embed_content_error_handling(sentry_init, capture_events, mock_genai_client):
+ """Test error handling in embed_content."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Mock an error at the HTTP level
+ with mock.patch.object(
+ mock_genai_client._api_client,
+ "request",
+ side_effect=Exception("Embedding API Error"),
+ ):
+ with start_transaction(name="google_genai_embeddings"):
+ with pytest.raises(Exception, match="Embedding API Error"):
+ mock_genai_client.models.embed_content(
+ model="text-embedding-004",
+ contents=["This will fail"],
+ )
+
+ # Should have both transaction and error events
+ assert len(events) == 2
+ error_event, _ = events
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "Exception"
+ assert error_event["exception"]["values"][0]["value"] == "Embedding API Error"
+ assert error_event["exception"]["values"][0]["mechanism"]["type"] == "google_genai"
+
+
+def test_embed_content_without_statistics(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test embed_content response without statistics (older package versions)."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Response without statistics (typical for older google-genai versions)
+ # Embeddings exist but don't have the statistics field
+ old_version_response = {
+ "embeddings": [
+ {
+ "values": [0.1, 0.2, 0.3],
+ },
+ {
+ "values": [0.2, 0.3, 0.4],
+ },
+ ],
+ }
+ mock_http_response = create_mock_http_response(old_version_response)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai_embeddings"):
+ mock_genai_client.models.embed_content(
+ model="text-embedding-004",
+ contents=["Test without statistics", "Another test"],
+ )
+
+ (event,) = events
+ (embed_span,) = event["spans"]
+
+ # No usage tokens since there are no statistics in older versions
+ # This is expected and the integration should handle it gracefully
+ assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS not in embed_span["data"]
+
+
+def test_embed_content_span_origin(sentry_init, capture_events, mock_genai_client):
+ """Test that embed_content spans have correct origin."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_EMBED_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai_embeddings"):
+ mock_genai_client.models.embed_content(
+ model="text-embedding-004",
+ contents=["Test origin"],
+ )
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ for span in event["spans"]:
+ assert span["origin"] == "auto.ai.google_genai"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_async_embed_content(
+ sentry_init, capture_events, send_default_pii, include_prompts, mock_genai_client
+):
+ """Test async embed_content method."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ # Mock the async HTTP response
+ mock_http_response = create_mock_http_response(EXAMPLE_EMBED_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client,
+ "async_request",
+ return_value=mock_http_response,
+ ):
+ with start_transaction(name="google_genai_embeddings_async"):
+ await mock_genai_client.aio.models.embed_content(
+ model="text-embedding-004",
+ contents=[
+ "What is your name?",
+ "What is your favorite color?",
+ ],
+ )
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "google_genai_embeddings_async"
+
+ # Should have 1 span for embeddings
+ assert len(event["spans"]) == 1
+ (embed_span,) = event["spans"]
+
+ # Check embeddings span
+ assert embed_span["op"] == OP.GEN_AI_EMBEDDINGS
+ assert embed_span["description"] == "embeddings text-embedding-004"
+ assert embed_span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
+ assert embed_span["data"][SPANDATA.GEN_AI_SYSTEM] == "gcp.gemini"
+ assert embed_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-004"
+
+ # Check input texts if PII is allowed
+ if send_default_pii and include_prompts:
+ input_texts = json.loads(embed_span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT])
+ assert input_texts == [
+ "What is your name?",
+ "What is your favorite color?",
+ ]
+ else:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in embed_span["data"]
+
+ # Check usage data (sum of token counts from statistics: 10 + 15 = 25)
+ # Note: Only available in newer versions with ContentEmbeddingStatistics
+ if SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in embed_span["data"]:
+ assert embed_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 25
+
+
+@pytest.mark.asyncio
+async def test_async_embed_content_string_input(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test async embed_content with a single string instead of list."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Mock response with single embedding
+ single_embed_response = {
+ "embeddings": [
+ {
+ "values": [0.1, 0.2, 0.3],
+ "statistics": {
+ "tokenCount": 5,
+ "truncated": False,
+ },
+ },
+ ],
+ "metadata": {
+ "billableCharacterCount": 10,
+ },
+ }
+ mock_http_response = create_mock_http_response(single_embed_response)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "async_request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai_embeddings_async"):
+ await mock_genai_client.aio.models.embed_content(
+ model="text-embedding-004",
+ contents="Single text input",
+ )
+
+ (event,) = events
+ (embed_span,) = event["spans"]
+
+ # Check that single string is handled correctly
+ input_texts = json.loads(embed_span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT])
+ assert input_texts == ["Single text input"]
+ # Should use token_count from statistics (5), not billable_character_count (10)
+ # Note: Only available in newer versions with ContentEmbeddingStatistics
+ if SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in embed_span["data"]:
+ assert embed_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 5
+
+
+@pytest.mark.asyncio
+async def test_async_embed_content_error_handling(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test error handling in async embed_content."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Mock an error at the HTTP level
+ with mock.patch.object(
+ mock_genai_client._api_client,
+ "async_request",
+ side_effect=Exception("Async Embedding API Error"),
+ ):
+ with start_transaction(name="google_genai_embeddings_async"):
+ with pytest.raises(Exception, match="Async Embedding API Error"):
+ await mock_genai_client.aio.models.embed_content(
+ model="text-embedding-004",
+ contents=["This will fail"],
+ )
+
+ # Should have both transaction and error events
+ assert len(events) == 2
+ error_event, _ = events
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "Exception"
+ assert error_event["exception"]["values"][0]["value"] == "Async Embedding API Error"
+ assert error_event["exception"]["values"][0]["mechanism"]["type"] == "google_genai"
+
+
+@pytest.mark.asyncio
+async def test_async_embed_content_without_statistics(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test async embed_content response without statistics (older package versions)."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Response without statistics (typical for older google-genai versions)
+ # Embeddings exist but don't have the statistics field
+ old_version_response = {
+ "embeddings": [
+ {
+ "values": [0.1, 0.2, 0.3],
+ },
+ {
+ "values": [0.2, 0.3, 0.4],
+ },
+ ],
+ }
+ mock_http_response = create_mock_http_response(old_version_response)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "async_request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai_embeddings_async"):
+ await mock_genai_client.aio.models.embed_content(
+ model="text-embedding-004",
+ contents=["Test without statistics", "Another test"],
+ )
+
+ (event,) = events
+ (embed_span,) = event["spans"]
+
+ # No usage tokens since there are no statistics in older versions
+ # This is expected and the integration should handle it gracefully
+ assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS not in embed_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_async_embed_content_span_origin(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test that async embed_content spans have correct origin."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_EMBED_RESPONSE_JSON)
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "async_request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai_embeddings_async"):
+ await mock_genai_client.aio.models.embed_content(
+ model="text-embedding-004",
+ contents=["Test origin"],
+ )
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ for span in event["spans"]:
+ assert span["origin"] == "auto.ai.google_genai"
+
+
+# Integration tests for generate_content with different input message formats
+def test_generate_content_with_content_object(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test generate_content with Content object input."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # Create Content object
+ content = genai_types.Content(
+ role="user", parts=[genai_types.Part(text="Hello from Content object")]
+ )
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=content, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ assert messages[0]["role"] == "user"
+ assert messages[0]["content"] == [
+ {"text": "Hello from Content object", "type": "text"}
+ ]
+
+
+def test_generate_content_with_dict_format(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test generate_content with dict format input (ContentDict)."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # Dict format content
+ contents = {"role": "user", "parts": [{"text": "Hello from dict format"}]}
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=contents, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ assert messages[0]["role"] == "user"
+ assert messages[0]["content"] == [
+ {"text": "Hello from dict format", "type": "text"}
+ ]
+
+
+def test_generate_content_with_file_data(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test generate_content with file_data (external file reference)."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # Content with file_data
+ file_data = genai_types.FileData(
+ file_uri="gs://bucket/image.jpg", mime_type="image/jpeg"
+ )
+ content = genai_types.Content(
+ role="user",
+ parts=[
+ genai_types.Part(text="What's in this image?"),
+ genai_types.Part(file_data=file_data),
+ ],
+ )
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=content, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ assert messages[0]["role"] == "user"
+ assert len(messages[0]["content"]) == 2
+ assert messages[0]["content"][0] == {
+ "text": "What's in this image?",
+ "type": "text",
+ }
+ assert messages[0]["content"][1]["type"] == "uri"
+ assert messages[0]["content"][1]["modality"] == "image"
+ assert messages[0]["content"][1]["mime_type"] == "image/jpeg"
+ assert messages[0]["content"][1]["uri"] == "gs://bucket/image.jpg"
+
+
+def test_generate_content_with_inline_data(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test generate_content with inline_data (binary data)."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # Content with inline binary data
+ image_bytes = b"fake_image_binary_data"
+ blob = genai_types.Blob(data=image_bytes, mime_type="image/png")
+ content = genai_types.Content(
+ role="user",
+ parts=[
+ genai_types.Part(text="Describe this image"),
+ genai_types.Part(inline_data=blob),
+ ],
+ )
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=content, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ assert messages[0]["role"] == "user"
+ assert len(messages[0]["content"]) == 2
+ assert messages[0]["content"][0] == {"text": "Describe this image", "type": "text"}
+ assert messages[0]["content"][1]["type"] == "blob"
+ assert messages[0]["content"][1]["mime_type"] == "image/png"
+ # Binary data should be substituted for privacy
+ assert messages[0]["content"][1]["content"] == BLOB_DATA_SUBSTITUTE
+
+
+def test_generate_content_with_function_response(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test generate_content with function_response (tool result)."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # Conversation with the function call from the model
+ function_call = genai_types.FunctionCall(
+ name="get_weather",
+ args={"location": "Paris"},
+ )
+
+ # Conversation with function response (tool result)
+ function_response = genai_types.FunctionResponse(
+ id="call_123", name="get_weather", response={"output": "Sunny, 72F"}
+ )
+ contents = [
+ genai_types.Content(
+ role="user", parts=[genai_types.Part(text="What's the weather in Paris?")]
+ ),
+ genai_types.Content(
+ role="model", parts=[genai_types.Part(function_call=function_call)]
+ ),
+ genai_types.Content(
+ role="user", parts=[genai_types.Part(function_response=function_response)]
+ ),
+ ]
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=contents, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ # First message is user message
+ assert messages[0]["role"] == "tool"
+ assert messages[0]["content"]["toolCallId"] == "call_123"
+ assert messages[0]["content"]["toolName"] == "get_weather"
+ assert messages[0]["content"]["output"] == "Sunny, 72F"
+
+
+def test_generate_content_with_mixed_string_and_content(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test generate_content with mixed string and Content objects in list."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # Mix of strings and Content objects
+ contents = [
+ "Hello, this is a string message",
+ genai_types.Content(
+ role="model",
+ parts=[genai_types.Part(text="Hi! How can I help you?")],
+ ),
+ genai_types.Content(
+ role="user",
+ parts=[genai_types.Part(text="Tell me a joke")],
+ ),
+ ]
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=contents, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ # User message
+ assert messages[0]["role"] == "user"
+ assert messages[0]["content"] == [{"text": "Tell me a joke", "type": "text"}]
+
+
+def test_generate_content_with_part_object_directly(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test generate_content with Part object directly (not wrapped in Content)."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # Part object directly
+ part = genai_types.Part(text="Direct Part object")
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=part, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ assert messages[0]["role"] == "user"
+ assert messages[0]["content"] == [{"text": "Direct Part object", "type": "text"}]
+
+
+def test_generate_content_with_list_of_dicts(
+ sentry_init, capture_events, mock_genai_client
+):
+ """
+ Test generate_content with list of dict format inputs.
+
+ We only keep (and assert) the last dict in `content` because we've made popping the last message a form of
+ message truncation to keep the span size within limits. If we were following OTEL conventions, all 3 dicts
+ would be present.
+ """
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # List of dicts (conversation in dict format)
+ contents = [
+ {"role": "user", "parts": [{"text": "First user message"}]},
+ {"role": "model", "parts": [{"text": "First model response"}]},
+ {"role": "user", "parts": [{"text": "Second user message"}]},
+ ]
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=contents, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ assert messages[0]["role"] == "user"
+ assert messages[0]["content"] == [{"text": "Second user message", "type": "text"}]
+
+
+def test_generate_content_with_dict_inline_data(
+ sentry_init, capture_events, mock_genai_client
+):
+ """Test generate_content with dict format containing inline_data."""
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ # Dict with inline_data
+ contents = {
+ "role": "user",
+ "parts": [
+ {"text": "What's in this image?"},
+ {"inline_data": {"data": b"fake_binary_data", "mime_type": "image/gif"}},
+ ],
+ }
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=contents, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ assert messages[0]["role"] == "user"
+ assert len(messages[0]["content"]) == 2
+ assert messages[0]["content"][0] == {
+ "text": "What's in this image?",
+ "type": "text",
+ }
+ assert messages[0]["content"][1]["type"] == "blob"
+ assert messages[0]["content"][1]["mime_type"] == "image/gif"
+ assert messages[0]["content"][1]["content"] == BLOB_DATA_SUBSTITUTE
+
+
+def test_generate_content_without_parts_property_inline_data(
+ sentry_init, capture_events, mock_genai_client
+):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ contents = [
+ {"text": "What's in this image?"},
+ {"inline_data": {"data": b"fake_binary_data", "mime_type": "image/gif"}},
+ ]
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=contents, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ assert len(messages) == 1
+
+ assert len(messages[0]["content"]) == 2
+ assert messages[0]["role"] == "user"
+ assert messages[0]["content"][0] == {
+ "text": "What's in this image?",
+ "type": "text",
+ }
+ assert messages[0]["content"][1]["inline_data"]
+
+ assert messages[0]["content"][1]["inline_data"]["data"] == BLOB_DATA_SUBSTITUTE
+ assert messages[0]["content"][1]["inline_data"]["mime_type"] == "image/gif"
+
+
+def test_generate_content_without_parts_property_inline_data_and_binary_data_within_string(
+ sentry_init, capture_events, mock_genai_client
+):
+ sentry_init(
+ integrations=[GoogleGenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ mock_http_response = create_mock_http_response(EXAMPLE_API_RESPONSE_JSON)
+
+ contents = [
+ {"text": "What's in this image?"},
+ {
+ "inline_data": {
+ "data": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8z8BQz0AEYBxVSF+FABJADveWkH6oAAAAAElFTkSuQmCC",
+ "mime_type": "image/png",
+ }
+ },
+ ]
+
+ with mock.patch.object(
+ mock_genai_client._api_client, "request", return_value=mock_http_response
+ ):
+ with start_transaction(name="google_genai"):
+ mock_genai_client.models.generate_content(
+ model="gemini-1.5-flash", contents=contents, config=create_test_config()
+ )
+
+ (event,) = events
+ invoke_span = event["spans"][0]
+
+ messages = json.loads(invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+ assert len(messages) == 1
+ assert messages[0]["role"] == "user"
+
+ assert len(messages[0]["content"]) == 2
+ assert messages[0]["content"][0] == {
+ "text": "What's in this image?",
+ "type": "text",
+ }
+ assert messages[0]["content"][1]["inline_data"]
+
+ assert messages[0]["content"][1]["inline_data"]["data"] == BLOB_DATA_SUBSTITUTE
+ assert messages[0]["content"][1]["inline_data"]["mime_type"] == "image/png"
+
+
+# Tests for extract_contents_messages function
+def test_extract_contents_messages_none():
+ """Test extract_contents_messages with None input"""
+ result = extract_contents_messages(None)
+ assert result == []
+
+
+def test_extract_contents_messages_string():
+ """Test extract_contents_messages with string input"""
+ result = extract_contents_messages("Hello world")
+ assert result == [{"role": "user", "content": "Hello world"}]
+
+
+def test_extract_contents_messages_content_object():
+ """Test extract_contents_messages with Content object"""
+ content = genai_types.Content(
+ role="user", parts=[genai_types.Part(text="Test message")]
+ )
+ result = extract_contents_messages(content)
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == [{"text": "Test message", "type": "text"}]
+
+
+def test_extract_contents_messages_content_object_model_role():
+ """Test extract_contents_messages with Content object having model role"""
+ content = genai_types.Content(
+ role="model", parts=[genai_types.Part(text="Assistant response")]
+ )
+ result = extract_contents_messages(content)
+ assert len(result) == 1
+ assert result[0]["role"] == "assistant"
+ assert result[0]["content"] == [{"text": "Assistant response", "type": "text"}]
+
+
+def test_extract_contents_messages_content_object_no_role():
+ """Test extract_contents_messages with Content object without role"""
+ content = genai_types.Content(parts=[genai_types.Part(text="No role message")])
+ result = extract_contents_messages(content)
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == [{"text": "No role message", "type": "text"}]
+
+
+def test_extract_contents_messages_part_object():
+ """Test extract_contents_messages with Part object"""
+ part = genai_types.Part(text="Direct part")
+ result = extract_contents_messages(part)
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == [{"text": "Direct part", "type": "text"}]
+
+
+def test_extract_contents_messages_file_data():
+ """Test extract_contents_messages with file_data"""
+ file_data = genai_types.FileData(
+ file_uri="gs://bucket/file.jpg", mime_type="image/jpeg"
+ )
+ part = genai_types.Part(file_data=file_data)
+ content = genai_types.Content(parts=[part])
+ result = extract_contents_messages(content)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert len(result[0]["content"]) == 1
+ blob_part = result[0]["content"][0]
+ assert blob_part["type"] == "uri"
+ assert blob_part["modality"] == "image"
+ assert blob_part["mime_type"] == "image/jpeg"
+ assert blob_part["uri"] == "gs://bucket/file.jpg"
+
+
+def test_extract_contents_messages_inline_data():
+ """Test extract_contents_messages with inline_data (binary)"""
+ # Create inline data with bytes
+ image_bytes = b"fake_image_data"
+ blob = genai_types.Blob(data=image_bytes, mime_type="image/png")
+ part = genai_types.Part(inline_data=blob)
+ content = genai_types.Content(parts=[part])
+ result = extract_contents_messages(content)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert len(result[0]["content"]) == 1
+ blob_part = result[0]["content"][0]
+ assert blob_part["type"] == "blob"
+ assert blob_part["mime_type"] == "image/png"
+ assert blob_part["content"] == BLOB_DATA_SUBSTITUTE
+
+
+def test_extract_contents_messages_function_response():
+ """Test extract_contents_messages with function_response (tool message)"""
+ function_response = genai_types.FunctionResponse(
+ id="call_123", name="get_weather", response={"output": "sunny"}
+ )
+ part = genai_types.Part(function_response=function_response)
+ content = genai_types.Content(parts=[part])
+ result = extract_contents_messages(content)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "tool"
+ assert result[0]["content"]["toolCallId"] == "call_123"
+ assert result[0]["content"]["toolName"] == "get_weather"
+ assert result[0]["content"]["output"] == "sunny"
+
+
+def test_extract_contents_messages_function_response_with_output_key():
+ """Test extract_contents_messages with function_response that has output key"""
+ function_response = genai_types.FunctionResponse(
+ id="call_456", name="get_time", response={"output": "3:00 PM", "error": None}
+ )
+ part = genai_types.Part(function_response=function_response)
+ content = genai_types.Content(parts=[part])
+ result = extract_contents_messages(content)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "tool"
+ assert result[0]["content"]["toolCallId"] == "call_456"
+ assert result[0]["content"]["toolName"] == "get_time"
+ # Should prefer "output" key
+ assert result[0]["content"]["output"] == "3:00 PM"
+
+
+def test_extract_contents_messages_mixed_parts():
+ """Test extract_contents_messages with mixed content parts"""
+ content = genai_types.Content(
+ role="user",
+ parts=[
+ genai_types.Part(text="Text part"),
+ genai_types.Part(
+ file_data=genai_types.FileData(
+ file_uri="gs://bucket/image.jpg", mime_type="image/jpeg"
+ )
+ ),
+ ],
+ )
+ result = extract_contents_messages(content)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert len(result[0]["content"]) == 2
+ assert result[0]["content"][0] == {"text": "Text part", "type": "text"}
+ assert result[0]["content"][1]["type"] == "uri"
+ assert result[0]["content"][1]["modality"] == "image"
+ assert result[0]["content"][1]["uri"] == "gs://bucket/image.jpg"
+
+
+def test_extract_contents_messages_list():
+ """Test extract_contents_messages with list input"""
+ contents = [
+ "First message",
+ genai_types.Content(
+ role="user", parts=[genai_types.Part(text="Second message")]
+ ),
+ ]
+ result = extract_contents_messages(contents)
+
+ assert len(result) == 2
+ assert result[0] == {"role": "user", "content": "First message"}
+ assert result[1]["role"] == "user"
+ assert result[1]["content"] == [{"text": "Second message", "type": "text"}]
+
+
+def test_extract_contents_messages_dict_content():
+ """Test extract_contents_messages with dict (ContentDict)"""
+ content_dict = {"role": "user", "parts": [{"text": "Dict message"}]}
+ result = extract_contents_messages(content_dict)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == [{"text": "Dict message", "type": "text"}]
+
+
+def test_extract_contents_messages_dict_with_text():
+ """Test extract_contents_messages with dict containing text key"""
+ content_dict = {"role": "user", "text": "Simple text"}
+ result = extract_contents_messages(content_dict)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == [{"text": "Simple text", "type": "text"}]
+
+
+def test_extract_contents_messages_file_object():
+ """Test extract_contents_messages with File object"""
+ file_obj = genai_types.File(
+ name="files/123", uri="gs://bucket/file.pdf", mime_type="application/pdf"
+ )
+ result = extract_contents_messages(file_obj)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert len(result[0]["content"]) == 1
+ blob_part = result[0]["content"][0]
+ assert blob_part["type"] == "uri"
+ assert blob_part["modality"] == "document"
+ assert blob_part["mime_type"] == "application/pdf"
+ assert blob_part["uri"] == "gs://bucket/file.pdf"
+
+
+@pytest.mark.skipif(
+ not hasattr(genai_types, "PIL_Image") or genai_types.PIL_Image is None,
+ reason="PIL not available",
+)
+def test_extract_contents_messages_pil_image():
+ """Test extract_contents_messages with PIL.Image.Image"""
+ try:
+ from PIL import Image as PILImage
+
+ # Create a simple test image
+ img = PILImage.new("RGB", (10, 10), color="red")
+ result = extract_contents_messages(img)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert len(result[0]["content"]) == 1
+ blob_part = result[0]["content"][0]
+ assert blob_part["type"] == "blob"
+ assert blob_part["mime_type"].startswith("image/")
+ assert "content" in blob_part
+ # Binary content is substituted with placeholder for privacy
+ assert blob_part["content"] == "[Blob substitute]"
+ except ImportError:
+ pytest.skip("PIL not available")
+
+
+def test_extract_contents_messages_tool_and_text():
+ """Test extract_contents_messages with both tool message and text"""
+ content = genai_types.Content(
+ role="user",
+ parts=[
+ genai_types.Part(text="User question"),
+ genai_types.Part(
+ function_response=genai_types.FunctionResponse(
+ id="call_789", name="search", response={"output": "results"}
+ )
+ ),
+ ],
+ )
+ result = extract_contents_messages(content)
+
+ # Should have two messages: one user message and one tool message
+ assert len(result) == 2
+ # First should be user message with text
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == [{"text": "User question", "type": "text"}]
+ # Second should be tool message
+ assert result[1]["role"] == "tool"
+ assert result[1]["content"]["toolCallId"] == "call_789"
+ assert result[1]["content"]["toolName"] == "search"
+
+
+def test_extract_contents_messages_empty_parts():
+ """Test extract_contents_messages with Content object with empty parts"""
+ content = genai_types.Content(role="user", parts=[])
+ result = extract_contents_messages(content)
+
+ assert result == []
+
+
+def test_extract_contents_messages_empty_list():
+ """Test extract_contents_messages with empty list"""
+ result = extract_contents_messages([])
+ assert result == []
+
+
+def test_extract_contents_messages_dict_inline_data():
+ """Test extract_contents_messages with dict containing inline_data"""
+ content_dict = {
+ "role": "user",
+ "parts": [{"inline_data": {"data": b"binary_data", "mime_type": "image/gif"}}],
+ }
+ result = extract_contents_messages(content_dict)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert len(result[0]["content"]) == 1
+ blob_part = result[0]["content"][0]
+ assert blob_part["type"] == "blob"
+ assert blob_part["mime_type"] == "image/gif"
+ assert blob_part["content"] == BLOB_DATA_SUBSTITUTE
+
+
+def test_extract_contents_messages_dict_function_response():
+ """Test extract_contents_messages with dict containing function_response"""
+ content_dict = {
+ "role": "user",
+ "parts": [
+ {
+ "function_response": {
+ "id": "dict_call_1",
+ "name": "dict_tool",
+ "response": {"result": "success"},
+ }
+ }
+ ],
+ }
+ result = extract_contents_messages(content_dict)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "tool"
+ assert result[0]["content"]["toolCallId"] == "dict_call_1"
+ assert result[0]["content"]["toolName"] == "dict_tool"
+ assert result[0]["content"]["output"] == '{"result": "success"}'
+
+
+def test_extract_contents_messages_object_with_text_attribute():
+ """Test extract_contents_messages with object that has text attribute"""
+
+ class TextObject:
+ def __init__(self):
+ self.text = "Object text"
+
+ obj = TextObject()
+ result = extract_contents_messages(obj)
+
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == [{"text": "Object text", "type": "text"}]
diff --git a/tests/integrations/gql/__init__.py b/tests/integrations/gql/__init__.py
new file mode 100644
index 0000000000..c3361b42f3
--- /dev/null
+++ b/tests/integrations/gql/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("gql")
diff --git a/tests/integrations/gql/test_gql.py b/tests/integrations/gql/test_gql.py
new file mode 100644
index 0000000000..2785c63e2c
--- /dev/null
+++ b/tests/integrations/gql/test_gql.py
@@ -0,0 +1,151 @@
+import pytest
+
+import responses
+from gql import gql
+from gql import Client
+from gql import __version__
+from gql.transport.exceptions import TransportQueryError
+from gql.transport.requests import RequestsHTTPTransport
+from sentry_sdk.integrations.gql import GQLIntegration
+from sentry_sdk.utils import parse_version
+
+GQL_VERSION = parse_version(__version__)
+
+
+@responses.activate
+def _execute_mock_query(response_json):
+ url = "http://example.com/graphql"
+ query_string = """
+ query Example {
+ example
+ }
+ """
+
+ # Mock the GraphQL server response
+ responses.add(
+ method=responses.POST,
+ url=url,
+ json=response_json,
+ status=200,
+ )
+
+ transport = RequestsHTTPTransport(url=url)
+ client = Client(transport=transport)
+ query = gql(query_string)
+
+ return client.execute(query)
+
+
+@responses.activate
+def _execute_mock_query_with_keyword_document(response_json):
+ url = "http://example.com/graphql"
+ query_string = """
+ query Example {
+ example
+ }
+ """
+
+ # Mock the GraphQL server response
+ responses.add(
+ method=responses.POST,
+ url=url,
+ json=response_json,
+ status=200,
+ )
+
+ transport = RequestsHTTPTransport(url=url)
+ client = Client(transport=transport)
+ query = gql(query_string)
+
+ return client.execute(document=query)
+
+
+_execute_query_funcs = [_execute_mock_query]
+if GQL_VERSION < (4,):
+ _execute_query_funcs.append(_execute_mock_query_with_keyword_document)
+
+
+def _make_erroneous_query(capture_events, execute_query):
+ """
+ Make an erroneous GraphQL query, and assert that the error was reraised, that
+ exactly one event was recorded, and that the exception recorded was a
+ TransportQueryError. Then, return the event to allow further verifications.
+ """
+ events = capture_events()
+ response_json = {"errors": ["something bad happened"]}
+
+ with pytest.raises(TransportQueryError):
+ execute_query(response_json)
+
+ assert len(events) == 1, (
+ "the sdk captured %d events, but 1 event was expected" % len(events)
+ )
+
+ (event,) = events
+ (exception,) = event["exception"]["values"]
+
+ assert exception["type"] == "TransportQueryError", (
+ "%s was captured, but we expected a TransportQueryError" % exception(type)
+ )
+
+ assert "request" in event
+
+ return event
+
+
+def test_gql_init(sentry_init):
+ """
+ Integration test to ensure we can initialize the SDK with the GQL Integration
+ """
+ sentry_init(integrations=[GQLIntegration()])
+
+
+@pytest.mark.parametrize("execute_query", _execute_query_funcs)
+def test_real_gql_request_no_error(sentry_init, capture_events, execute_query):
+ """
+ Integration test verifying that the GQLIntegration works as expected with successful query.
+ """
+ sentry_init(integrations=[GQLIntegration()])
+ events = capture_events()
+
+ response_data = {"example": "This is the example"}
+ response_json = {"data": response_data}
+
+ result = execute_query(response_json)
+
+ assert result == response_data, (
+ "client.execute returned a different value from what it received from the server"
+ )
+ assert len(events) == 0, (
+ "the sdk captured an event, even though the query was successful"
+ )
+
+
+@pytest.mark.parametrize("execute_query", _execute_query_funcs)
+def test_real_gql_request_with_error_no_pii(sentry_init, capture_events, execute_query):
+ """
+ Integration test verifying that the GQLIntegration works as expected with query resulting
+ in a GraphQL error, and that PII is not sent.
+ """
+ sentry_init(integrations=[GQLIntegration()])
+
+ event = _make_erroneous_query(capture_events, execute_query)
+
+ assert "data" not in event["request"]
+ assert "response" not in event["contexts"]
+
+
+@pytest.mark.parametrize("execute_query", _execute_query_funcs)
+def test_real_gql_request_with_error_with_pii(
+ sentry_init, capture_events, execute_query
+):
+ """
+ Integration test verifying that the GQLIntegration works as expected with query resulting
+ in a GraphQL error, and that PII is not sent.
+ """
+ sentry_init(integrations=[GQLIntegration()], send_default_pii=True)
+
+ event = _make_erroneous_query(capture_events, execute_query)
+
+ assert "data" in event["request"]
+ assert "response" in event["contexts"]
diff --git a/tests/integrations/graphene/__init__.py b/tests/integrations/graphene/__init__.py
new file mode 100644
index 0000000000..f81854aed5
--- /dev/null
+++ b/tests/integrations/graphene/__init__.py
@@ -0,0 +1,5 @@
+import pytest
+
+pytest.importorskip("graphene")
+pytest.importorskip("fastapi")
+pytest.importorskip("flask")
diff --git a/tests/integrations/graphene/test_graphene.py b/tests/integrations/graphene/test_graphene.py
new file mode 100644
index 0000000000..63bc5de5d2
--- /dev/null
+++ b/tests/integrations/graphene/test_graphene.py
@@ -0,0 +1,283 @@
+from fastapi import FastAPI, Request
+from fastapi.testclient import TestClient
+from flask import Flask, request, jsonify
+from graphene import ObjectType, String, Schema
+
+from sentry_sdk.consts import OP
+from sentry_sdk.integrations.fastapi import FastApiIntegration
+from sentry_sdk.integrations.flask import FlaskIntegration
+from sentry_sdk.integrations.graphene import GrapheneIntegration
+from sentry_sdk.integrations.starlette import StarletteIntegration
+
+
+class Query(ObjectType):
+ hello = String(first_name=String(default_value="stranger"))
+ goodbye = String()
+
+ def resolve_hello(root, info, first_name): # noqa: N805
+ return "Hello {}!".format(first_name)
+
+ def resolve_goodbye(root, info): # noqa: N805
+ raise RuntimeError("oh no!")
+
+
+def test_capture_request_if_available_and_send_pii_is_on_async(
+ sentry_init, capture_events
+):
+ sentry_init(
+ send_default_pii=True,
+ integrations=[
+ GrapheneIntegration(),
+ FastApiIntegration(),
+ StarletteIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ schema = Schema(query=Query)
+
+ async_app = FastAPI()
+
+ @async_app.post("/graphql")
+ async def graphql_server_async(request: Request):
+ data = await request.json()
+ result = await schema.execute_async(data["query"])
+ return result.data
+
+ query = {"query": "query ErrorQuery {goodbye}"}
+ client = TestClient(async_app)
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "graphene"
+ assert event["request"]["api_target"] == "graphql"
+ assert event["request"]["data"] == query
+
+
+def test_capture_request_if_available_and_send_pii_is_on_sync(
+ sentry_init, capture_events
+):
+ sentry_init(
+ send_default_pii=True,
+ integrations=[GrapheneIntegration(), FlaskIntegration()],
+ )
+ events = capture_events()
+
+ schema = Schema(query=Query)
+
+ sync_app = Flask(__name__)
+
+ @sync_app.route("/graphql", methods=["POST"])
+ def graphql_server_sync():
+ data = request.get_json()
+ result = schema.execute(data["query"])
+ return jsonify(result.data), 200
+
+ query = {"query": "query ErrorQuery {goodbye}"}
+ client = sync_app.test_client()
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "graphene"
+ assert event["request"]["api_target"] == "graphql"
+ assert event["request"]["data"] == query
+
+
+def test_do_not_capture_request_if_send_pii_is_off_async(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ GrapheneIntegration(),
+ FastApiIntegration(),
+ StarletteIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ schema = Schema(query=Query)
+
+ async_app = FastAPI()
+
+ @async_app.post("/graphql")
+ async def graphql_server_async(request: Request):
+ data = await request.json()
+ result = await schema.execute_async(data["query"])
+ return result.data
+
+ query = {"query": "query ErrorQuery {goodbye}"}
+ client = TestClient(async_app)
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "graphene"
+ assert "data" not in event["request"]
+ assert "response" not in event["contexts"]
+
+
+def test_do_not_capture_request_if_send_pii_is_off_sync(sentry_init, capture_events):
+ sentry_init(
+ integrations=[GrapheneIntegration(), FlaskIntegration()],
+ )
+ events = capture_events()
+
+ schema = Schema(query=Query)
+
+ sync_app = Flask(__name__)
+
+ @sync_app.route("/graphql", methods=["POST"])
+ def graphql_server_sync():
+ data = request.get_json()
+ result = schema.execute(data["query"])
+ return jsonify(result.data), 200
+
+ query = {"query": "query ErrorQuery {goodbye}"}
+ client = sync_app.test_client()
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "graphene"
+ assert "data" not in event["request"]
+ assert "response" not in event["contexts"]
+
+
+def test_no_event_if_no_errors_async(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ GrapheneIntegration(),
+ FastApiIntegration(),
+ StarletteIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ schema = Schema(query=Query)
+
+ async_app = FastAPI()
+
+ @async_app.post("/graphql")
+ async def graphql_server_async(request: Request):
+ data = await request.json()
+ result = await schema.execute_async(data["query"])
+ return result.data
+
+ query = {
+ "query": "query GreetingQuery { hello }",
+ }
+ client = TestClient(async_app)
+ client.post("/graphql", json=query)
+
+ assert len(events) == 0
+
+
+def test_no_event_if_no_errors_sync(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ GrapheneIntegration(),
+ FlaskIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ schema = Schema(query=Query)
+
+ sync_app = Flask(__name__)
+
+ @sync_app.route("/graphql", methods=["POST"])
+ def graphql_server_sync():
+ data = request.get_json()
+ result = schema.execute(data["query"])
+ return jsonify(result.data), 200
+
+ query = {
+ "query": "query GreetingQuery { hello }",
+ }
+ client = sync_app.test_client()
+ client.post("/graphql", json=query)
+
+ assert len(events) == 0
+
+
+def test_graphql_span_holds_query_information(sentry_init, capture_events):
+ sentry_init(
+ integrations=[GrapheneIntegration(), FlaskIntegration()],
+ traces_sample_rate=1.0,
+ default_integrations=False,
+ )
+ events = capture_events()
+
+ schema = Schema(query=Query)
+
+ sync_app = Flask(__name__)
+
+ @sync_app.route("/graphql", methods=["POST"])
+ def graphql_server_sync():
+ data = request.get_json()
+ result = schema.execute(data["query"], operation_name=data.get("operationName"))
+ return jsonify(result.data), 200
+
+ query = {
+ "query": "query GreetingQuery { hello }",
+ "operationName": "GreetingQuery",
+ }
+ client = sync_app.test_client()
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert len(event["spans"]) == 1
+
+ (span,) = event["spans"]
+ assert span["op"] == OP.GRAPHQL_QUERY
+ assert span["description"] == query["operationName"]
+ assert span["data"]["graphql.document"] == query["query"]
+ assert span["data"]["graphql.operation.name"] == query["operationName"]
+ assert span["data"]["graphql.operation.type"] == "query"
+
+
+def test_breadcrumbs_hold_query_information_on_error(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ GrapheneIntegration(),
+ ],
+ default_integrations=False,
+ )
+ events = capture_events()
+
+ schema = Schema(query=Query)
+
+ sync_app = Flask(__name__)
+
+ @sync_app.route("/graphql", methods=["POST"])
+ def graphql_server_sync():
+ data = request.get_json()
+ result = schema.execute(data["query"], operation_name=data.get("operationName"))
+ return jsonify(result.data), 200
+
+ query = {
+ "query": "query ErrorQuery { goodbye }",
+ "operationName": "ErrorQuery",
+ }
+ client = sync_app.test_client()
+ client.post("/graphql", json=query)
+
+ assert len(events) == 1
+
+ (event,) = events
+ assert len(event["breadcrumbs"]) == 1
+
+ breadcrumbs = event["breadcrumbs"]["values"]
+ assert len(breadcrumbs) == 1
+
+ (breadcrumb,) = breadcrumbs
+ assert breadcrumb["category"] == "graphql.operation"
+ assert breadcrumb["data"]["operation_name"] == query["operationName"]
+ assert breadcrumb["data"]["operation_type"] == "query"
+ assert breadcrumb["type"] == "default"
diff --git a/tests/integrations/grpc/__init__.py b/tests/integrations/grpc/__init__.py
index 88a0a201e4..f18dce91e2 100644
--- a/tests/integrations/grpc/__init__.py
+++ b/tests/integrations/grpc/__init__.py
@@ -1,3 +1,8 @@
+import sys
+from pathlib import Path
+
import pytest
+# For imports inside gRPC autogenerated code to work
+sys.path.append(str(Path(__file__).parent))
pytest.importorskip("grpc")
diff --git a/tests/integrations/grpc/compile_test_services.sh b/tests/integrations/grpc/compile_test_services.sh
new file mode 100755
index 0000000000..777a27e6e5
--- /dev/null
+++ b/tests/integrations/grpc/compile_test_services.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+
+# Run this script from the project root to generate the python code
+
+TARGET_PATH=./tests/integrations/grpc
+
+# Create python file
+python -m grpc_tools.protoc \
+ --proto_path=$TARGET_PATH/protos/ \
+ --python_out=$TARGET_PATH/ \
+ --pyi_out=$TARGET_PATH/ \
+ --grpc_python_out=$TARGET_PATH/ \
+ $TARGET_PATH/protos/grpc_test_service.proto
+
+echo Code generation successfull
diff --git a/tests/integrations/grpc/grpc_test_service.proto b/tests/integrations/grpc/grpc_test_service.proto
deleted file mode 100644
index 43497c7129..0000000000
--- a/tests/integrations/grpc/grpc_test_service.proto
+++ /dev/null
@@ -1,11 +0,0 @@
-syntax = "proto3";
-
-package grpc_test_server;
-
-service gRPCTestService{
- rpc TestServe(gRPCTestMessage) returns (gRPCTestMessage);
-}
-
-message gRPCTestMessage {
- string text = 1;
-}
diff --git a/tests/integrations/grpc/grpc_test_service_pb2.py b/tests/integrations/grpc/grpc_test_service_pb2.py
index c68f255b4a..84ea7f632a 100644
--- a/tests/integrations/grpc/grpc_test_service_pb2.py
+++ b/tests/integrations/grpc/grpc_test_service_pb2.py
@@ -2,27 +2,26 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: grpc_test_service.proto
"""Generated protocol buffer code."""
-from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
-
+from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x17grpc_test_service.proto\x12\x10grpc_test_server"\x1f\n\x0fgRPCTestMessage\x12\x0c\n\x04text\x18\x01 \x01(\t2d\n\x0fgRPCTestService\x12Q\n\tTestServe\x12!.grpc_test_server.gRPCTestMessage\x1a!.grpc_test_server.gRPCTestMessageb\x06proto3'
-)
-_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "grpc_test_service_pb2", globals())
-if _descriptor._USE_C_DESCRIPTORS == False:
- DESCRIPTOR._options = None
- _GRPCTESTMESSAGE._serialized_start = 45
- _GRPCTESTMESSAGE._serialized_end = 76
- _GRPCTESTSERVICE._serialized_start = 78
- _GRPCTESTSERVICE._serialized_end = 178
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17grpc_test_service.proto\x12\x10grpc_test_server\"\x1f\n\x0fgRPCTestMessage\x12\x0c\n\x04text\x18\x01 \x01(\t2\xf8\x02\n\x0fgRPCTestService\x12Q\n\tTestServe\x12!.grpc_test_server.gRPCTestMessage\x1a!.grpc_test_server.gRPCTestMessage\x12Y\n\x0fTestUnaryStream\x12!.grpc_test_server.gRPCTestMessage\x1a!.grpc_test_server.gRPCTestMessage0\x01\x12\\\n\x10TestStreamStream\x12!.grpc_test_server.gRPCTestMessage\x1a!.grpc_test_server.gRPCTestMessage(\x01\x30\x01\x12Y\n\x0fTestStreamUnary\x12!.grpc_test_server.gRPCTestMessage\x1a!.grpc_test_server.gRPCTestMessage(\x01\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_test_service_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_GRPCTESTMESSAGE']._serialized_start=45
+ _globals['_GRPCTESTMESSAGE']._serialized_end=76
+ _globals['_GRPCTESTSERVICE']._serialized_start=79
+ _globals['_GRPCTESTSERVICE']._serialized_end=455
# @@protoc_insertion_point(module_scope)
diff --git a/tests/integrations/grpc/grpc_test_service_pb2.pyi b/tests/integrations/grpc/grpc_test_service_pb2.pyi
index 02a0b7045b..f16d8a2d65 100644
--- a/tests/integrations/grpc/grpc_test_service_pb2.pyi
+++ b/tests/integrations/grpc/grpc_test_service_pb2.pyi
@@ -1,32 +1,11 @@
-"""
-@generated by mypy-protobuf. Do not edit manually!
-isort:skip_file
-"""
-import builtins
-import google.protobuf.descriptor
-import google.protobuf.message
-import sys
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from typing import ClassVar as _ClassVar, Optional as _Optional
-if sys.version_info >= (3, 8):
- import typing as typing_extensions
-else:
- import typing_extensions
+DESCRIPTOR: _descriptor.FileDescriptor
-DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
-
-@typing_extensions.final
-class gRPCTestMessage(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- TEXT_FIELD_NUMBER: builtins.int
- text: builtins.str
- def __init__(
- self,
- *,
- text: builtins.str = ...,
- ) -> None: ...
- def ClearField(
- self, field_name: typing_extensions.Literal["text", b"text"]
- ) -> None: ...
-
-global___gRPCTestMessage = gRPCTestMessage
+class gRPCTestMessage(_message.Message):
+ __slots__ = ["text"]
+ TEXT_FIELD_NUMBER: _ClassVar[int]
+ text: str
+ def __init__(self, text: _Optional[str] = ...) -> None: ...
diff --git a/tests/integrations/grpc/grpc_test_service_pb2_grpc.py b/tests/integrations/grpc/grpc_test_service_pb2_grpc.py
index 73b7d94c16..ad897608ca 100644
--- a/tests/integrations/grpc/grpc_test_service_pb2_grpc.py
+++ b/tests/integrations/grpc/grpc_test_service_pb2_grpc.py
@@ -2,7 +2,7 @@
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
-import tests.integrations.grpc.grpc_test_service_pb2 as grpc__test__service__pb2
+import grpc_test_service_pb2 as grpc__test__service__pb2
class gRPCTestServiceStub(object):
@@ -15,10 +15,25 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.TestServe = channel.unary_unary(
- "/grpc_test_server.gRPCTestService/TestServe",
- request_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
- response_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
- )
+ '/grpc_test_server.gRPCTestService/TestServe',
+ request_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ response_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
+ )
+ self.TestUnaryStream = channel.unary_stream(
+ '/grpc_test_server.gRPCTestService/TestUnaryStream',
+ request_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ response_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
+ )
+ self.TestStreamStream = channel.stream_stream(
+ '/grpc_test_server.gRPCTestService/TestStreamStream',
+ request_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ response_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
+ )
+ self.TestStreamUnary = channel.stream_unary(
+ '/grpc_test_server.gRPCTestService/TestStreamUnary',
+ request_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ response_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
+ )
class gRPCTestServiceServicer(object):
@@ -27,53 +42,124 @@ class gRPCTestServiceServicer(object):
def TestServe(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
- context.set_details("Method not implemented!")
- raise NotImplementedError("Method not implemented!")
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def TestUnaryStream(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def TestStreamStream(self, request_iterator, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def TestStreamUnary(self, request_iterator, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
def add_gRPCTestServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
- "TestServe": grpc.unary_unary_rpc_method_handler(
- servicer.TestServe,
- request_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
- response_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
- ),
+ 'TestServe': grpc.unary_unary_rpc_method_handler(
+ servicer.TestServe,
+ request_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
+ response_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ ),
+ 'TestUnaryStream': grpc.unary_stream_rpc_method_handler(
+ servicer.TestUnaryStream,
+ request_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
+ response_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ ),
+ 'TestStreamStream': grpc.stream_stream_rpc_method_handler(
+ servicer.TestStreamStream,
+ request_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
+ response_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ ),
+ 'TestStreamUnary': grpc.stream_unary_rpc_method_handler(
+ servicer.TestStreamUnary,
+ request_deserializer=grpc__test__service__pb2.gRPCTestMessage.FromString,
+ response_serializer=grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- "grpc_test_server.gRPCTestService", rpc_method_handlers
- )
+ 'grpc_test_server.gRPCTestService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
-# This class is part of an EXPERIMENTAL API.
+ # This class is part of an EXPERIMENTAL API.
class gRPCTestService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
- def TestServe(
- request,
- target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None,
- ):
- return grpc.experimental.unary_unary(
- request,
+ def TestServe(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(request, target, '/grpc_test_server.gRPCTestService/TestServe',
+ grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ grpc__test__service__pb2.gRPCTestMessage.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def TestUnaryStream(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/grpc_test_server.gRPCTestService/TestUnaryStream',
+ grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ grpc__test__service__pb2.gRPCTestMessage.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def TestStreamStream(request_iterator,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.stream_stream(request_iterator, target, '/grpc_test_server.gRPCTestService/TestStreamStream',
+ grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
+ grpc__test__service__pb2.gRPCTestMessage.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+ @staticmethod
+ def TestStreamUnary(request_iterator,
target,
- "/grpc_test_server.gRPCTestService/TestServe",
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.stream_unary(request_iterator, target, '/grpc_test_server.gRPCTestService/TestStreamUnary',
grpc__test__service__pb2.gRPCTestMessage.SerializeToString,
grpc__test__service__pb2.gRPCTestMessage.FromString,
- options,
- channel_credentials,
- insecure,
- call_credentials,
- compression,
- wait_for_ready,
- timeout,
- metadata,
- )
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/tests/integrations/grpc/protos/grpc_test_service.proto b/tests/integrations/grpc/protos/grpc_test_service.proto
new file mode 100644
index 0000000000..9eba747218
--- /dev/null
+++ b/tests/integrations/grpc/protos/grpc_test_service.proto
@@ -0,0 +1,14 @@
+syntax = "proto3";
+
+package grpc_test_server;
+
+service gRPCTestService{
+ rpc TestServe(gRPCTestMessage) returns (gRPCTestMessage);
+ rpc TestUnaryStream(gRPCTestMessage) returns (stream gRPCTestMessage);
+ rpc TestStreamStream(stream gRPCTestMessage) returns (stream gRPCTestMessage);
+ rpc TestStreamUnary(stream gRPCTestMessage) returns (gRPCTestMessage);
+}
+
+message gRPCTestMessage {
+ string text = 1;
+}
diff --git a/tests/integrations/grpc/test_grpc.py b/tests/integrations/grpc/test_grpc.py
index 92883e9256..25436d9feb 100644
--- a/tests/integrations/grpc/test_grpc.py
+++ b/tests/integrations/grpc/test_grpc.py
@@ -1,40 +1,101 @@
-from __future__ import absolute_import
-
-import os
-
-from concurrent import futures
-
import grpc
import pytest
-from sentry_sdk import Hub, start_transaction
+from concurrent import futures
+from typing import List, Optional, Tuple
+from unittest.mock import Mock
+
+from sentry_sdk import start_span, start_transaction
from sentry_sdk.consts import OP
+from sentry_sdk.integrations.grpc import GRPCIntegration
from sentry_sdk.integrations.grpc.client import ClientInterceptor
-from sentry_sdk.integrations.grpc.server import ServerInterceptor
+from tests.conftest import ApproxDict
from tests.integrations.grpc.grpc_test_service_pb2 import gRPCTestMessage
from tests.integrations.grpc.grpc_test_service_pb2_grpc import (
- gRPCTestServiceServicer,
add_gRPCTestServiceServicer_to_server,
+ gRPCTestServiceServicer,
gRPCTestServiceStub,
)
-PORT = 50051
-PORT += os.getpid() % 100 # avoid port conflicts when running tests in parallel
+
+# Set up in-memory channel instead of network-based
+def _set_up(
+ interceptors: Optional[List[grpc.ServerInterceptor]] = None,
+) -> Tuple[grpc.Server, grpc.Channel]:
+ """
+ Sets up a gRPC server and returns both the server and a channel connected to it.
+ This eliminates network dependencies and makes tests more reliable.
+ """
+ # Create server with thread pool
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=2),
+ interceptors=interceptors,
+ )
+
+ # Add our test service to the server
+ servicer = TestService()
+ add_gRPCTestServiceServicer_to_server(servicer, server)
+
+ # Use dynamic port allocation instead of hardcoded port
+ port = server.add_insecure_port("[::]:0") # Let gRPC choose an available port
+ server.start()
+
+ # Create channel connected to our server
+ channel = grpc.insecure_channel(f"localhost:{port}") # noqa: E231
+
+ return server, channel
+
+
+def _tear_down(server: grpc.Server):
+ server.stop(grace=None) # Immediate shutdown
@pytest.mark.forked
def test_grpc_server_starts_transaction(sentry_init, capture_events_forksafe):
- sentry_init(traces_sample_rate=1.0)
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
events = capture_events_forksafe()
- server = _set_up()
+ server, channel = _set_up()
- with grpc.insecure_channel(f"localhost:{PORT}") as channel:
- stub = gRPCTestServiceStub(channel)
- stub.TestServe(gRPCTestMessage(text="test"))
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ stub.TestServe(gRPCTestMessage(text="test"))
+
+ _tear_down(server=server)
+
+ events.write_file.close()
+ event = events.read_event()
+ span = event["spans"][0]
+
+ assert event["type"] == "transaction"
+ assert event["transaction_info"] == {
+ "source": "custom",
+ }
+ assert event["contexts"]["trace"]["op"] == OP.GRPC_SERVER
+ assert span["op"] == "test"
+
+
+@pytest.mark.forked
+def test_grpc_server_other_interceptors(sentry_init, capture_events_forksafe):
+ """Ensure compatibility with additional server interceptors."""
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+ events = capture_events_forksafe()
+ mock_intercept = lambda continuation, handler_call_details: continuation(
+ handler_call_details
+ )
+ mock_interceptor = Mock()
+ mock_interceptor.intercept_service.side_effect = mock_intercept
+
+ server, channel = _set_up(interceptors=[mock_interceptor])
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ stub.TestServe(gRPCTestMessage(text="test"))
_tear_down(server=server)
+ mock_interceptor.intercept_service.assert_called_once()
+
events.write_file.close()
event = events.read_event()
span = event["spans"][0]
@@ -49,33 +110,33 @@ def test_grpc_server_starts_transaction(sentry_init, capture_events_forksafe):
@pytest.mark.forked
def test_grpc_server_continues_transaction(sentry_init, capture_events_forksafe):
- sentry_init(traces_sample_rate=1.0)
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
events = capture_events_forksafe()
- server = _set_up()
+ server, channel = _set_up()
- with grpc.insecure_channel(f"localhost:{PORT}") as channel:
- stub = gRPCTestServiceStub(channel)
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
- with start_transaction() as transaction:
- metadata = (
- (
- "baggage",
- "sentry-trace_id={trace_id},sentry-environment=test,"
- "sentry-transaction=test-transaction,sentry-sample_rate=1.0".format(
- trace_id=transaction.trace_id
- ),
+ with start_transaction() as transaction:
+ metadata = (
+ (
+ "baggage",
+ "sentry-trace_id={trace_id},sentry-environment=test,"
+ "sentry-transaction=test-transaction,sentry-sample_rate=1.0".format(
+ trace_id=transaction.trace_id
),
- (
- "sentry-trace",
- "{trace_id}-{parent_span_id}-{sampled}".format(
- trace_id=transaction.trace_id,
- parent_span_id=transaction.span_id,
- sampled=1,
- ),
+ ),
+ (
+ "sentry-trace",
+ "{trace_id}-{parent_span_id}-{sampled}".format(
+ trace_id=transaction.trace_id,
+ parent_span_id=transaction.span_id,
+ sampled=1,
),
- )
- stub.TestServe(gRPCTestMessage(text="test"), metadata=metadata)
+ ),
+ )
+ stub.TestServe(gRPCTestMessage(text="test"), metadata=metadata)
_tear_down(server=server)
@@ -94,18 +155,16 @@ def test_grpc_server_continues_transaction(sentry_init, capture_events_forksafe)
@pytest.mark.forked
def test_grpc_client_starts_span(sentry_init, capture_events_forksafe):
- sentry_init(traces_sample_rate=1.0)
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
events = capture_events_forksafe()
- interceptors = [ClientInterceptor()]
- server = _set_up()
+ server, channel = _set_up()
- with grpc.insecure_channel(f"localhost:{PORT}") as channel:
- channel = grpc.intercept_channel(channel, *interceptors)
- stub = gRPCTestServiceStub(channel)
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
- with start_transaction():
- stub.TestServe(gRPCTestMessage(text="test"))
+ with start_transaction():
+ stub.TestServe(gRPCTestMessage(text="test"))
_tear_down(server=server)
@@ -120,29 +179,147 @@ def test_grpc_client_starts_span(sentry_init, capture_events_forksafe):
span["description"]
== "unary unary call to /grpc_test_server.gRPCTestService/TestServe"
)
- assert span["data"] == {
- "type": "unary unary",
- "method": "/grpc_test_server.gRPCTestService/TestServe",
- "code": "OK",
- }
+ assert span["data"] == ApproxDict(
+ {
+ "type": "unary unary",
+ "method": "/grpc_test_server.gRPCTestService/TestServe",
+ "code": "OK",
+ }
+ )
+
+
+@pytest.mark.forked
+def test_grpc_client_unary_stream_starts_span(sentry_init, capture_events_forksafe):
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+ events = capture_events_forksafe()
+
+ server, channel = _set_up()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+
+ with start_transaction():
+ [el for el in stub.TestUnaryStream(gRPCTestMessage(text="test"))]
+
+ _tear_down(server=server)
+
+ events.write_file.close()
+ local_transaction = events.read_event()
+ span = local_transaction["spans"][0]
+
+ assert len(local_transaction["spans"]) == 1
+ assert span["op"] == OP.GRPC_CLIENT
+ assert (
+ span["description"]
+ == "unary stream call to /grpc_test_server.gRPCTestService/TestUnaryStream"
+ )
+ assert span["data"] == ApproxDict(
+ {
+ "type": "unary stream",
+ "method": "/grpc_test_server.gRPCTestService/TestUnaryStream",
+ }
+ )
+
+
+# using unittest.mock.Mock not possible because grpc verifies
+# that the interceptor is of the correct type
+class MockClientInterceptor(grpc.UnaryUnaryClientInterceptor):
+ call_counter = 0
+
+ def intercept_unary_unary(self, continuation, client_call_details, request):
+ self.__class__.call_counter += 1
+ return continuation(client_call_details, request)
+
+
+@pytest.mark.forked
+def test_grpc_client_other_interceptor(sentry_init, capture_events_forksafe):
+ """Ensure compatibility with additional client interceptors."""
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+ events = capture_events_forksafe()
+
+ server, channel = _set_up()
+
+ # Intercept the channel
+ channel = grpc.intercept_channel(channel, MockClientInterceptor())
+ stub = gRPCTestServiceStub(channel)
+
+ with start_transaction():
+ stub.TestServe(gRPCTestMessage(text="test"))
+
+ _tear_down(server=server)
+
+ assert MockClientInterceptor.call_counter == 1
+
+ events.write_file.close()
+ events.read_event()
+ local_transaction = events.read_event()
+ span = local_transaction["spans"][0]
+
+ assert len(local_transaction["spans"]) == 1
+ assert span["op"] == OP.GRPC_CLIENT
+ assert (
+ span["description"]
+ == "unary unary call to /grpc_test_server.gRPCTestService/TestServe"
+ )
+ assert span["data"] == ApproxDict(
+ {
+ "type": "unary unary",
+ "method": "/grpc_test_server.gRPCTestService/TestServe",
+ "code": "OK",
+ }
+ )
+
+
+@pytest.mark.forked
+def test_prevent_dual_client_interceptor(sentry_init, capture_events_forksafe):
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+ events = capture_events_forksafe()
+
+ server, channel = _set_up()
+
+ # Intercept the channel
+ channel = grpc.intercept_channel(channel, ClientInterceptor())
+ stub = gRPCTestServiceStub(channel)
+
+ with start_transaction():
+ stub.TestServe(gRPCTestMessage(text="test"))
+
+ _tear_down(server=server)
+
+ events.write_file.close()
+ events.read_event()
+ local_transaction = events.read_event()
+ span = local_transaction["spans"][0]
+
+ assert len(local_transaction["spans"]) == 1
+ assert span["op"] == OP.GRPC_CLIENT
+ assert (
+ span["description"]
+ == "unary unary call to /grpc_test_server.gRPCTestService/TestServe"
+ )
+ assert span["data"] == ApproxDict(
+ {
+ "type": "unary unary",
+ "method": "/grpc_test_server.gRPCTestService/TestServe",
+ "code": "OK",
+ }
+ )
@pytest.mark.forked
def test_grpc_client_and_servers_interceptors_integration(
sentry_init, capture_events_forksafe
):
- sentry_init(traces_sample_rate=1.0)
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
events = capture_events_forksafe()
- interceptors = [ClientInterceptor()]
- server = _set_up()
+ server, channel = _set_up()
- with grpc.insecure_channel(f"localhost:{PORT}") as channel:
- channel = grpc.intercept_channel(channel, *interceptors)
- stub = gRPCTestServiceStub(channel)
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
- with start_transaction():
- stub.TestServe(gRPCTestMessage(text="test"))
+ with start_transaction():
+ stub.TestServe(gRPCTestMessage(text="test"))
_tear_down(server=server)
@@ -156,25 +333,67 @@ def test_grpc_client_and_servers_interceptors_integration(
)
-def _set_up():
- server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=2),
- interceptors=[ServerInterceptor(find_name=_find_name)],
- )
+@pytest.mark.forked
+def test_stream_stream(sentry_init):
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+ server, channel = _set_up()
- add_gRPCTestServiceServicer_to_server(TestService, server)
- server.add_insecure_port(f"[::]:{PORT}")
- server.start()
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ response_iterator = stub.TestStreamStream(iter((gRPCTestMessage(text="test"),)))
+ for response in response_iterator:
+ assert response.text == "test"
- return server
+ _tear_down(server=server)
-def _tear_down(server: grpc.Server):
- server.stop(None)
+@pytest.mark.forked
+def test_stream_unary(sentry_init):
+ """
+ Test to verify stream-stream works.
+ Tracing not supported for it yet.
+ """
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+ server, channel = _set_up()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ response = stub.TestStreamUnary(iter((gRPCTestMessage(text="test"),)))
+ assert response.text == "test"
+ _tear_down(server=server)
-def _find_name(request):
- return request.__class__
+
+@pytest.mark.forked
+def test_span_origin(sentry_init, capture_events_forksafe):
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+ events = capture_events_forksafe()
+
+ server, channel = _set_up()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+
+ with start_transaction(name="custom_transaction"):
+ stub.TestServe(gRPCTestMessage(text="test"))
+
+ _tear_down(server=server)
+
+ events.write_file.close()
+
+ transaction_from_integration = events.read_event()
+ custom_transaction = events.read_event()
+
+ assert (
+ transaction_from_integration["contexts"]["trace"]["origin"] == "auto.grpc.grpc"
+ )
+ assert (
+ transaction_from_integration["spans"][0]["origin"]
+ == "auto.grpc.grpc.TestService"
+ ) # manually created in TestService, not the instrumentation
+
+ assert custom_transaction["contexts"]["trace"]["origin"] == "manual"
+ assert custom_transaction["spans"][0]["origin"] == "auto.grpc.grpc"
class TestService(gRPCTestServiceServicer):
@@ -182,8 +401,26 @@ class TestService(gRPCTestServiceServicer):
@staticmethod
def TestServe(request, context): # noqa: N802
- hub = Hub.current
- with hub.start_span(op="test", description="test"):
+ with start_span(
+ op="test",
+ name="test",
+ origin="auto.grpc.grpc.TestService",
+ ):
pass
return gRPCTestMessage(text=request.text)
+
+ @staticmethod
+ def TestUnaryStream(request, context): # noqa: N802
+ for _ in range(3):
+ yield gRPCTestMessage(text=request.text)
+
+ @staticmethod
+ def TestStreamStream(request, context): # noqa: N802
+ for r in request:
+ yield r
+
+ @staticmethod
+ def TestStreamUnary(request, context): # noqa: N802
+ requests = [r for r in request]
+ return requests.pop()
diff --git a/tests/integrations/grpc/test_grpc_aio.py b/tests/integrations/grpc/test_grpc_aio.py
new file mode 100644
index 0000000000..96e9a4dba8
--- /dev/null
+++ b/tests/integrations/grpc/test_grpc_aio.py
@@ -0,0 +1,335 @@
+import asyncio
+
+import grpc
+import pytest
+import pytest_asyncio
+import sentry_sdk
+
+from sentry_sdk import start_span, start_transaction
+from sentry_sdk.consts import OP
+from sentry_sdk.integrations.grpc import GRPCIntegration
+from tests.conftest import ApproxDict
+from tests.integrations.grpc.grpc_test_service_pb2 import gRPCTestMessage
+from tests.integrations.grpc.grpc_test_service_pb2_grpc import (
+ add_gRPCTestServiceServicer_to_server,
+ gRPCTestServiceServicer,
+ gRPCTestServiceStub,
+)
+
+
+@pytest_asyncio.fixture(scope="function")
+async def grpc_server_and_channel(sentry_init):
+ """
+ Creates an async gRPC server and a channel connected to it.
+ Returns both for use in tests, and cleans up afterward.
+ """
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+
+ # Create server
+ server = grpc.aio.server()
+
+ # Let gRPC choose a free port instead of hardcoding it
+ port = server.add_insecure_port("[::]:0")
+
+ # Add service implementation
+ add_gRPCTestServiceServicer_to_server(TestService, server)
+
+ # Start the server
+ await asyncio.create_task(server.start())
+
+ # Create channel connected to our server
+ channel = grpc.aio.insecure_channel(f"localhost:{port}") # noqa: E231
+
+ try:
+ yield server, channel
+ finally:
+ # Clean up resources
+ await channel.close()
+ await server.stop(None)
+
+
+@pytest.mark.asyncio
+async def test_noop_for_unimplemented_method(sentry_init, capture_events):
+ sentry_init(traces_sample_rate=1.0, integrations=[GRPCIntegration()])
+
+ # Create empty server with no services
+ server = grpc.aio.server()
+ port = server.add_insecure_port("[::]:0") # Let gRPC choose a free port
+ await asyncio.create_task(server.start())
+
+ events = capture_events()
+
+ try:
+ async with grpc.aio.insecure_channel(
+ f"localhost:{port}" # noqa: E231
+ ) as channel:
+ stub = gRPCTestServiceStub(channel)
+ with pytest.raises(grpc.RpcError) as exc:
+ await stub.TestServe(gRPCTestMessage(text="test"))
+ assert exc.value.details() == "Method not found!"
+ finally:
+ await server.stop(None)
+
+ assert not events
+
+
+@pytest.mark.asyncio
+async def test_grpc_server_starts_transaction(grpc_server_and_channel, capture_events):
+ _, channel = grpc_server_and_channel
+ events = capture_events()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ await stub.TestServe(gRPCTestMessage(text="test"))
+
+ (event,) = events
+ span = event["spans"][0]
+
+ assert event["type"] == "transaction"
+ assert event["transaction_info"] == {
+ "source": "custom",
+ }
+ assert event["contexts"]["trace"]["op"] == OP.GRPC_SERVER
+ assert span["op"] == "test"
+
+
+@pytest.mark.asyncio
+async def test_grpc_server_continues_transaction(
+ grpc_server_and_channel, capture_events
+):
+ _, channel = grpc_server_and_channel
+ events = capture_events()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+
+ with sentry_sdk.start_transaction() as transaction:
+ metadata = (
+ (
+ "baggage",
+ "sentry-trace_id={trace_id},sentry-environment=test,"
+ "sentry-transaction=test-transaction,sentry-sample_rate=1.0".format(
+ trace_id=transaction.trace_id
+ ),
+ ),
+ (
+ "sentry-trace",
+ "{trace_id}-{parent_span_id}-{sampled}".format(
+ trace_id=transaction.trace_id,
+ parent_span_id=transaction.span_id,
+ sampled=1,
+ ),
+ ),
+ )
+
+ await stub.TestServe(gRPCTestMessage(text="test"), metadata=metadata)
+
+ (event, _) = events
+ span = event["spans"][0]
+
+ assert event["type"] == "transaction"
+ assert event["transaction_info"] == {
+ "source": "custom",
+ }
+ assert event["contexts"]["trace"]["op"] == OP.GRPC_SERVER
+ assert event["contexts"]["trace"]["trace_id"] == transaction.trace_id
+ assert span["op"] == "test"
+
+
+@pytest.mark.asyncio
+async def test_grpc_server_exception(grpc_server_and_channel, capture_events):
+ _, channel = grpc_server_and_channel
+ events = capture_events()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ try:
+ await stub.TestServe(gRPCTestMessage(text="exception"))
+ raise AssertionError()
+ except Exception:
+ pass
+
+ (event, _) = events
+
+ assert event["exception"]["values"][0]["type"] == "TestService.TestException"
+ assert event["exception"]["values"][0]["value"] == "test"
+ assert event["exception"]["values"][0]["mechanism"]["handled"] is False
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "grpc"
+
+
+@pytest.mark.asyncio
+async def test_grpc_server_abort(grpc_server_and_channel, capture_events):
+ _, channel = grpc_server_and_channel
+ events = capture_events()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ try:
+ await stub.TestServe(gRPCTestMessage(text="abort"))
+ raise AssertionError()
+ except Exception:
+ pass
+
+ # Add a small delay to allow events to be collected
+ await asyncio.sleep(0.1)
+
+ assert len(events) == 1
+
+
+@pytest.mark.asyncio
+async def test_grpc_client_starts_span(
+ grpc_server_and_channel, capture_events_forksafe
+):
+ _, channel = grpc_server_and_channel
+ events = capture_events_forksafe()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ with start_transaction():
+ await stub.TestServe(gRPCTestMessage(text="test"))
+
+ events.write_file.close()
+ events.read_event()
+ local_transaction = events.read_event()
+ span = local_transaction["spans"][0]
+
+ assert len(local_transaction["spans"]) == 1
+ assert span["op"] == OP.GRPC_CLIENT
+ assert (
+ span["description"]
+ == "unary unary call to /grpc_test_server.gRPCTestService/TestServe"
+ )
+ assert span["data"] == ApproxDict(
+ {
+ "type": "unary unary",
+ "method": "/grpc_test_server.gRPCTestService/TestServe",
+ "code": "OK",
+ }
+ )
+
+
+@pytest.mark.asyncio
+async def test_grpc_client_unary_stream_starts_span(
+ grpc_server_and_channel, capture_events_forksafe
+):
+ _, channel = grpc_server_and_channel
+ events = capture_events_forksafe()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ with start_transaction():
+ response = stub.TestUnaryStream(gRPCTestMessage(text="test"))
+ [_ async for _ in response]
+
+ events.write_file.close()
+ local_transaction = events.read_event()
+ span = local_transaction["spans"][0]
+
+ assert len(local_transaction["spans"]) == 1
+ assert span["op"] == OP.GRPC_CLIENT
+ assert (
+ span["description"]
+ == "unary stream call to /grpc_test_server.gRPCTestService/TestUnaryStream"
+ )
+ assert span["data"] == ApproxDict(
+ {
+ "type": "unary stream",
+ "method": "/grpc_test_server.gRPCTestService/TestUnaryStream",
+ }
+ )
+
+
+@pytest.mark.asyncio
+async def test_stream_stream(grpc_server_and_channel):
+ """
+ Test to verify stream-stream works.
+ Tracing not supported for it yet.
+ """
+ _, channel = grpc_server_and_channel
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ response = stub.TestStreamStream((gRPCTestMessage(text="test"),))
+ async for r in response:
+ assert r.text == "test"
+
+
+@pytest.mark.asyncio
+async def test_stream_unary(grpc_server_and_channel):
+ """
+ Test to verify stream-stream works.
+ Tracing not supported for it yet.
+ """
+ _, channel = grpc_server_and_channel
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ response = await stub.TestStreamUnary((gRPCTestMessage(text="test"),))
+ assert response.text == "test"
+
+
+@pytest.mark.asyncio
+async def test_span_origin(grpc_server_and_channel, capture_events_forksafe):
+ _, channel = grpc_server_and_channel
+ events = capture_events_forksafe()
+
+ # Use the provided channel
+ stub = gRPCTestServiceStub(channel)
+ with start_transaction(name="custom_transaction"):
+ await stub.TestServe(gRPCTestMessage(text="test"))
+
+ events.write_file.close()
+
+ transaction_from_integration = events.read_event()
+ custom_transaction = events.read_event()
+
+ assert (
+ transaction_from_integration["contexts"]["trace"]["origin"] == "auto.grpc.grpc"
+ )
+ assert (
+ transaction_from_integration["spans"][0]["origin"]
+ == "auto.grpc.grpc.TestService.aio"
+ ) # manually created in TestService, not the instrumentation
+
+ assert custom_transaction["contexts"]["trace"]["origin"] == "manual"
+ assert custom_transaction["spans"][0]["origin"] == "auto.grpc.grpc"
+
+
+class TestService(gRPCTestServiceServicer):
+ class TestException(Exception):
+ __test__ = False
+
+ def __init__(self):
+ super().__init__("test")
+
+ @classmethod
+ async def TestServe(cls, request, context): # noqa: N802
+ with start_span(
+ op="test",
+ name="test",
+ origin="auto.grpc.grpc.TestService.aio",
+ ):
+ pass
+
+ if request.text == "exception":
+ raise cls.TestException()
+
+ if request.text == "abort":
+ await context.abort(grpc.StatusCode.ABORTED, "Aborted!")
+
+ return gRPCTestMessage(text=request.text)
+
+ @classmethod
+ async def TestUnaryStream(cls, request, context): # noqa: N802
+ for _ in range(3):
+ yield gRPCTestMessage(text=request.text)
+
+ @classmethod
+ async def TestStreamStream(cls, request, context): # noqa: N802
+ async for r in request:
+ yield r
+
+ @classmethod
+ async def TestStreamUnary(cls, request, context): # noqa: N802
+ requests = [r async for r in request]
+ return requests.pop()
diff --git a/tests/integrations/httpx/__init__.py b/tests/integrations/httpx/__init__.py
index 1afd90ea3a..e524321b8b 100644
--- a/tests/integrations/httpx/__init__.py
+++ b/tests/integrations/httpx/__init__.py
@@ -1,3 +1,9 @@
+import os
+import sys
import pytest
pytest.importorskip("httpx")
+
+# Load `httpx_helpers` into the module search path to test request source path names relative to module. See
+# `test_request_source_with_module_in_search_path`
+sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
diff --git a/tests/integrations/httpx/httpx_helpers/__init__.py b/tests/integrations/httpx/httpx_helpers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/httpx/httpx_helpers/helpers.py b/tests/integrations/httpx/httpx_helpers/helpers.py
new file mode 100644
index 0000000000..f1d4f3c98b
--- /dev/null
+++ b/tests/integrations/httpx/httpx_helpers/helpers.py
@@ -0,0 +1,6 @@
+def get_request_with_client(client, url):
+ client.get(url)
+
+
+async def async_get_request_with_client(client, url):
+ await client.get(url)
diff --git a/tests/integrations/httpx/test_httpx.py b/tests/integrations/httpx/test_httpx.py
index 74b15b8958..33bdc93c73 100644
--- a/tests/integrations/httpx/test_httpx.py
+++ b/tests/integrations/httpx/test_httpx.py
@@ -1,19 +1,26 @@
+import os
+import datetime
import asyncio
+from unittest import mock
-import pytest
import httpx
-import responses
+import pytest
+from contextlib import contextmanager
+import sentry_sdk
from sentry_sdk import capture_message, start_transaction
-from sentry_sdk.consts import MATCH_ALL
+from sentry_sdk.consts import MATCH_ALL, SPANDATA
from sentry_sdk.integrations.httpx import HttpxIntegration
+from tests.conftest import ApproxDict
@pytest.mark.parametrize(
"httpx_client",
(httpx.Client(), httpx.AsyncClient()),
)
-def test_crumb_capture_and_hint(sentry_init, capture_events, httpx_client):
+def test_crumb_capture_and_hint(sentry_init, capture_events, httpx_client, httpx_mock):
+ httpx_mock.add_response()
+
def before_breadcrumb(crumb, hint):
crumb["data"]["extra"] = "foo"
return crumb
@@ -21,7 +28,6 @@ def before_breadcrumb(crumb, hint):
sentry_init(integrations=[HttpxIntegration()], before_breadcrumb=before_breadcrumb)
url = "http://example.com/"
- responses.add(responses.GET, url, status=200)
with start_transaction():
events = capture_events()
@@ -41,26 +47,90 @@ def before_breadcrumb(crumb, hint):
crumb = event["breadcrumbs"]["values"][0]
assert crumb["type"] == "http"
assert crumb["category"] == "httplib"
- assert crumb["data"] == {
- "url": url,
- "method": "GET",
- "http.fragment": "",
- "http.query": "",
- "status_code": 200,
- "reason": "OK",
- "extra": "foo",
- }
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": url,
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_FRAGMENT: "",
+ SPANDATA.HTTP_QUERY: "",
+ SPANDATA.HTTP_STATUS_CODE: 200,
+ "reason": "OK",
+ "extra": "foo",
+ }
+ )
+
+
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+@pytest.mark.parametrize(
+ "status_code,level",
+ [
+ (200, None),
+ (301, None),
+ (403, "warning"),
+ (405, "warning"),
+ (500, "error"),
+ ],
+)
+def test_crumb_capture_client_error(
+ sentry_init, capture_events, httpx_client, httpx_mock, status_code, level
+):
+ httpx_mock.add_response(status_code=status_code)
+
+ sentry_init(integrations=[HttpxIntegration()])
+
+ url = "http://example.com/"
+
+ with start_transaction():
+ events = capture_events()
+
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ response = asyncio.get_event_loop().run_until_complete(
+ httpx_client.get(url)
+ )
+ else:
+ response = httpx_client.get(url)
+
+ assert response.status_code == status_code
+ capture_message("Testing!")
+
+ (event,) = events
+
+ crumb = event["breadcrumbs"]["values"][0]
+ assert crumb["type"] == "http"
+ assert crumb["category"] == "httplib"
+
+ if level is None:
+ assert "level" not in crumb
+ else:
+ assert crumb["level"] == level
+
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": url,
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_FRAGMENT: "",
+ SPANDATA.HTTP_QUERY: "",
+ SPANDATA.HTTP_STATUS_CODE: status_code,
+ }
+ )
@pytest.mark.parametrize(
"httpx_client",
(httpx.Client(), httpx.AsyncClient()),
)
-def test_outgoing_trace_headers(sentry_init, httpx_client):
- sentry_init(traces_sample_rate=1.0, integrations=[HttpxIntegration()])
+def test_outgoing_trace_headers(sentry_init, httpx_client, httpx_mock):
+ httpx_mock.add_response()
+
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[HttpxIntegration()],
+ )
url = "http://example.com/"
- responses.add(responses.GET, url, status=200)
with start_transaction(
name="/interactions/other-dogs/new-dog",
@@ -84,6 +154,53 @@ def test_outgoing_trace_headers(sentry_init, httpx_client):
)
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+def test_outgoing_trace_headers_append_to_baggage(
+ sentry_init,
+ httpx_client,
+ httpx_mock,
+):
+ httpx_mock.add_response()
+
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[HttpxIntegration()],
+ release="d08ebdb9309e1b004c6f52202de58a09c2268e42",
+ )
+
+ url = "http://example.com/"
+
+ # patch random.randrange to return a predictable sample_rand value
+ with mock.patch("sentry_sdk.tracing_utils.Random.randrange", return_value=500000):
+ with start_transaction(
+ name="/interactions/other-dogs/new-dog",
+ op="greeting.sniff",
+ trace_id="01234567890123456789012345678901",
+ ) as transaction:
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ response = asyncio.get_event_loop().run_until_complete(
+ httpx_client.get(url, headers={"baGGage": "custom=data"})
+ )
+ else:
+ response = httpx_client.get(url, headers={"baGGage": "custom=data"})
+
+ request_span = transaction._span_recorder.spans[-1]
+ assert response.request.headers[
+ "sentry-trace"
+ ] == "{trace_id}-{parent_span_id}-{sampled}".format(
+ trace_id=transaction.trace_id,
+ parent_span_id=request_span.span_id,
+ sampled=1,
+ )
+ assert (
+ response.request.headers["baggage"]
+ == "custom=data,sentry-trace_id=01234567890123456789012345678901,sentry-sample_rand=0.500000,sentry-environment=production,sentry-release=d08ebdb9309e1b004c6f52202de58a09c2268e42,sentry-transaction=/interactions/other-dogs/new-dog,sentry-sample_rate=1.0,sentry-sampled=true"
+ )
+
+
@pytest.mark.parametrize(
"httpx_client,trace_propagation_targets,url,trace_propagated",
[
@@ -214,10 +331,12 @@ def test_option_trace_propagation_targets(
integrations=[HttpxIntegration()],
)
- if asyncio.iscoroutinefunction(httpx_client.get):
- asyncio.get_event_loop().run_until_complete(httpx_client.get(url))
- else:
- httpx_client.get(url)
+ # Must be in a transaction to propagate headers
+ with sentry_sdk.start_transaction():
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ asyncio.get_event_loop().run_until_complete(httpx_client.get(url))
+ else:
+ httpx_client.get(url)
request_headers = httpx_mock.get_request().headers
@@ -225,3 +344,389 @@ def test_option_trace_propagation_targets(
assert "sentry-trace" in request_headers
else:
assert "sentry-trace" not in request_headers
+
+
+def test_do_not_propagate_outside_transaction(sentry_init, httpx_mock):
+ httpx_mock.add_response()
+
+ sentry_init(
+ traces_sample_rate=1.0,
+ trace_propagation_targets=[MATCH_ALL],
+ integrations=[HttpxIntegration()],
+ )
+
+ httpx_client = httpx.Client()
+ httpx_client.get("http://example.com/")
+
+ request_headers = httpx_mock.get_request().headers
+ assert "sentry-trace" not in request_headers
+
+
+@pytest.mark.tests_internal_exceptions
+def test_omit_url_data_if_parsing_fails(sentry_init, capture_events, httpx_mock):
+ httpx_mock.add_response()
+
+ sentry_init(integrations=[HttpxIntegration()])
+
+ httpx_client = httpx.Client()
+ url = "http://example.com"
+
+ events = capture_events()
+ with mock.patch(
+ "sentry_sdk.integrations.httpx.parse_url",
+ side_effect=ValueError,
+ ):
+ response = httpx_client.get(url)
+
+ assert response.status_code == 200
+ capture_message("Testing!")
+
+ (event,) = events
+ assert event["breadcrumbs"]["values"][0]["data"] == ApproxDict(
+ {
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_STATUS_CODE: 200,
+ "reason": "OK",
+ # no url related data
+ }
+ )
+
+ assert "url" not in event["breadcrumbs"]["values"][0]["data"]
+ assert SPANDATA.HTTP_FRAGMENT not in event["breadcrumbs"]["values"][0]["data"]
+ assert SPANDATA.HTTP_QUERY not in event["breadcrumbs"]["values"][0]["data"]
+
+
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+def test_request_source_disabled(sentry_init, capture_events, httpx_client, httpx_mock):
+ httpx_mock.add_response()
+ sentry_options = {
+ "integrations": [HttpxIntegration()],
+ "traces_sample_rate": 1.0,
+ "enable_http_request_source": False,
+ "http_request_source_threshold_ms": 0,
+ }
+
+ sentry_init(**sentry_options)
+
+ events = capture_events()
+
+ url = "http://example.com/"
+
+ with start_transaction(name="test_transaction"):
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ asyncio.get_event_loop().run_until_complete(httpx_client.get(url))
+ else:
+ httpx_client.get(url)
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+@pytest.mark.parametrize("enable_http_request_source", [None, True])
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+def test_request_source_enabled(
+ sentry_init, capture_events, enable_http_request_source, httpx_client, httpx_mock
+):
+ httpx_mock.add_response()
+ sentry_options = {
+ "integrations": [HttpxIntegration()],
+ "traces_sample_rate": 1.0,
+ "http_request_source_threshold_ms": 0,
+ }
+ if enable_http_request_source is not None:
+ sentry_options["enable_http_request_source"] = enable_http_request_source
+
+ sentry_init(**sentry_options)
+
+ events = capture_events()
+
+ url = "http://example.com/"
+
+ with start_transaction(name="test_transaction"):
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ asyncio.get_event_loop().run_until_complete(httpx_client.get(url))
+ else:
+ httpx_client.get(url)
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+def test_request_source(sentry_init, capture_events, httpx_client, httpx_mock):
+ httpx_mock.add_response()
+
+ sentry_init(
+ integrations=[HttpxIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=0,
+ )
+
+ events = capture_events()
+
+ url = "http://example.com/"
+
+ with start_transaction(name="test_transaction"):
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ asyncio.get_event_loop().run_until_complete(httpx_client.get(url))
+ else:
+ httpx_client.get(url)
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "tests.integrations.httpx.test_httpx"
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/httpx/test_httpx.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "test_request_source"
+
+
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+def test_request_source_with_module_in_search_path(
+ sentry_init, capture_events, httpx_client, httpx_mock
+):
+ """
+ Test that request source is relative to the path of the module it ran in
+ """
+ httpx_mock.add_response()
+ sentry_init(
+ integrations=[HttpxIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=0,
+ )
+
+ events = capture_events()
+
+ url = "http://example.com/"
+
+ with start_transaction(name="test_transaction"):
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ from httpx_helpers.helpers import async_get_request_with_client
+
+ asyncio.get_event_loop().run_until_complete(
+ async_get_request_with_client(httpx_client, url)
+ )
+ else:
+ from httpx_helpers.helpers import get_request_with_client
+
+ get_request_with_client(httpx_client, url)
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "httpx_helpers.helpers"
+ assert data.get(SPANDATA.CODE_FILEPATH) == "httpx_helpers/helpers.py"
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ assert data.get(SPANDATA.CODE_FUNCTION) == "async_get_request_with_client"
+ else:
+ assert data.get(SPANDATA.CODE_FUNCTION) == "get_request_with_client"
+
+
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+def test_no_request_source_if_duration_too_short(
+ sentry_init, capture_events, httpx_client, httpx_mock
+):
+ httpx_mock.add_response()
+
+ sentry_init(
+ integrations=[HttpxIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=100,
+ )
+
+ events = capture_events()
+
+ url = "http://example.com/"
+
+ with start_transaction(name="test_transaction"):
+
+ @contextmanager
+ def fake_start_span(*args, **kwargs):
+ with sentry_sdk.start_span(*args, **kwargs) as span:
+ pass
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=99999)
+ yield span
+
+ with mock.patch(
+ "sentry_sdk.integrations.httpx.start_span",
+ fake_start_span,
+ ):
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ asyncio.get_event_loop().run_until_complete(httpx_client.get(url))
+ else:
+ httpx_client.get(url)
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+def test_request_source_if_duration_over_threshold(
+ sentry_init, capture_events, httpx_client, httpx_mock
+):
+ httpx_mock.add_response()
+
+ sentry_init(
+ integrations=[HttpxIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=100,
+ )
+
+ events = capture_events()
+
+ url = "http://example.com/"
+
+ with start_transaction(name="test_transaction"):
+
+ @contextmanager
+ def fake_start_span(*args, **kwargs):
+ with sentry_sdk.start_span(*args, **kwargs) as span:
+ pass
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=100001)
+ yield span
+
+ with mock.patch(
+ "sentry_sdk.integrations.httpx.start_span",
+ fake_start_span,
+ ):
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ asyncio.get_event_loop().run_until_complete(httpx_client.get(url))
+ else:
+ httpx_client.get(url)
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "tests.integrations.httpx.test_httpx"
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/httpx/test_httpx.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert (
+ data.get(SPANDATA.CODE_FUNCTION)
+ == "test_request_source_if_duration_over_threshold"
+ )
+
+
+@pytest.mark.parametrize(
+ "httpx_client",
+ (httpx.Client(), httpx.AsyncClient()),
+)
+def test_span_origin(sentry_init, capture_events, httpx_client, httpx_mock):
+ httpx_mock.add_response()
+
+ sentry_init(
+ integrations=[HttpxIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ url = "http://example.com/"
+
+ with start_transaction(name="test_transaction"):
+ if asyncio.iscoroutinefunction(httpx_client.get):
+ asyncio.get_event_loop().run_until_complete(httpx_client.get(url))
+ else:
+ httpx_client.get(url)
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.http.httpx"
diff --git a/tests/integrations/huey/test_huey.py b/tests/integrations/huey/test_huey.py
index 819a4816d7..143a369348 100644
--- a/tests/integrations/huey/test_huey.py
+++ b/tests/integrations/huey/test_huey.py
@@ -3,11 +3,16 @@
from sentry_sdk import start_transaction
from sentry_sdk.integrations.huey import HueyIntegration
+from sentry_sdk.utils import parse_version
+from huey import __version__ as HUEY_VERSION
from huey.api import MemoryHuey, Result
from huey.exceptions import RetryTask
+HUEY_VERSION = parse_version(HUEY_VERSION)
+
+
@pytest.fixture
def init_huey(sentry_init):
def inner():
@@ -15,7 +20,6 @@ def inner():
integrations=[HueyIntegration()],
traces_sample_rate=1.0,
send_default_pii=True,
- debug=True,
)
return MemoryHuey(name="sentry_sdk")
@@ -118,6 +122,35 @@ def retry_task(context):
assert len(huey) == 0
+@pytest.mark.parametrize("lock_name", ["lock.a", "lock.b"], ids=["locked", "unlocked"])
+@pytest.mark.skipif(HUEY_VERSION < (2, 5), reason="is_locked was added in 2.5")
+def test_task_lock(capture_events, init_huey, lock_name):
+ huey = init_huey()
+
+ task_lock_name = "lock.a"
+ should_be_locked = task_lock_name == lock_name
+
+ @huey.task()
+ @huey.lock_task(task_lock_name)
+ def maybe_locked_task():
+ pass
+
+ events = capture_events()
+
+ with huey.lock_task(lock_name):
+ assert huey.is_locked(task_lock_name) == should_be_locked
+ result = execute_huey_task(huey, maybe_locked_task)
+
+ (event,) = events
+
+ assert event["transaction"] == "maybe_locked_task"
+ assert event["tags"]["huey_task_id"] == result.task.id
+ assert (
+ event["contexts"]["trace"]["status"] == "aborted" if should_be_locked else "ok"
+ )
+ assert len(huey) == 0
+
+
def test_huey_enqueue(init_huey, capture_events):
huey = init_huey()
@@ -138,3 +171,55 @@ def dummy_task():
assert len(event["spans"])
assert event["spans"][0]["op"] == "queue.submit.huey"
assert event["spans"][0]["description"] == "different_task_name"
+
+
+def test_huey_propagate_trace(init_huey, capture_events):
+ huey = init_huey()
+
+ events = capture_events()
+
+ @huey.task()
+ def propagated_trace_task():
+ pass
+
+ with start_transaction() as outer_transaction:
+ execute_huey_task(huey, propagated_trace_task)
+
+ assert (
+ events[0]["transaction"] == "propagated_trace_task"
+ ) # the "inner" transaction
+ assert events[0]["contexts"]["trace"]["trace_id"] == outer_transaction.trace_id
+
+
+def test_span_origin_producer(init_huey, capture_events):
+ huey = init_huey()
+
+ @huey.task(name="different_task_name")
+ def dummy_task():
+ pass
+
+ events = capture_events()
+
+ with start_transaction():
+ dummy_task()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.queue.huey"
+
+
+def test_span_origin_consumer(init_huey, capture_events):
+ huey = init_huey()
+
+ events = capture_events()
+
+ @huey.task()
+ def propagated_trace_task():
+ pass
+
+ execute_huey_task(huey, propagated_trace_task)
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.queue.huey"
diff --git a/tests/integrations/huggingface_hub/__init__.py b/tests/integrations/huggingface_hub/__init__.py
new file mode 100644
index 0000000000..fe1fa0af50
--- /dev/null
+++ b/tests/integrations/huggingface_hub/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("huggingface_hub")
diff --git a/tests/integrations/huggingface_hub/test_huggingface_hub.py b/tests/integrations/huggingface_hub/test_huggingface_hub.py
new file mode 100644
index 0000000000..9dd15ca4b5
--- /dev/null
+++ b/tests/integrations/huggingface_hub/test_huggingface_hub.py
@@ -0,0 +1,1017 @@
+import re
+from typing import TYPE_CHECKING
+from unittest import mock
+
+import pytest
+import responses
+from huggingface_hub import InferenceClient
+
+import sentry_sdk
+from sentry_sdk.integrations.huggingface_hub import HuggingfaceHubIntegration
+from sentry_sdk.utils import package_version
+
+try:
+ from huggingface_hub.utils._errors import HfHubHTTPError
+except ImportError:
+ from huggingface_hub.errors import HfHubHTTPError
+
+
+if TYPE_CHECKING:
+ from typing import Any
+
+
+HF_VERSION = package_version("huggingface-hub")
+
+if HF_VERSION and HF_VERSION < (0, 30, 0):
+ MODEL_ENDPOINT = "https://api-inference.huggingface.co/models/{model_name}"
+ INFERENCE_ENDPOINT = "https://api-inference.huggingface.co/models/{model_name}"
+else:
+ MODEL_ENDPOINT = "https://huggingface.co/api/models/{model_name}"
+ INFERENCE_ENDPOINT = (
+ "https://router.huggingface.co/hf-inference/models/{model_name}"
+ )
+
+
+def get_hf_provider_inference_client():
+ # The provider parameter was added in version 0.28.0 of huggingface_hub
+ return (
+ InferenceClient(model="test-model", provider="hf-inference")
+ if HF_VERSION >= (0, 28, 0)
+ else InferenceClient(model="test-model")
+ )
+
+
+def _add_mock_response(
+ httpx_mock, rsps, method, url, json=None, status=200, body=None, headers=None
+):
+ # HF v1+ uses httpx for making requests to their API, while <1 uses requests.
+ # Since we have to test both, we need mocks for both httpx and requests.
+ if HF_VERSION >= (1, 0, 0):
+ httpx_mock.add_response(
+ method=method,
+ url=url,
+ json=json,
+ content=body,
+ status_code=status,
+ headers=headers,
+ is_optional=True,
+ is_reusable=True,
+ )
+ else:
+ rsps.add(
+ method=method,
+ url=url,
+ json=json,
+ body=body,
+ status=status,
+ headers=headers,
+ )
+
+
+@pytest.fixture
+def mock_hf_text_generation_api(httpx_mock):
+ # type: () -> Any
+ """Mock HuggingFace text generation API"""
+
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
+ model_name = "test-model"
+
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "GET",
+ re.compile(
+ MODEL_ENDPOINT.format(model_name=model_name)
+ + r"(\?expand=inferenceProviderMapping)?"
+ ),
+ json={
+ "id": model_name,
+ "pipeline_tag": "text-generation",
+ "inferenceProviderMapping": {
+ "hf-inference": {
+ "status": "live",
+ "providerId": model_name,
+ "task": "text-generation",
+ }
+ },
+ },
+ status=200,
+ )
+
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "POST",
+ INFERENCE_ENDPOINT.format(model_name=model_name),
+ json={
+ "generated_text": "[mocked] Hello! How can i help you?",
+ "details": {
+ "finish_reason": "length",
+ "generated_tokens": 10,
+ "prefill": [],
+ "tokens": [],
+ },
+ },
+ status=200,
+ )
+
+ if HF_VERSION >= (1, 0, 0):
+ yield httpx_mock
+ else:
+ yield rsps
+
+
+@pytest.fixture
+def mock_hf_api_with_errors(httpx_mock):
+ # type: () -> Any
+ """Mock HuggingFace API that always raises errors for any request"""
+
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
+ model_name = "test-model"
+
+ # Mock model info endpoint with error
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "GET",
+ MODEL_ENDPOINT.format(model_name=model_name),
+ json={"error": "Model not found"},
+ status=404,
+ )
+
+ # Mock text generation endpoint with error
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "POST",
+ INFERENCE_ENDPOINT.format(model_name=model_name),
+ json={"error": "Internal server error", "message": "Something went wrong"},
+ status=500,
+ )
+
+ # Mock chat completion endpoint with error
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "POST",
+ INFERENCE_ENDPOINT.format(model_name=model_name) + "/v1/chat/completions",
+ json={"error": "Internal server error", "message": "Something went wrong"},
+ status=500,
+ )
+
+ # Catch-all pattern for any other model requests
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "GET",
+ "https://huggingface.co/api/models/test-model-error",
+ json={"error": "Generic model error"},
+ status=500,
+ )
+
+ if HF_VERSION >= (1, 0, 0):
+ yield httpx_mock
+ else:
+ yield rsps
+
+
+@pytest.fixture
+def mock_hf_text_generation_api_streaming(httpx_mock):
+ # type: () -> Any
+ """Mock streaming HuggingFace text generation API"""
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
+ model_name = "test-model"
+
+ # Mock model info endpoint
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "GET",
+ MODEL_ENDPOINT.format(model_name=model_name),
+ json={
+ "id": model_name,
+ "pipeline_tag": "text-generation",
+ "inferenceProviderMapping": {
+ "hf-inference": {
+ "status": "live",
+ "providerId": model_name,
+ "task": "text-generation",
+ }
+ },
+ },
+ status=200,
+ )
+
+ # Mock text generation endpoint for streaming
+ streaming_response = b'data:{"token":{"id":1, "special": false, "text": "the mocked "}}\n\ndata:{"token":{"id":2, "special": false, "text": "model response"}, "details":{"finish_reason": "length", "generated_tokens": 10, "seed": 0}}\n\n'
+
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "POST",
+ INFERENCE_ENDPOINT.format(model_name=model_name),
+ body=streaming_response,
+ status=200,
+ headers={
+ "Content-Type": "text/event-stream",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ },
+ )
+
+ if HF_VERSION >= (1, 0, 0):
+ yield httpx_mock
+ else:
+ yield rsps
+
+
+@pytest.fixture
+def mock_hf_chat_completion_api(httpx_mock):
+ # type: () -> Any
+ """Mock HuggingFace chat completion API"""
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
+ model_name = "test-model"
+
+ # Mock model info endpoint
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "GET",
+ MODEL_ENDPOINT.format(model_name=model_name),
+ json={
+ "id": model_name,
+ "pipeline_tag": "conversational",
+ "inferenceProviderMapping": {
+ "hf-inference": {
+ "status": "live",
+ "providerId": model_name,
+ "task": "conversational",
+ }
+ },
+ },
+ status=200,
+ )
+
+ # Mock chat completion endpoint
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "POST",
+ INFERENCE_ENDPOINT.format(model_name=model_name) + "/v1/chat/completions",
+ json={
+ "id": "xyz-123",
+ "created": 1234567890,
+ "model": f"{model_name}-123",
+ "system_fingerprint": "fp_123",
+ "choices": [
+ {
+ "index": 0,
+ "finish_reason": "stop",
+ "message": {
+ "role": "assistant",
+ "content": "[mocked] Hello! How can I help you today?",
+ },
+ }
+ ],
+ "usage": {
+ "completion_tokens": 8,
+ "prompt_tokens": 10,
+ "total_tokens": 18,
+ },
+ },
+ status=200,
+ )
+
+ if HF_VERSION >= (1, 0, 0):
+ yield httpx_mock
+ else:
+ yield rsps
+
+
+@pytest.fixture
+def mock_hf_chat_completion_api_tools(httpx_mock):
+ # type: () -> Any
+ """Mock HuggingFace chat completion API with tool calls."""
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
+ model_name = "test-model"
+
+ # Mock model info endpoint
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "GET",
+ MODEL_ENDPOINT.format(model_name=model_name),
+ json={
+ "id": model_name,
+ "pipeline_tag": "conversational",
+ "inferenceProviderMapping": {
+ "hf-inference": {
+ "status": "live",
+ "providerId": model_name,
+ "task": "conversational",
+ }
+ },
+ },
+ status=200,
+ )
+
+ # Mock chat completion endpoint
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "POST",
+ INFERENCE_ENDPOINT.format(model_name=model_name) + "/v1/chat/completions",
+ json={
+ "id": "xyz-123",
+ "created": 1234567890,
+ "model": f"{model_name}-123",
+ "system_fingerprint": "fp_123",
+ "choices": [
+ {
+ "index": 0,
+ "finish_reason": "tool_calls",
+ "message": {
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "id": "call_123",
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "arguments": {"location": "Paris"},
+ },
+ }
+ ],
+ },
+ }
+ ],
+ "usage": {
+ "completion_tokens": 8,
+ "prompt_tokens": 10,
+ "total_tokens": 18,
+ },
+ },
+ status=200,
+ )
+
+ if HF_VERSION >= (1, 0, 0):
+ yield httpx_mock
+ else:
+ yield rsps
+
+
+@pytest.fixture
+def mock_hf_chat_completion_api_streaming(httpx_mock):
+ # type: () -> Any
+ """Mock streaming HuggingFace chat completion API"""
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
+ model_name = "test-model"
+
+ # Mock model info endpoint
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "GET",
+ MODEL_ENDPOINT.format(model_name=model_name),
+ json={
+ "id": model_name,
+ "pipeline_tag": "conversational",
+ "inferenceProviderMapping": {
+ "hf-inference": {
+ "status": "live",
+ "providerId": model_name,
+ "task": "conversational",
+ }
+ },
+ },
+ status=200,
+ )
+
+ # Mock chat completion streaming endpoint
+ streaming_chat_response = (
+ b'data:{"id":"xyz-123","created":1234567890,"model":"test-model-123","system_fingerprint":"fp_123","choices":[{"delta":{"role":"assistant","content":"the mocked "},"index":0,"finish_reason":null}],"usage":null}\n\n'
+ b'data:{"id":"xyz-124","created":1234567890,"model":"test-model-123","system_fingerprint":"fp_123","choices":[{"delta":{"role":"assistant","content":"model response"},"index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":183,"completion_tokens":14,"total_tokens":197}}\n\n'
+ )
+
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "POST",
+ INFERENCE_ENDPOINT.format(model_name=model_name) + "/v1/chat/completions",
+ body=streaming_chat_response,
+ status=200,
+ headers={
+ "Content-Type": "text/event-stream",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ },
+ )
+
+ if HF_VERSION >= (1, 0, 0):
+ yield httpx_mock
+ else:
+ yield rsps
+
+
+@pytest.fixture
+def mock_hf_chat_completion_api_streaming_tools(httpx_mock):
+ # type: () -> Any
+ """Mock streaming HuggingFace chat completion API with tool calls."""
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
+ model_name = "test-model"
+
+ # Mock model info endpoint
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "GET",
+ MODEL_ENDPOINT.format(model_name=model_name),
+ json={
+ "id": model_name,
+ "pipeline_tag": "conversational",
+ "inferenceProviderMapping": {
+ "hf-inference": {
+ "status": "live",
+ "providerId": model_name,
+ "task": "conversational",
+ }
+ },
+ },
+ status=200,
+ )
+
+ # Mock chat completion streaming endpoint
+ streaming_chat_response = (
+ b'data:{"id":"xyz-123","created":1234567890,"model":"test-model-123","system_fingerprint":"fp_123","choices":[{"delta":{"role":"assistant","content":"response with tool calls follows"},"index":0,"finish_reason":null}],"usage":null}\n\n'
+ b'data:{"id":"xyz-124","created":1234567890,"model":"test-model-123","system_fingerprint":"fp_123","choices":[{"delta":{"role":"assistant","tool_calls": [{"id": "call_123","type": "function","function": {"name": "get_weather", "arguments": {"location": "Paris"}}}]},"index":0,"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":183,"completion_tokens":14,"total_tokens":197}}\n\n'
+ )
+
+ _add_mock_response(
+ httpx_mock,
+ rsps,
+ "POST",
+ INFERENCE_ENDPOINT.format(model_name=model_name) + "/v1/chat/completions",
+ body=streaming_chat_response,
+ status=200,
+ headers={
+ "Content-Type": "text/event-stream",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ },
+ )
+
+ if HF_VERSION >= (1, 0, 0):
+ yield httpx_mock
+ else:
+ yield rsps
+
+
+@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
+@pytest.mark.parametrize("send_default_pii", [True, False])
+@pytest.mark.parametrize("include_prompts", [True, False])
+def test_text_generation(
+ sentry_init: "Any",
+ capture_events: "Any",
+ send_default_pii: "Any",
+ include_prompts: "Any",
+ mock_hf_text_generation_api: "Any",
+) -> None:
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)],
+ )
+ events = capture_events()
+
+ client = InferenceClient(model="test-model")
+
+ with sentry_sdk.start_transaction(name="test"):
+ client.text_generation(
+ "Hello",
+ stream=False,
+ details=True,
+ )
+
+ (transaction,) = events
+
+ span = None
+ for sp in transaction["spans"]:
+ if sp["op"].startswith("gen_ai"):
+ assert span is None, "there is exactly one gen_ai span"
+ span = sp
+ else:
+ # there should be no other spans, just the gen_ai span
+ # and optionally some http.client spans from talking to the hf api
+ assert sp["op"] == "http.client"
+
+ assert span is not None
+
+ assert span["op"] == "gen_ai.text_completion"
+ assert span["description"] == "text_completion test-model"
+ assert span["origin"] == "auto.ai.huggingface_hub"
+
+ expected_data = {
+ "gen_ai.operation.name": "text_completion",
+ "gen_ai.request.model": "test-model",
+ "gen_ai.response.finish_reasons": "length",
+ "gen_ai.response.streaming": False,
+ "gen_ai.usage.total_tokens": 10,
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ if send_default_pii and include_prompts:
+ expected_data["gen_ai.request.messages"] = "Hello"
+ expected_data["gen_ai.response.text"] = "[mocked] Hello! How can i help you?"
+
+ if not send_default_pii or not include_prompts:
+ assert "gen_ai.request.messages" not in expected_data
+ assert "gen_ai.response.text" not in expected_data
+
+ assert span["data"] == expected_data
+
+ # text generation does not set the response model
+ assert "gen_ai.response.model" not in span["data"]
+
+
+@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
+@pytest.mark.parametrize("send_default_pii", [True, False])
+@pytest.mark.parametrize("include_prompts", [True, False])
+def test_text_generation_streaming(
+ sentry_init: "Any",
+ capture_events: "Any",
+ send_default_pii: "Any",
+ include_prompts: "Any",
+ mock_hf_text_generation_api_streaming: "Any",
+) -> None:
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)],
+ )
+ events = capture_events()
+
+ client = InferenceClient(model="test-model")
+
+ with sentry_sdk.start_transaction(name="test"):
+ for _ in client.text_generation(
+ prompt="Hello",
+ stream=True,
+ details=True,
+ ):
+ pass
+
+ (transaction,) = events
+
+ span = None
+ for sp in transaction["spans"]:
+ if sp["op"].startswith("gen_ai"):
+ assert span is None, "there is exactly one gen_ai span"
+ span = sp
+ else:
+ # there should be no other spans, just the gen_ai span
+ # and optionally some http.client spans from talking to the hf api
+ assert sp["op"] == "http.client"
+
+ assert span is not None
+
+ assert span["op"] == "gen_ai.text_completion"
+ assert span["description"] == "text_completion test-model"
+ assert span["origin"] == "auto.ai.huggingface_hub"
+
+ expected_data = {
+ "gen_ai.operation.name": "text_completion",
+ "gen_ai.request.model": "test-model",
+ "gen_ai.response.finish_reasons": "length",
+ "gen_ai.response.streaming": True,
+ "gen_ai.usage.total_tokens": 10,
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ if send_default_pii and include_prompts:
+ expected_data["gen_ai.request.messages"] = "Hello"
+ expected_data["gen_ai.response.text"] = "the mocked model response"
+
+ if not send_default_pii or not include_prompts:
+ assert "gen_ai.request.messages" not in expected_data
+ assert "gen_ai.response.text" not in expected_data
+
+ assert span["data"] == expected_data
+
+ # text generation does not set the response model
+ assert "gen_ai.response.model" not in span["data"]
+
+
+@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
+@pytest.mark.parametrize("send_default_pii", [True, False])
+@pytest.mark.parametrize("include_prompts", [True, False])
+def test_chat_completion(
+ sentry_init: "Any",
+ capture_events: "Any",
+ send_default_pii: "Any",
+ include_prompts: "Any",
+ mock_hf_chat_completion_api: "Any",
+) -> None:
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)],
+ )
+ events = capture_events()
+
+ client = get_hf_provider_inference_client()
+
+ with sentry_sdk.start_transaction(name="test"):
+ client.chat_completion(
+ messages=[{"role": "user", "content": "Hello!"}],
+ stream=False,
+ )
+
+ (transaction,) = events
+
+ span = None
+ for sp in transaction["spans"]:
+ if sp["op"].startswith("gen_ai"):
+ assert span is None, "there is exactly one gen_ai span"
+ span = sp
+ else:
+ # there should be no other spans, just the gen_ai span
+ # and optionally some http.client spans from talking to the hf api
+ assert sp["op"] == "http.client"
+
+ assert span is not None
+
+ assert span["op"] == "gen_ai.chat"
+ assert span["description"] == "chat test-model"
+ assert span["origin"] == "auto.ai.huggingface_hub"
+
+ expected_data = {
+ "gen_ai.operation.name": "chat",
+ "gen_ai.request.model": "test-model",
+ "gen_ai.response.finish_reasons": "stop",
+ "gen_ai.response.model": "test-model-123",
+ "gen_ai.response.streaming": False,
+ "gen_ai.usage.input_tokens": 10,
+ "gen_ai.usage.output_tokens": 8,
+ "gen_ai.usage.total_tokens": 18,
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ if send_default_pii and include_prompts:
+ expected_data["gen_ai.request.messages"] = (
+ '[{"role": "user", "content": "Hello!"}]'
+ )
+ expected_data["gen_ai.response.text"] = (
+ "[mocked] Hello! How can I help you today?"
+ )
+
+ if not send_default_pii or not include_prompts:
+ assert "gen_ai.request.messages" not in expected_data
+ assert "gen_ai.response.text" not in expected_data
+
+ assert span["data"] == expected_data
+
+
+@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
+@pytest.mark.parametrize("send_default_pii", [True, False])
+@pytest.mark.parametrize("include_prompts", [True, False])
+def test_chat_completion_streaming(
+ sentry_init: "Any",
+ capture_events: "Any",
+ send_default_pii: "Any",
+ include_prompts: "Any",
+ mock_hf_chat_completion_api_streaming: "Any",
+) -> None:
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)],
+ )
+ events = capture_events()
+
+ client = get_hf_provider_inference_client()
+
+ with sentry_sdk.start_transaction(name="test"):
+ _ = list(
+ client.chat_completion(
+ [{"role": "user", "content": "Hello!"}],
+ stream=True,
+ )
+ )
+
+ (transaction,) = events
+
+ span = None
+ for sp in transaction["spans"]:
+ if sp["op"].startswith("gen_ai"):
+ assert span is None, "there is exactly one gen_ai span"
+ span = sp
+ else:
+ # there should be no other spans, just the gen_ai span
+ # and optionally some http.client spans from talking to the hf api
+ assert sp["op"] == "http.client"
+
+ assert span is not None
+
+ assert span["op"] == "gen_ai.chat"
+ assert span["description"] == "chat test-model"
+ assert span["origin"] == "auto.ai.huggingface_hub"
+
+ expected_data = {
+ "gen_ai.operation.name": "chat",
+ "gen_ai.request.model": "test-model",
+ "gen_ai.response.finish_reasons": "stop",
+ "gen_ai.response.model": "test-model-123",
+ "gen_ai.response.streaming": True,
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+ # usage is not available in older versions of the library
+ if HF_VERSION and HF_VERSION >= (0, 26, 0):
+ expected_data["gen_ai.usage.input_tokens"] = 183
+ expected_data["gen_ai.usage.output_tokens"] = 14
+ expected_data["gen_ai.usage.total_tokens"] = 197
+
+ if send_default_pii and include_prompts:
+ expected_data["gen_ai.request.messages"] = (
+ '[{"role": "user", "content": "Hello!"}]'
+ )
+ expected_data["gen_ai.response.text"] = "the mocked model response"
+
+ if not send_default_pii or not include_prompts:
+ assert "gen_ai.request.messages" not in expected_data
+ assert "gen_ai.response.text" not in expected_data
+
+ assert span["data"] == expected_data
+
+
+@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
+def test_chat_completion_api_error(
+ sentry_init: "Any", capture_events: "Any", mock_hf_api_with_errors: "Any"
+) -> None:
+ sentry_init(traces_sample_rate=1.0)
+ events = capture_events()
+
+ client = get_hf_provider_inference_client()
+
+ with sentry_sdk.start_transaction(name="test"):
+ with pytest.raises(HfHubHTTPError):
+ client.chat_completion(
+ messages=[{"role": "user", "content": "Hello!"}],
+ )
+
+ (
+ error,
+ transaction,
+ ) = events
+
+ assert error["exception"]["values"][0]["mechanism"]["type"] == "huggingface_hub"
+ assert not error["exception"]["values"][0]["mechanism"]["handled"]
+
+ span = None
+ for sp in transaction["spans"]:
+ if sp["op"].startswith("gen_ai"):
+ assert span is None, "there is exactly one gen_ai span"
+ span = sp
+ else:
+ # there should be no other spans, just the gen_ai span
+ # and optionally some http.client spans from talking to the hf api
+ assert sp["op"] == "http.client"
+
+ assert span is not None
+
+ assert span["op"] == "gen_ai.chat"
+ assert span["description"] == "chat test-model"
+ assert span["origin"] == "auto.ai.huggingface_hub"
+ assert span["status"] == "internal_error"
+ assert span.get("tags", {}).get("status") == "internal_error"
+
+ assert (
+ error["contexts"]["trace"]["trace_id"]
+ == transaction["contexts"]["trace"]["trace_id"]
+ )
+ expected_data = {
+ "gen_ai.operation.name": "chat",
+ "gen_ai.request.model": "test-model",
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+ assert span["data"] == expected_data
+
+
+@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
+def test_span_status_error(
+ sentry_init: "Any", capture_events: "Any", mock_hf_api_with_errors: "Any"
+) -> None:
+ sentry_init(traces_sample_rate=1.0)
+ events = capture_events()
+
+ client = get_hf_provider_inference_client()
+
+ with sentry_sdk.start_transaction(name="test"):
+ with pytest.raises(HfHubHTTPError):
+ client.chat_completion(
+ messages=[{"role": "user", "content": "Hello!"}],
+ )
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+
+ span = None
+ for sp in transaction["spans"]:
+ if sp["op"].startswith("gen_ai"):
+ assert span is None, "there is exactly one gen_ai span"
+ span = sp
+ else:
+ # there should be no other spans, just the gen_ai span
+ # and optionally some http.client spans from talking to the hf api
+ assert sp["op"] == "http.client"
+
+ assert span is not None
+ assert span["status"] == "internal_error"
+ assert span["tags"]["status"] == "internal_error"
+
+
+@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
+@pytest.mark.parametrize("send_default_pii", [True, False])
+@pytest.mark.parametrize("include_prompts", [True, False])
+def test_chat_completion_with_tools(
+ sentry_init: "Any",
+ capture_events: "Any",
+ send_default_pii: "Any",
+ include_prompts: "Any",
+ mock_hf_chat_completion_api_tools: "Any",
+) -> None:
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)],
+ )
+ events = capture_events()
+
+ client = get_hf_provider_inference_client()
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ "required": ["location"],
+ },
+ },
+ }
+ ]
+
+ with sentry_sdk.start_transaction(name="test"):
+ client.chat_completion(
+ messages=[{"role": "user", "content": "What is the weather in Paris?"}],
+ tools=tools,
+ tool_choice="auto",
+ )
+
+ (transaction,) = events
+
+ span = None
+ for sp in transaction["spans"]:
+ if sp["op"].startswith("gen_ai"):
+ assert span is None, "there is exactly one gen_ai span"
+ span = sp
+ else:
+ # there should be no other spans, just the gen_ai span
+ # and optionally some http.client spans from talking to the hf api
+ assert sp["op"] == "http.client"
+
+ assert span is not None
+
+ assert span["op"] == "gen_ai.chat"
+ assert span["description"] == "chat test-model"
+ assert span["origin"] == "auto.ai.huggingface_hub"
+
+ expected_data = {
+ "gen_ai.operation.name": "chat",
+ "gen_ai.request.available_tools": '[{"type": "function", "function": {"name": "get_weather", "description": "Get current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}]',
+ "gen_ai.request.model": "test-model",
+ "gen_ai.response.finish_reasons": "tool_calls",
+ "gen_ai.response.model": "test-model-123",
+ "gen_ai.usage.input_tokens": 10,
+ "gen_ai.usage.output_tokens": 8,
+ "gen_ai.usage.total_tokens": 18,
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ if send_default_pii and include_prompts:
+ expected_data["gen_ai.request.messages"] = (
+ '[{"role": "user", "content": "What is the weather in Paris?"}]'
+ )
+ expected_data["gen_ai.response.tool_calls"] = (
+ '[{"function": {"arguments": {"location": "Paris"}, "name": "get_weather", "description": "None"}, "id": "call_123", "type": "function"}]'
+ )
+
+ if not send_default_pii or not include_prompts:
+ assert "gen_ai.request.messages" not in expected_data
+ assert "gen_ai.response.text" not in expected_data
+ assert "gen_ai.response.tool_calls" not in expected_data
+
+ assert span["data"] == expected_data
+
+
+@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
+@pytest.mark.parametrize("send_default_pii", [True, False])
+@pytest.mark.parametrize("include_prompts", [True, False])
+def test_chat_completion_streaming_with_tools(
+ sentry_init: "Any",
+ capture_events: "Any",
+ send_default_pii: "Any",
+ include_prompts: "Any",
+ mock_hf_chat_completion_api_streaming_tools: "Any",
+) -> None:
+ sentry_init(
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)],
+ )
+ events = capture_events()
+
+ client = get_hf_provider_inference_client()
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ "required": ["location"],
+ },
+ },
+ }
+ ]
+
+ with sentry_sdk.start_transaction(name="test"):
+ _ = list(
+ client.chat_completion(
+ messages=[{"role": "user", "content": "What is the weather in Paris?"}],
+ stream=True,
+ tools=tools,
+ tool_choice="auto",
+ )
+ )
+
+ (transaction,) = events
+
+ span = None
+ for sp in transaction["spans"]:
+ if sp["op"].startswith("gen_ai"):
+ assert span is None, "there is exactly one gen_ai span"
+ span = sp
+ else:
+ # there should be no other spans, just the gen_ai span
+ # and optionally some http.client spans from talking to the hf api
+ assert sp["op"] == "http.client"
+
+ assert span is not None
+
+ assert span["op"] == "gen_ai.chat"
+ assert span["description"] == "chat test-model"
+ assert span["origin"] == "auto.ai.huggingface_hub"
+
+ expected_data = {
+ "gen_ai.operation.name": "chat",
+ "gen_ai.request.available_tools": '[{"type": "function", "function": {"name": "get_weather", "description": "Get current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}]',
+ "gen_ai.request.model": "test-model",
+ "gen_ai.response.finish_reasons": "tool_calls",
+ "gen_ai.response.model": "test-model-123",
+ "gen_ai.response.streaming": True,
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ if HF_VERSION and HF_VERSION >= (0, 26, 0):
+ expected_data["gen_ai.usage.input_tokens"] = 183
+ expected_data["gen_ai.usage.output_tokens"] = 14
+ expected_data["gen_ai.usage.total_tokens"] = 197
+
+ if send_default_pii and include_prompts:
+ expected_data["gen_ai.request.messages"] = (
+ '[{"role": "user", "content": "What is the weather in Paris?"}]'
+ )
+ expected_data["gen_ai.response.text"] = "response with tool calls follows"
+ expected_data["gen_ai.response.tool_calls"] = (
+ '[{"function": {"arguments": {"location": "Paris"}, "name": "get_weather"}, "id": "call_123", "type": "function", "index": "None"}]'
+ )
+
+ if not send_default_pii or not include_prompts:
+ assert "gen_ai.request.messages" not in expected_data
+ assert "gen_ai.response.text" not in expected_data
+ assert "gen_ai.response.tool_calls" not in expected_data
+
+ assert span["data"] == expected_data
diff --git a/tests/integrations/langchain/__init__.py b/tests/integrations/langchain/__init__.py
new file mode 100644
index 0000000000..a286454a56
--- /dev/null
+++ b/tests/integrations/langchain/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("langchain_core")
diff --git a/tests/integrations/langchain/test_langchain.py b/tests/integrations/langchain/test_langchain.py
new file mode 100644
index 0000000000..498a5d6f4a
--- /dev/null
+++ b/tests/integrations/langchain/test_langchain.py
@@ -0,0 +1,2447 @@
+import json
+from typing import List, Optional, Any, Iterator
+from unittest import mock
+from unittest.mock import Mock, patch
+
+import pytest
+
+from sentry_sdk.consts import SPANDATA
+
+try:
+ # Langchain >= 0.2
+ from langchain_openai import ChatOpenAI, OpenAI
+except ImportError:
+ # Langchain < 0.2
+ from langchain_community.llms import OpenAI
+ from langchain_community.chat_models import ChatOpenAI
+
+from langchain_core.callbacks import BaseCallbackManager, CallbackManagerForLLMRun
+from langchain_core.messages import BaseMessage, AIMessageChunk
+from langchain_core.outputs import ChatGenerationChunk, ChatResult
+from langchain_core.runnables import RunnableConfig
+from langchain_core.language_models.chat_models import BaseChatModel
+
+import sentry_sdk
+from sentry_sdk import start_transaction
+from sentry_sdk.utils import package_version
+from sentry_sdk.integrations.langchain import (
+ LangchainIntegration,
+ SentryLangchainCallback,
+ _transform_langchain_content_block,
+ _transform_langchain_message_content,
+)
+
+try:
+ # langchain v1+
+ from langchain.tools import tool
+ from langchain.agents import create_agent
+ from langchain_classic.agents import AgentExecutor, create_openai_tools_agent # type: ignore[import-not-found]
+except ImportError:
+ # langchain int:
+ """Returns the length of a word."""
+ return len(word)
+
+
+global stream_result_mock # type: Mock
+global llm_type # type: str
+
+
+class MockOpenAI(ChatOpenAI):
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ for x in stream_result_mock():
+ yield x
+
+ @property
+ def _llm_type(self) -> str:
+ return llm_type
+
+
+def test_langchain_text_completion(
+ sentry_init,
+ capture_events,
+ get_model_response,
+):
+ sentry_init(
+ integrations=[
+ LangchainIntegration(
+ include_prompts=True,
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ model_response = get_model_response(
+ Completion(
+ id="completion-id",
+ object="text_completion",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ CompletionChoice(
+ index=0,
+ finish_reason="stop",
+ text="The capital of France is Paris.",
+ )
+ ],
+ usage=CompletionUsage(
+ prompt_tokens=10,
+ completion_tokens=15,
+ total_tokens=25,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ model = OpenAI(
+ model_name="gpt-3.5-turbo",
+ temperature=0.7,
+ max_tokens=100,
+ openai_api_key="badkey",
+ )
+
+ with patch.object(
+ model.client._client._client,
+ "send",
+ return_value=model_response,
+ ) as _:
+ with start_transaction():
+ input_text = "What is the capital of France?"
+ model.invoke(input_text, config={"run_name": "my-snazzy-pipeline"})
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ llm_spans = [
+ span
+ for span in tx.get("spans", [])
+ if span.get("op") == "gen_ai.text_completion"
+ ]
+ assert len(llm_spans) > 0
+
+ llm_span = llm_spans[0]
+ assert llm_span["description"] == "text_completion gpt-3.5-turbo"
+ assert llm_span["data"]["gen_ai.system"] == "openai"
+ assert llm_span["data"]["gen_ai.pipeline.name"] == "my-snazzy-pipeline"
+ assert llm_span["data"]["gen_ai.request.model"] == "gpt-3.5-turbo"
+ assert llm_span["data"]["gen_ai.response.text"] == "The capital of France is Paris."
+ assert llm_span["data"]["gen_ai.usage.total_tokens"] == 25
+ assert llm_span["data"]["gen_ai.usage.input_tokens"] == 10
+ assert llm_span["data"]["gen_ai.usage.output_tokens"] == 15
+
+
+@pytest.mark.skipif(
+ LANGCHAIN_VERSION < (1,),
+ reason="LangChain 1.0+ required (ONE AGENT refactor)",
+)
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+@pytest.mark.parametrize(
+ "system_instructions_content",
+ [
+ "You are very powerful assistant, but don't know current events",
+ [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ ],
+ ids=["string", "blocks"],
+)
+def test_langchain_create_agent(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ system_instructions_content,
+ request,
+ get_model_response,
+ nonstreaming_responses_model_response,
+):
+ sentry_init(
+ integrations=[
+ LangchainIntegration(
+ include_prompts=include_prompts,
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ model_response = get_model_response(
+ nonstreaming_responses_model_response,
+ serialize_pydantic=True,
+ request_headers={
+ "X-Stainless-Raw-Response": "True",
+ },
+ )
+
+ llm = ChatOpenAI(
+ model_name="gpt-3.5-turbo",
+ temperature=0,
+ openai_api_key="badkey",
+ use_responses_api=True,
+ )
+ agent = create_agent(
+ model=llm,
+ tools=[get_word_length],
+ system_prompt=SystemMessage(content=system_instructions_content),
+ name="word_length_agent",
+ )
+
+ with patch.object(
+ llm.client._client._client,
+ "send",
+ return_value=model_response,
+ ) as _:
+ with start_transaction():
+ agent.invoke(
+ {
+ "messages": [
+ HumanMessage(content="How many letters in the word eudca"),
+ ],
+ },
+ )
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ assert tx["contexts"]["trace"]["origin"] == "manual"
+
+ chat_spans = list(x for x in tx["spans"] if x["op"] == "gen_ai.chat")
+ assert len(chat_spans) == 1
+ assert chat_spans[0]["origin"] == "auto.ai.langchain"
+
+ assert chat_spans[0]["data"]["gen_ai.system"] == "openai-chat"
+ assert chat_spans[0]["data"]["gen_ai.usage.input_tokens"] == 10
+ assert chat_spans[0]["data"]["gen_ai.usage.output_tokens"] == 20
+ assert chat_spans[0]["data"]["gen_ai.usage.total_tokens"] == 30
+
+ if send_default_pii and include_prompts:
+ assert (
+ chat_spans[0]["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ == "Hello, how can I help you?"
+ )
+
+ param_id = request.node.callspec.id
+ if "string" in param_id:
+ assert [
+ {
+ "type": "text",
+ "content": "You are very powerful assistant, but don't know current events",
+ }
+ ] == json.loads(chat_spans[0]["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS])
+ else:
+ assert [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "type": "text",
+ "content": "Be concise and clear.",
+ },
+ ] == json.loads(chat_spans[0]["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS])
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in chat_spans[0].get("data", {})
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in chat_spans[0].get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in chat_spans[0].get("data", {})
+
+
+@pytest.mark.skipif(
+ LANGCHAIN_VERSION < (1,),
+ reason="LangChain 1.0+ required (ONE AGENT refactor)",
+)
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_tool_execution_span(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ responses_tool_call_model_responses,
+):
+ sentry_init(
+ integrations=[
+ LangchainIntegration(
+ include_prompts=include_prompts,
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ responses = responses_tool_call_model_responses(
+ tool_name="get_word_length",
+ arguments='{"word": "eudca"}',
+ response_model="gpt-4-0613",
+ response_text="The word eudca has 5 letters.",
+ response_ids=iter(["resp_1", "resp_2"]),
+ usages=iter(
+ [
+ ResponseUsage(
+ input_tokens=142,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=50,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=192,
+ ),
+ ResponseUsage(
+ input_tokens=89,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=28,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=117,
+ ),
+ ]
+ ),
+ )
+ tool_response = get_model_response(
+ next(responses),
+ serialize_pydantic=True,
+ request_headers={
+ "X-Stainless-Raw-Response": "True",
+ },
+ )
+ final_response = get_model_response(
+ next(responses),
+ serialize_pydantic=True,
+ request_headers={
+ "X-Stainless-Raw-Response": "True",
+ },
+ )
+
+ llm = ChatOpenAI(
+ model_name="gpt-4",
+ temperature=0,
+ openai_api_key="badkey",
+ use_responses_api=True,
+ )
+ agent = create_agent(
+ model=llm,
+ tools=[get_word_length],
+ name="word_length_agent",
+ )
+
+ with patch.object(
+ llm.client._client._client,
+ "send",
+ side_effect=[tool_response, final_response],
+ ) as _:
+ with start_transaction():
+ agent.invoke(
+ {
+ "messages": [
+ HumanMessage(content="How many letters in the word eudca"),
+ ],
+ },
+ )
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ assert tx["contexts"]["trace"]["origin"] == "manual"
+
+ chat_spans = list(x for x in tx["spans"] if x["op"] == "gen_ai.chat")
+ assert len(chat_spans) == 2
+
+ tool_exec_spans = list(x for x in tx["spans"] if x["op"] == "gen_ai.execute_tool")
+ assert len(tool_exec_spans) == 1
+ tool_exec_span = tool_exec_spans[0]
+
+ assert chat_spans[0]["origin"] == "auto.ai.langchain"
+ assert chat_spans[1]["origin"] == "auto.ai.langchain"
+ assert tool_exec_span["origin"] == "auto.ai.langchain"
+
+ assert chat_spans[0]["data"]["gen_ai.usage.input_tokens"] == 142
+ assert chat_spans[0]["data"]["gen_ai.usage.output_tokens"] == 50
+ assert chat_spans[0]["data"]["gen_ai.usage.total_tokens"] == 192
+ assert chat_spans[0]["data"]["gen_ai.system"] == "openai-chat"
+
+ assert chat_spans[1]["data"]["gen_ai.usage.input_tokens"] == 89
+ assert chat_spans[1]["data"]["gen_ai.usage.output_tokens"] == 28
+ assert chat_spans[1]["data"]["gen_ai.usage.total_tokens"] == 117
+ assert chat_spans[1]["data"]["gen_ai.system"] == "openai-chat"
+
+ if send_default_pii and include_prompts:
+ assert "word" in tool_exec_span["data"][SPANDATA.GEN_AI_TOOL_INPUT]
+
+ assert "5" in chat_spans[1]["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+
+ # Verify tool calls are recorded when PII is enabled
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS in chat_spans[0].get("data", {}), (
+ "Tool calls should be recorded when send_default_pii=True and include_prompts=True"
+ )
+ tool_calls_data = chat_spans[0]["data"][SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS]
+ assert isinstance(tool_calls_data, str)
+ assert "get_word_length" in tool_calls_data
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in chat_spans[0].get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in chat_spans[0].get("data", {})
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in chat_spans[1].get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in chat_spans[1].get("data", {})
+ assert SPANDATA.GEN_AI_TOOL_INPUT not in tool_exec_span.get("data", {})
+ assert SPANDATA.GEN_AI_TOOL_OUTPUT not in tool_exec_span.get("data", {})
+
+ # Verify tool calls are NOT recorded when PII is disabled
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS not in chat_spans[0].get(
+ "data", {}
+ ), (
+ f"Tool calls should NOT be recorded when send_default_pii={send_default_pii} "
+ f"and include_prompts={include_prompts}"
+ )
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS not in chat_spans[1].get(
+ "data", {}
+ ), (
+ f"Tool calls should NOT be recorded when send_default_pii={send_default_pii} "
+ f"and include_prompts={include_prompts}"
+ )
+
+ # Verify that available tools are always recorded regardless of PII settings
+ for chat_span in chat_spans:
+ tools_data = chat_span["data"][SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS]
+ assert "get_word_length" in tools_data
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+@pytest.mark.parametrize(
+ "system_instructions_content",
+ [
+ "You are very powerful assistant, but don't know current events",
+ ["You are a helpful assistant.", "Be concise and clear."],
+ [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ ],
+ ids=["string", "list", "blocks"],
+)
+def test_langchain_openai_tools_agent(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ system_instructions_content,
+ request,
+ get_model_response,
+ server_side_event_chunks,
+):
+ sentry_init(
+ integrations=[
+ LangchainIntegration(
+ include_prompts=include_prompts,
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ prompt = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ system_instructions_content,
+ ),
+ ("user", "{input}"),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ]
+ )
+
+ tool_response = get_model_response(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="chatcmpl-turn-1",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ Choice(
+ index=0,
+ delta=ChoiceDelta(role="assistant"),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ ChatCompletionChunk(
+ id="chatcmpl-turn-1",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ Choice(
+ index=0,
+ delta=ChoiceDelta(
+ tool_calls=[
+ ChoiceDeltaToolCall(
+ index=0,
+ id="call_BbeyNhCKa6kYLYzrD40NGm3b",
+ type="function",
+ function=ChoiceDeltaToolCallFunction(
+ name="get_word_length",
+ arguments="",
+ ),
+ ),
+ ],
+ ),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ ChatCompletionChunk(
+ id="chatcmpl-turn-1",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ Choice(
+ index=0,
+ delta=ChoiceDelta(
+ tool_calls=[
+ ChoiceDeltaToolCall(
+ index=0,
+ function=ChoiceDeltaToolCallFunction(
+ arguments='{"word": "eudca"}',
+ ),
+ ),
+ ],
+ ),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ ChatCompletionChunk(
+ id="chatcmpl-turn-1",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ Choice(
+ index=0,
+ delta=ChoiceDelta(content="5"),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ ChatCompletionChunk(
+ id="chatcmpl-turn-1",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ Choice(
+ index=0,
+ delta=ChoiceDelta(),
+ finish_reason="function_call",
+ ),
+ ],
+ ),
+ ChatCompletionChunk(
+ id="chatcmpl-turn-1",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[],
+ usage=CompletionUsage(
+ prompt_tokens=142,
+ completion_tokens=50,
+ total_tokens=192,
+ ),
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+
+ final_response = get_model_response(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="chatcmpl-turn-2",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ Choice(
+ index=0,
+ delta=ChoiceDelta(role="assistant"),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ ChatCompletionChunk(
+ id="chatcmpl-turn-2",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ Choice(
+ index=0,
+ delta=ChoiceDelta(content="The word eudca has 5 letters."),
+ finish_reason=None,
+ ),
+ ],
+ ),
+ ChatCompletionChunk(
+ id="chatcmpl-turn-2",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[
+ Choice(
+ index=0,
+ delta=ChoiceDelta(),
+ finish_reason="stop",
+ ),
+ ],
+ ),
+ ChatCompletionChunk(
+ id="chatcmpl-turn-2",
+ object="chat.completion.chunk",
+ created=10000000,
+ model="gpt-3.5-turbo",
+ choices=[],
+ usage=CompletionUsage(
+ prompt_tokens=89,
+ completion_tokens=28,
+ total_tokens=117,
+ ),
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+
+ llm = ChatOpenAI(
+ model_name="gpt-3.5-turbo",
+ temperature=0,
+ openai_api_key="badkey",
+ )
+ agent = create_openai_tools_agent(llm, [get_word_length], prompt)
+
+ agent_executor = AgentExecutor(agent=agent, tools=[get_word_length], verbose=True)
+
+ with patch.object(
+ llm.client._client._client,
+ "send",
+ side_effect=[tool_response, final_response],
+ ) as _:
+ with start_transaction():
+ list(agent_executor.stream({"input": "How many letters in the word eudca"}))
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ assert tx["contexts"]["trace"]["origin"] == "manual"
+
+ invoke_agent_span = next(x for x in tx["spans"] if x["op"] == "gen_ai.invoke_agent")
+ chat_spans = list(x for x in tx["spans"] if x["op"] == "gen_ai.chat")
+ tool_exec_span = next(x for x in tx["spans"] if x["op"] == "gen_ai.execute_tool")
+
+ assert len(chat_spans) == 2
+
+ assert invoke_agent_span["origin"] == "auto.ai.langchain"
+ assert chat_spans[0]["origin"] == "auto.ai.langchain"
+ assert chat_spans[1]["origin"] == "auto.ai.langchain"
+ assert tool_exec_span["origin"] == "auto.ai.langchain"
+
+ # We can't guarantee anything about the "shape" of the langchain execution graph
+ assert len(list(x for x in tx["spans"] if x["op"] == "gen_ai.chat")) > 0
+
+ # Token usage is only available in newer versions of langchain (v0.2+)
+ # where usage_metadata is supported on AIMessageChunk
+ if "gen_ai.usage.input_tokens" in chat_spans[0]["data"]:
+ assert chat_spans[0]["data"]["gen_ai.usage.input_tokens"] == 142
+ assert chat_spans[0]["data"]["gen_ai.usage.output_tokens"] == 50
+ assert chat_spans[0]["data"]["gen_ai.usage.total_tokens"] == 192
+
+ if "gen_ai.usage.input_tokens" in chat_spans[1]["data"]:
+ assert chat_spans[1]["data"]["gen_ai.usage.input_tokens"] == 89
+ assert chat_spans[1]["data"]["gen_ai.usage.output_tokens"] == 28
+ assert chat_spans[1]["data"]["gen_ai.usage.total_tokens"] == 117
+
+ if send_default_pii and include_prompts:
+ assert "5" in chat_spans[0]["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ assert "word" in tool_exec_span["data"][SPANDATA.GEN_AI_TOOL_INPUT]
+ assert 5 == int(tool_exec_span["data"][SPANDATA.GEN_AI_TOOL_OUTPUT])
+
+ param_id = request.node.callspec.id
+ if "string" in param_id:
+ assert [
+ {
+ "type": "text",
+ "content": "You are very powerful assistant, but don't know current events",
+ }
+ ] == json.loads(chat_spans[0]["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS])
+ else:
+ assert [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "type": "text",
+ "content": "Be concise and clear.",
+ },
+ ] == json.loads(chat_spans[0]["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS])
+
+ assert "5" in chat_spans[1]["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+
+ # Verify tool calls are recorded when PII is enabled
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS in chat_spans[0].get("data", {}), (
+ "Tool calls should be recorded when send_default_pii=True and include_prompts=True"
+ )
+ tool_calls_data = chat_spans[0]["data"][SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS]
+ assert isinstance(tool_calls_data, (list, str)) # Could be serialized
+ if isinstance(tool_calls_data, str):
+ assert "get_word_length" in tool_calls_data
+ elif isinstance(tool_calls_data, list) and len(tool_calls_data) > 0:
+ # Check if tool calls contain expected function name
+ tool_call_str = str(tool_calls_data)
+ assert "get_word_length" in tool_call_str
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in chat_spans[0].get("data", {})
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in chat_spans[0].get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in chat_spans[0].get("data", {})
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in chat_spans[1].get("data", {})
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in chat_spans[1].get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in chat_spans[1].get("data", {})
+ assert SPANDATA.GEN_AI_TOOL_INPUT not in tool_exec_span.get("data", {})
+ assert SPANDATA.GEN_AI_TOOL_OUTPUT not in tool_exec_span.get("data", {})
+
+ # Verify tool calls are NOT recorded when PII is disabled
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS not in chat_spans[0].get(
+ "data", {}
+ ), (
+ f"Tool calls should NOT be recorded when send_default_pii={send_default_pii} "
+ f"and include_prompts={include_prompts}"
+ )
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS not in chat_spans[1].get(
+ "data", {}
+ ), (
+ f"Tool calls should NOT be recorded when send_default_pii={send_default_pii} "
+ f"and include_prompts={include_prompts}"
+ )
+
+ # Verify finish_reasons is always an array of strings
+ assert chat_spans[0]["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == [
+ "function_call"
+ ]
+ assert chat_spans[1]["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == ["stop"]
+
+ # Verify that available tools are always recorded regardless of PII settings
+ for chat_span in chat_spans:
+ tools_data = chat_span["data"][SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS]
+ assert tools_data is not None, (
+ "Available tools should always be recorded regardless of PII settings"
+ )
+ assert "get_word_length" in tools_data
+
+
+def test_langchain_error(sentry_init, capture_events):
+ global llm_type
+ llm_type = "acme-llm"
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ prompt = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ "You are very powerful assistant, but don't know current events",
+ ),
+ ("user", "{input}"),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ]
+ )
+ global stream_result_mock
+ stream_result_mock = Mock(side_effect=ValueError("API rate limit error"))
+ llm = MockOpenAI(
+ model_name="gpt-3.5-turbo",
+ temperature=0,
+ openai_api_key="badkey",
+ )
+ agent = create_openai_tools_agent(llm, [get_word_length], prompt)
+
+ agent_executor = AgentExecutor(agent=agent, tools=[get_word_length], verbose=True)
+
+ with start_transaction(), pytest.raises(ValueError):
+ list(agent_executor.stream({"input": "How many letters in the word eudca"}))
+
+ error = events[0]
+ assert error["level"] == "error"
+
+
+def test_span_status_error(sentry_init, capture_events):
+ global llm_type
+ llm_type = "acme-llm"
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ with start_transaction(name="test"):
+ prompt = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ "You are very powerful assistant, but don't know current events",
+ ),
+ ("user", "{input}"),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ]
+ )
+ global stream_result_mock
+ stream_result_mock = Mock(side_effect=ValueError("API rate limit error"))
+ llm = MockOpenAI(
+ model_name="gpt-3.5-turbo",
+ temperature=0,
+ openai_api_key="badkey",
+ )
+ agent = create_openai_tools_agent(llm, [get_word_length], prompt)
+
+ agent_executor = AgentExecutor(
+ agent=agent, tools=[get_word_length], verbose=True
+ )
+
+ with pytest.raises(ValueError):
+ list(agent_executor.stream({"input": "How many letters in the word eudca"}))
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+ assert transaction["spans"][0]["status"] == "internal_error"
+ assert transaction["spans"][0]["tags"]["status"] == "internal_error"
+ assert transaction["contexts"]["trace"]["status"] == "internal_error"
+
+
+def test_manual_callback_no_duplication(sentry_init):
+ """
+ Test that when a user manually provides a SentryLangchainCallback,
+ the integration doesn't create a duplicate callback.
+ """
+
+ # Track callback instances
+ tracked_callback_instances = set()
+
+ class CallbackTrackingModel(BaseChatModel):
+ """Mock model that tracks callback instances for testing."""
+
+ def _generate(
+ self,
+ messages,
+ stop=None,
+ run_manager=None,
+ **kwargs,
+ ):
+ # Track all SentryLangchainCallback instances
+ if run_manager:
+ for handler in run_manager.handlers:
+ if isinstance(handler, SentryLangchainCallback):
+ tracked_callback_instances.add(id(handler))
+
+ for handler in run_manager.inheritable_handlers:
+ if isinstance(handler, SentryLangchainCallback):
+ tracked_callback_instances.add(id(handler))
+
+ return ChatResult(
+ generations=[
+ ChatGenerationChunk(message=AIMessageChunk(content="Hello!"))
+ ],
+ llm_output={},
+ )
+
+ @property
+ def _llm_type(self):
+ return "test_model"
+
+ @property
+ def _identifying_params(self):
+ return {}
+
+ sentry_init(integrations=[LangchainIntegration()])
+
+ # Create a manual SentryLangchainCallback
+ manual_callback = SentryLangchainCallback(
+ max_span_map_size=100, include_prompts=False
+ )
+
+ # Create RunnableConfig with the manual callback
+ config = RunnableConfig(callbacks=[manual_callback])
+
+ # Invoke the model with the config
+ llm = CallbackTrackingModel()
+ llm.invoke("Hello", config)
+
+ # Verify that only ONE SentryLangchainCallback instance was used
+ assert len(tracked_callback_instances) == 1, (
+ f"Expected exactly 1 SentryLangchainCallback instance, "
+ f"but found {len(tracked_callback_instances)}. "
+ f"This indicates callback duplication occurred."
+ )
+
+ # Verify the callback ID matches our manual callback
+ assert id(manual_callback) in tracked_callback_instances
+
+
+def test_span_map_is_instance_variable():
+ """Test that each SentryLangchainCallback instance has its own span_map."""
+ # Create two separate callback instances
+ callback1 = SentryLangchainCallback(max_span_map_size=100, include_prompts=True)
+ callback2 = SentryLangchainCallback(max_span_map_size=100, include_prompts=True)
+
+ # Verify they have different span_map instances
+ assert callback1.span_map is not callback2.span_map, (
+ "span_map should be an instance variable, not shared between instances"
+ )
+
+
+def test_langchain_callback_manager(sentry_init):
+ sentry_init(
+ integrations=[LangchainIntegration()],
+ traces_sample_rate=1.0,
+ )
+ local_manager = BaseCallbackManager(handlers=[])
+
+ with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
+ mock_configure = mock_manager_module._configure
+
+ # Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
+ LangchainIntegration.setup_once()
+
+ callback_manager_cls = Mock()
+
+ mock_manager_module._configure(
+ callback_manager_cls, local_callbacks=local_manager
+ )
+
+ assert mock_configure.call_count == 1
+
+ call_args = mock_configure.call_args
+ assert call_args.args[0] is callback_manager_cls
+
+ passed_manager = call_args.args[2]
+ assert passed_manager is not local_manager
+ assert local_manager.handlers == []
+
+ [handler] = passed_manager.handlers
+ assert isinstance(handler, SentryLangchainCallback)
+
+
+def test_langchain_callback_manager_with_sentry_callback(sentry_init):
+ sentry_init(
+ integrations=[LangchainIntegration()],
+ traces_sample_rate=1.0,
+ )
+ sentry_callback = SentryLangchainCallback(0, False)
+ local_manager = BaseCallbackManager(handlers=[sentry_callback])
+
+ with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
+ mock_configure = mock_manager_module._configure
+
+ # Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
+ LangchainIntegration.setup_once()
+
+ callback_manager_cls = Mock()
+
+ mock_manager_module._configure(
+ callback_manager_cls, local_callbacks=local_manager
+ )
+
+ assert mock_configure.call_count == 1
+
+ call_args = mock_configure.call_args
+ assert call_args.args[0] is callback_manager_cls
+
+ passed_manager = call_args.args[2]
+ assert passed_manager is local_manager
+
+ [handler] = passed_manager.handlers
+ assert handler is sentry_callback
+
+
+def test_langchain_callback_list(sentry_init):
+ sentry_init(
+ integrations=[LangchainIntegration()],
+ traces_sample_rate=1.0,
+ )
+ local_callbacks = []
+
+ with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
+ mock_configure = mock_manager_module._configure
+
+ # Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
+ LangchainIntegration.setup_once()
+
+ callback_manager_cls = Mock()
+
+ mock_manager_module._configure(
+ callback_manager_cls, local_callbacks=local_callbacks
+ )
+
+ assert mock_configure.call_count == 1
+
+ call_args = mock_configure.call_args
+ assert call_args.args[0] is callback_manager_cls
+
+ passed_callbacks = call_args.args[2]
+ assert passed_callbacks is not local_callbacks
+ assert local_callbacks == []
+
+ [handler] = passed_callbacks
+ assert isinstance(handler, SentryLangchainCallback)
+
+
+def test_langchain_callback_list_existing_callback(sentry_init):
+ sentry_init(
+ integrations=[LangchainIntegration()],
+ traces_sample_rate=1.0,
+ )
+ sentry_callback = SentryLangchainCallback(0, False)
+ local_callbacks = [sentry_callback]
+
+ with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
+ mock_configure = mock_manager_module._configure
+
+ # Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
+ LangchainIntegration.setup_once()
+
+ callback_manager_cls = Mock()
+
+ mock_manager_module._configure(
+ callback_manager_cls, local_callbacks=local_callbacks
+ )
+
+ assert mock_configure.call_count == 1
+
+ call_args = mock_configure.call_args
+ assert call_args.args[0] is callback_manager_cls
+
+ passed_callbacks = call_args.args[2]
+ assert passed_callbacks is local_callbacks
+
+ [handler] = passed_callbacks
+ assert handler is sentry_callback
+
+
+def test_langchain_message_role_mapping(sentry_init, capture_events):
+ """Test that message roles are properly normalized in langchain integration."""
+ global llm_type
+ llm_type = "openai-chat"
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", "You are a helpful assistant"),
+ ("human", "{input}"),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ]
+ )
+
+ global stream_result_mock
+ stream_result_mock = Mock(
+ side_effect=[
+ [
+ ChatGenerationChunk(
+ type="ChatGenerationChunk",
+ message=AIMessageChunk(content="Test response"),
+ ),
+ ]
+ ]
+ )
+
+ llm = MockOpenAI(
+ model_name="gpt-3.5-turbo",
+ temperature=0,
+ openai_api_key="badkey",
+ )
+ agent = create_openai_tools_agent(llm, [get_word_length], prompt)
+ agent_executor = AgentExecutor(agent=agent, tools=[get_word_length], verbose=True)
+
+ # Test input that should trigger message role normalization
+ test_input = "Hello, how are you?"
+
+ with start_transaction():
+ list(agent_executor.stream({"input": test_input}))
+
+ assert len(events) > 0
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find spans with gen_ai operation that should have message data
+ gen_ai_spans = [
+ span for span in tx.get("spans", []) if span.get("op", "").startswith("gen_ai")
+ ]
+
+ # Check if any span has message data with normalized roles
+ message_data_found = False
+ for span in gen_ai_spans:
+ span_data = span.get("data", {})
+ if SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data:
+ message_data_found = True
+ messages_data = span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES]
+
+ # Parse the message data (might be JSON string)
+ if isinstance(messages_data, str):
+ try:
+ messages = json.loads(messages_data)
+ except json.JSONDecodeError:
+ # If not valid JSON, skip this assertion
+ continue
+ else:
+ messages = messages_data
+
+ # Verify that the input message is present and contains the test input
+ assert isinstance(messages, list)
+ assert len(messages) > 0
+
+ # The test input should be in one of the messages
+ input_found = False
+ for msg in messages:
+ if isinstance(msg, dict) and test_input in str(msg.get("content", "")):
+ input_found = True
+ break
+ elif isinstance(msg, str) and test_input in msg:
+ input_found = True
+ break
+
+ assert input_found, (
+ f"Test input '{test_input}' not found in messages: {messages}"
+ )
+ break
+
+ # The message role mapping functionality is primarily tested through the normalization
+ # that happens in the integration code. The fact that we can capture and process
+ # the messages without errors indicates the role mapping is working correctly.
+ assert message_data_found, "No span found with gen_ai request messages data"
+
+
+def test_langchain_message_role_normalization_units():
+ """Test the message role normalization functions directly."""
+ from sentry_sdk.ai.utils import normalize_message_role, normalize_message_roles
+
+ # Test individual role normalization
+ assert normalize_message_role("ai") == "assistant"
+ assert normalize_message_role("human") == "user"
+ assert normalize_message_role("tool_call") == "tool"
+ assert normalize_message_role("system") == "system"
+ assert normalize_message_role("user") == "user"
+ assert normalize_message_role("assistant") == "assistant"
+ assert normalize_message_role("tool") == "tool"
+
+ # Test unknown role (should remain unchanged)
+ assert normalize_message_role("unknown_role") == "unknown_role"
+
+ # Test message list normalization
+ test_messages = [
+ {"role": "human", "content": "Hello"},
+ {"role": "ai", "content": "Hi there!"},
+ {"role": "tool_call", "content": "function_call"},
+ {"role": "system", "content": "You are helpful"},
+ {"content": "Message without role"},
+ "string message",
+ ]
+
+ normalized = normalize_message_roles(test_messages)
+
+ # Verify the original messages are not modified
+ assert test_messages[0]["role"] == "human" # Original unchanged
+ assert test_messages[1]["role"] == "ai" # Original unchanged
+
+ # Verify the normalized messages have correct roles
+ assert normalized[0]["role"] == "user" # human -> user
+ assert normalized[1]["role"] == "assistant" # ai -> assistant
+ assert normalized[2]["role"] == "tool" # tool_call -> tool
+ assert normalized[3]["role"] == "system" # system unchanged
+ assert "role" not in normalized[4] # Message without role unchanged
+ assert normalized[5] == "string message" # String message unchanged
+
+
+def test_langchain_message_truncation(sentry_init, capture_events):
+ """Test that large messages are truncated properly in Langchain integration."""
+ from langchain_core.outputs import LLMResult, Generation
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ callback = SentryLangchainCallback(max_span_map_size=100, include_prompts=True)
+
+ run_id = "12345678-1234-1234-1234-123456789012"
+ serialized = {"_type": "openai-chat", "model_name": "gpt-3.5-turbo"}
+
+ large_content = (
+ "This is a very long message that will exceed our size limits. " * 1000
+ )
+ prompts = [
+ "small message 1",
+ large_content,
+ large_content,
+ "small message 4",
+ "small message 5",
+ ]
+
+ with start_transaction():
+ callback.on_llm_start(
+ serialized=serialized,
+ prompts=prompts,
+ run_id=run_id,
+ name="my_pipeline",
+ invocation_params={
+ "temperature": 0.7,
+ "max_tokens": 100,
+ "model": "gpt-3.5-turbo",
+ },
+ )
+
+ response = LLMResult(
+ generations=[[Generation(text="The response")]],
+ llm_output={
+ "token_usage": {
+ "total_tokens": 25,
+ "prompt_tokens": 10,
+ "completion_tokens": 15,
+ }
+ },
+ )
+ callback.on_llm_end(response=response, run_id=run_id)
+
+ assert len(events) > 0
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ llm_spans = [
+ span
+ for span in tx.get("spans", [])
+ if span.get("op") == "gen_ai.text_completion"
+ ]
+ assert len(llm_spans) > 0
+
+ llm_span = llm_spans[0]
+ assert llm_span["data"]["gen_ai.operation.name"] == "text_completion"
+ assert llm_span["data"][SPANDATA.GEN_AI_PIPELINE_NAME] == "my_pipeline"
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in llm_span["data"]
+ messages_data = llm_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert isinstance(messages_data, str)
+
+ parsed_messages = json.loads(messages_data)
+ assert isinstance(parsed_messages, list)
+ assert len(parsed_messages) == 1
+ assert "small message 5" in str(parsed_messages[0])
+ assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_langchain_embeddings_sync(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ """Test that sync embedding methods (embed_documents, embed_query) are properly traced."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ # Mock the actual API call
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_documents",
+ wraps=lambda self, texts: [[0.1, 0.2, 0.3] for _ in texts],
+ ) as mock_embed_documents:
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ # Force setup to re-run to ensure our mock is wrapped
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_embeddings"):
+ # Test embed_documents
+ result = embeddings.embed_documents(["Hello world", "Test document"])
+
+ assert len(result) == 2
+ mock_embed_documents.assert_called_once()
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find embeddings span
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ assert len(embeddings_spans) == 1
+
+ embeddings_span = embeddings_spans[0]
+ assert embeddings_span["description"] == "embeddings text-embedding-ada-002"
+ assert embeddings_span["origin"] == "auto.ai.langchain"
+ assert embeddings_span["data"]["gen_ai.operation.name"] == "embeddings"
+ assert embeddings_span["data"]["gen_ai.request.model"] == "text-embedding-ada-002"
+
+ # Check if input is captured based on PII settings
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT in embeddings_span["data"]
+ input_data = embeddings_span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ # Could be serialized as string
+ if isinstance(input_data, str):
+ assert "Hello world" in input_data
+ assert "Test document" in input_data
+ else:
+ assert "Hello world" in input_data
+ assert "Test document" in input_data
+ else:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in embeddings_span.get("data", {})
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (False, False),
+ ],
+)
+def test_langchain_embeddings_embed_query(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ """Test that embed_query method is properly traced."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ # Mock the actual API call
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_query",
+ wraps=lambda self, text: [0.1, 0.2, 0.3],
+ ) as mock_embed_query:
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ # Force setup to re-run to ensure our mock is wrapped
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_embeddings_query"):
+ result = embeddings.embed_query("What is the capital of France?")
+
+ assert len(result) == 3
+ mock_embed_query.assert_called_once()
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find embeddings span
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ assert len(embeddings_spans) == 1
+
+ embeddings_span = embeddings_spans[0]
+ assert embeddings_span["data"]["gen_ai.operation.name"] == "embeddings"
+ assert embeddings_span["data"]["gen_ai.request.model"] == "text-embedding-ada-002"
+
+ # Check if input is captured based on PII settings
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT in embeddings_span["data"]
+ input_data = embeddings_span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ # Could be serialized as string
+ if isinstance(input_data, str):
+ assert "What is the capital of France?" in input_data
+ else:
+ assert "What is the capital of France?" in input_data
+ else:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in embeddings_span.get("data", {})
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (False, False),
+ ],
+)
+@pytest.mark.asyncio
+async def test_langchain_embeddings_async(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ """Test that async embedding methods (aembed_documents, aembed_query) are properly traced."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ async def mock_aembed_documents(self, texts):
+ return [[0.1, 0.2, 0.3] for _ in texts]
+
+ # Mock the actual API call
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "aembed_documents",
+ wraps=mock_aembed_documents,
+ ) as mock_aembed:
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ # Force setup to re-run to ensure our mock is wrapped
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_async_embeddings"):
+ result = await embeddings.aembed_documents(
+ ["Async hello", "Async test document"]
+ )
+
+ assert len(result) == 2
+ mock_aembed.assert_called_once()
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find embeddings span
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ assert len(embeddings_spans) == 1
+
+ embeddings_span = embeddings_spans[0]
+ assert embeddings_span["description"] == "embeddings text-embedding-ada-002"
+ assert embeddings_span["origin"] == "auto.ai.langchain"
+ assert embeddings_span["data"]["gen_ai.operation.name"] == "embeddings"
+ assert embeddings_span["data"]["gen_ai.request.model"] == "text-embedding-ada-002"
+
+ # Check if input is captured based on PII settings
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT in embeddings_span["data"]
+ input_data = embeddings_span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ # Could be serialized as string
+ if isinstance(input_data, str):
+ assert "Async hello" in input_data or "Async test document" in input_data
+ else:
+ assert "Async hello" in input_data or "Async test document" in input_data
+ else:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in embeddings_span.get("data", {})
+
+
+@pytest.mark.asyncio
+async def test_langchain_embeddings_aembed_query(sentry_init, capture_events):
+ """Test that aembed_query method is properly traced."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ async def mock_aembed_query(self, text):
+ return [0.1, 0.2, 0.3]
+
+ # Mock the actual API call
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "aembed_query",
+ wraps=mock_aembed_query,
+ ) as mock_aembed:
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ # Force setup to re-run to ensure our mock is wrapped
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_async_embeddings_query"):
+ result = await embeddings.aembed_query("Async query test")
+
+ assert len(result) == 3
+ mock_aembed.assert_called_once()
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find embeddings span
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ assert len(embeddings_spans) == 1
+
+ embeddings_span = embeddings_spans[0]
+ assert embeddings_span["data"]["gen_ai.operation.name"] == "embeddings"
+ assert embeddings_span["data"]["gen_ai.request.model"] == "text-embedding-ada-002"
+
+ # Check if input is captured
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT in embeddings_span["data"]
+ input_data = embeddings_span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ # Could be serialized as string
+ if isinstance(input_data, str):
+ assert "Async query test" in input_data
+ else:
+ assert "Async query test" in input_data
+
+
+def test_langchain_embeddings_no_model_name(sentry_init, capture_events):
+ """Test embeddings when model name is not available."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Mock the actual API call and remove model attribute
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_documents",
+ wraps=lambda self, texts: [[0.1, 0.2, 0.3] for _ in texts],
+ ):
+ embeddings = OpenAIEmbeddings(openai_api_key="test-key")
+ # Remove model attribute to test fallback
+ delattr(embeddings, "model")
+ if hasattr(embeddings, "model_name"):
+ delattr(embeddings, "model_name")
+
+ # Force setup to re-run to ensure our mock is wrapped
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_embeddings_no_model"):
+ embeddings.embed_documents(["Test"])
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find embeddings span
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ assert len(embeddings_spans) == 1
+
+ embeddings_span = embeddings_spans[0]
+ assert embeddings_span["description"] == "embeddings"
+ assert embeddings_span["data"]["gen_ai.operation.name"] == "embeddings"
+ # Model name should not be set if not available
+ assert (
+ "gen_ai.request.model" not in embeddings_span["data"]
+ or embeddings_span["data"]["gen_ai.request.model"] is None
+ )
+
+
+def test_langchain_embeddings_integration_disabled(sentry_init, capture_events):
+ """Test that embeddings are not traced when integration is disabled."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ # Initialize without LangchainIntegration
+ sentry_init(traces_sample_rate=1.0)
+ events = capture_events()
+
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_documents",
+ return_value=[[0.1, 0.2, 0.3]],
+ ):
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ with start_transaction(name="test_embeddings_disabled"):
+ embeddings.embed_documents(["Test"])
+
+ # Check that no embeddings spans were created
+ if events:
+ tx = events[0]
+ embeddings_spans = [
+ span
+ for span in tx.get("spans", [])
+ if span.get("op") == "gen_ai.embeddings"
+ ]
+ # Should be empty since integration is disabled
+ assert len(embeddings_spans) == 0
+
+
+def test_langchain_embeddings_multiple_providers(sentry_init, capture_events):
+ """Test that embeddings work with different providers."""
+ try:
+ from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Mock both providers
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_documents",
+ wraps=lambda self, texts: [[0.1, 0.2, 0.3] for _ in texts],
+ ), mock.patch.object(
+ AzureOpenAIEmbeddings,
+ "embed_documents",
+ wraps=lambda self, texts: [[0.4, 0.5, 0.6] for _ in texts],
+ ):
+ openai_embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+ azure_embeddings = AzureOpenAIEmbeddings(
+ model="text-embedding-ada-002",
+ azure_endpoint="https://test.openai.azure.com/",
+ openai_api_key="test-key",
+ )
+
+ # Force setup to re-run
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_multiple_providers"):
+ openai_embeddings.embed_documents(["OpenAI test"])
+ azure_embeddings.embed_documents(["Azure test"])
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find embeddings spans
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ # Should have 2 spans, one for each provider
+ assert len(embeddings_spans) == 2
+
+ # Verify both spans have proper data
+ for span in embeddings_spans:
+ assert span["data"]["gen_ai.operation.name"] == "embeddings"
+ assert span["data"]["gen_ai.request.model"] == "text-embedding-ada-002"
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT in span["data"]
+
+
+def test_langchain_embeddings_error_handling(sentry_init, capture_events):
+ """Test that errors in embeddings are properly captured."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Mock the API call to raise an error
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_documents",
+ side_effect=ValueError("API error"),
+ ):
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ # Force setup to re-run
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_embeddings_error"):
+ with pytest.raises(ValueError):
+ embeddings.embed_documents(["Test"])
+
+ # The error should be captured
+ assert len(events) >= 1
+ # We should have both the transaction and potentially an error event
+ [e for e in events if e.get("level") == "error"]
+ # Note: errors might not be auto-captured depending on SDK settings,
+ # but the span should still be created
+
+
+def test_langchain_embeddings_multiple_calls(sentry_init, capture_events):
+ """Test that multiple embeddings calls within a transaction are all traced."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Mock the actual API calls
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_documents",
+ wraps=lambda self, texts: [[0.1, 0.2, 0.3] for _ in texts],
+ ), mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_query",
+ wraps=lambda self, text: [0.4, 0.5, 0.6],
+ ):
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ # Force setup to re-run
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_multiple_embeddings"):
+ # Call embed_documents
+ embeddings.embed_documents(["First batch", "Second batch"])
+ # Call embed_query
+ embeddings.embed_query("Single query")
+ # Call embed_documents again
+ embeddings.embed_documents(["Third batch"])
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find embeddings spans - should have 3 (2 embed_documents + 1 embed_query)
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ assert len(embeddings_spans) == 3
+
+ # Verify all spans have proper data
+ for span in embeddings_spans:
+ assert span["data"]["gen_ai.operation.name"] == "embeddings"
+ assert span["data"]["gen_ai.request.model"] == "text-embedding-ada-002"
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT in span["data"]
+
+ # Verify the input data is different for each span
+ input_data_list = [
+ span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] for span in embeddings_spans
+ ]
+ # They should all be different (different inputs)
+ assert len(set(str(data) for data in input_data_list)) == 3
+
+
+def test_langchain_embeddings_span_hierarchy(sentry_init, capture_events):
+ """Test that embeddings spans are properly nested within parent spans."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Mock the actual API call
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_documents",
+ wraps=lambda self, texts: [[0.1, 0.2, 0.3] for _ in texts],
+ ):
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ # Force setup to re-run
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_span_hierarchy"):
+ with sentry_sdk.start_span(op="custom", name="custom operation"):
+ embeddings.embed_documents(["Test within custom span"])
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find all spans
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ custom_spans = [span for span in tx.get("spans", []) if span.get("op") == "custom"]
+
+ assert len(embeddings_spans) == 1
+ assert len(custom_spans) == 1
+
+ # Both spans should exist
+ embeddings_span = embeddings_spans[0]
+ custom_span = custom_spans[0]
+
+ assert embeddings_span["data"]["gen_ai.operation.name"] == "embeddings"
+ assert custom_span["description"] == "custom operation"
+
+
+def test_langchain_embeddings_with_list_and_string_inputs(sentry_init, capture_events):
+ """Test that embeddings correctly handle both list and string inputs."""
+ try:
+ from langchain_openai import OpenAIEmbeddings
+ except ImportError:
+ pytest.skip("langchain_openai not installed")
+
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Mock the actual API calls
+ with mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_documents",
+ wraps=lambda self, texts: [[0.1, 0.2, 0.3] for _ in texts],
+ ), mock.patch.object(
+ OpenAIEmbeddings,
+ "embed_query",
+ wraps=lambda self, text: [0.4, 0.5, 0.6],
+ ):
+ embeddings = OpenAIEmbeddings(
+ model="text-embedding-ada-002", openai_api_key="test-key"
+ )
+
+ # Force setup to re-run
+ LangchainIntegration.setup_once()
+
+ with start_transaction(name="test_input_types"):
+ # embed_documents takes a list
+ embeddings.embed_documents(["List item 1", "List item 2", "List item 3"])
+ # embed_query takes a string
+ embeddings.embed_query("Single string query")
+
+ # Check captured events
+ assert len(events) >= 1
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ # Find embeddings spans
+ embeddings_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == "gen_ai.embeddings"
+ ]
+ assert len(embeddings_spans) == 2
+
+ # Both should have input data captured as lists
+ for span in embeddings_spans:
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT in span["data"]
+ input_data = span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ # Input should be normalized to list format
+ if isinstance(input_data, str):
+ # If serialized, should contain the input text
+ assert "List item" in input_data or "Single string query" in input_data, (
+ f"Expected input text in serialized data: {input_data}"
+ )
+
+
+@pytest.mark.parametrize(
+ "response_metadata_model,expected_model",
+ [
+ ("gpt-3.5-turbo", "gpt-3.5-turbo"),
+ (None, None),
+ ],
+)
+def test_langchain_response_model_extraction(
+ sentry_init,
+ capture_events,
+ response_metadata_model,
+ expected_model,
+):
+ sentry_init(
+ integrations=[LangchainIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ callback = SentryLangchainCallback(max_span_map_size=100, include_prompts=True)
+
+ run_id = "test-response-model-uuid"
+ serialized = {"_type": "openai-chat", "model_name": "gpt-3.5-turbo"}
+ prompts = ["Test prompt"]
+
+ with start_transaction():
+ callback.on_llm_start(
+ serialized=serialized,
+ prompts=prompts,
+ run_id=run_id,
+ invocation_params={"model": "gpt-3.5-turbo"},
+ )
+
+ response_metadata = {"model_name": response_metadata_model}
+ message = AIMessageChunk(
+ content="Test response", response_metadata=response_metadata
+ )
+
+ generation = Mock(text="Test response", message=message)
+ response = Mock(generations=[[generation]])
+ callback.on_llm_end(response=response, run_id=run_id)
+
+ assert len(events) > 0
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ llm_spans = [
+ span
+ for span in tx.get("spans", [])
+ if span.get("op") == "gen_ai.text_completion"
+ ]
+ assert len(llm_spans) > 0
+
+ llm_span = llm_spans[0]
+ assert llm_span["data"]["gen_ai.operation.name"] == "text_completion"
+
+ if expected_model is not None:
+ assert SPANDATA.GEN_AI_RESPONSE_MODEL in llm_span["data"]
+ assert llm_span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == expected_model
+ else:
+ assert SPANDATA.GEN_AI_RESPONSE_MODEL not in llm_span.get("data", {})
+
+
+# Tests for multimodal content transformation functions
+
+
+class TestTransformLangchainContentBlock:
+ """Tests for _transform_langchain_content_block function."""
+
+ def test_transform_image_base64(self):
+ """Test transformation of base64-encoded image content."""
+ content_block = {
+ "type": "image",
+ "base64": "/9j/4AAQSkZJRgABAQAAAQABAAD...",
+ "mime_type": "image/jpeg",
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "content": "/9j/4AAQSkZJRgABAQAAAQABAAD...",
+ }
+
+ def test_transform_image_url(self):
+ """Test transformation of URL-referenced image content."""
+ content_block = {
+ "type": "image",
+ "url": "https://example.com/image.jpg",
+ "mime_type": "image/jpeg",
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "uri": "https://example.com/image.jpg",
+ }
+
+ def test_transform_image_file_id(self):
+ """Test transformation of file_id-referenced image content."""
+ content_block = {
+ "type": "image",
+ "file_id": "file-abc123",
+ "mime_type": "image/png",
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "file",
+ "modality": "image",
+ "mime_type": "image/png",
+ "file_id": "file-abc123",
+ }
+
+ def test_transform_image_url_legacy_with_data_uri(self):
+ """Test transformation of legacy image_url format with data: URI (base64)."""
+ content_block = {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD"},
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "content": "/9j/4AAQSkZJRgABAQAAAQABAAD",
+ }
+
+ def test_transform_image_url_legacy_with_http_url(self):
+ """Test transformation of legacy image_url format with HTTP URL."""
+ content_block = {
+ "type": "image_url",
+ "image_url": {"url": "https://example.com/image.png"},
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "",
+ "uri": "https://example.com/image.png",
+ }
+
+ def test_transform_image_url_legacy_string_url(self):
+ """Test transformation of legacy image_url format with string URL."""
+ content_block = {
+ "type": "image_url",
+ "image_url": "https://example.com/image.gif",
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "",
+ "uri": "https://example.com/image.gif",
+ }
+
+ def test_transform_image_url_legacy_data_uri_png(self):
+ """Test transformation of legacy image_url format with PNG data URI."""
+ content_block = {
+ "type": "image_url",
+ "image_url": {
+ "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
+ },
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/png",
+ "content": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
+ }
+
+ def test_transform_missing_mime_type(self):
+ """Test transformation when mime_type is not provided."""
+ content_block = {
+ "type": "image",
+ "base64": "/9j/4AAQSkZJRgABAQAAAQABAAD...",
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "",
+ "content": "/9j/4AAQSkZJRgABAQAAAQABAAD...",
+ }
+
+ def test_transform_anthropic_source_base64(self):
+ """Test transformation of Anthropic-style image with base64 source."""
+ content_block = {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/png",
+ "data": "iVBORw0KGgoAAAANSUhEUgAAAAE...",
+ },
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/png",
+ "content": "iVBORw0KGgoAAAANSUhEUgAAAAE...",
+ }
+
+ def test_transform_anthropic_source_url(self):
+ """Test transformation of Anthropic-style image with URL source."""
+ content_block = {
+ "type": "image",
+ "source": {
+ "type": "url",
+ "media_type": "image/jpeg",
+ "url": "https://example.com/image.jpg",
+ },
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "uri": "https://example.com/image.jpg",
+ }
+
+ def test_transform_anthropic_source_without_media_type(self):
+ """Test transformation of Anthropic-style image without media_type uses empty mime_type."""
+ content_block = {
+ "type": "image",
+ "mime_type": "image/webp", # Top-level mime_type is ignored by standard Anthropic format
+ "source": {
+ "type": "base64",
+ "data": "UklGRh4AAABXRUJQVlA4IBIAAAAwAQCdASoBAAEAAQAcJYgCdAEO",
+ },
+ }
+ result = _transform_langchain_content_block(content_block)
+ # Note: The shared transform_content_part uses media_type from source, not top-level mime_type
+ assert result == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "",
+ "content": "UklGRh4AAABXRUJQVlA4IBIAAAAwAQCdASoBAAEAAQAcJYgCdAEO",
+ }
+
+ def test_transform_google_inline_data(self):
+ """Test transformation of Google-style inline_data format."""
+ content_block = {
+ "inline_data": {
+ "mime_type": "image/jpeg",
+ "data": "/9j/4AAQSkZJRgABAQAAAQABAAD...",
+ }
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "content": "/9j/4AAQSkZJRgABAQAAAQABAAD...",
+ }
+
+ def test_transform_google_file_data(self):
+ """Test transformation of Google-style file_data format."""
+ content_block = {
+ "file_data": {
+ "mime_type": "image/png",
+ "file_uri": "gs://bucket/path/to/image.png",
+ }
+ }
+ result = _transform_langchain_content_block(content_block)
+ assert result == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "image/png",
+ "uri": "gs://bucket/path/to/image.png",
+ }
+
+
+@pytest.mark.parametrize(
+ "ai_type,expected_system",
+ [
+ # Real LangChain _type values (from _llm_type properties)
+ # OpenAI
+ ("openai-chat", "openai-chat"),
+ ("openai", "openai"),
+ # Azure OpenAI
+ ("azure-openai-chat", "azure-openai-chat"),
+ ("azure", "azure"),
+ # Anthropic
+ ("anthropic-chat", "anthropic-chat"),
+ # Google
+ ("vertexai", "vertexai"),
+ ("chat-google-generative-ai", "chat-google-generative-ai"),
+ ("google_gemini", "google_gemini"),
+ # AWS Bedrock
+ ("amazon_bedrock_chat", "amazon_bedrock_chat"),
+ ("amazon_bedrock", "amazon_bedrock"),
+ # Cohere
+ ("cohere-chat", "cohere-chat"),
+ # Ollama
+ ("chat-ollama", "chat-ollama"),
+ ("ollama-llm", "ollama-llm"),
+ # Mistral
+ ("mistralai-chat", "mistralai-chat"),
+ # Fireworks
+ ("fireworks-chat", "fireworks-chat"),
+ ("fireworks", "fireworks"),
+ # HuggingFace
+ ("huggingface-chat-wrapper", "huggingface-chat-wrapper"),
+ # Groq
+ ("groq-chat", "groq-chat"),
+ # NVIDIA
+ ("chat-nvidia-ai-playground", "chat-nvidia-ai-playground"),
+ # xAI
+ ("xai-chat", "xai-chat"),
+ # DeepSeek
+ ("chat-deepseek", "chat-deepseek"),
+ # Edge cases
+ ("", None),
+ (None, None),
+ ],
+)
+def test_langchain_ai_system_detection(
+ sentry_init, capture_events, ai_type, expected_system
+):
+ sentry_init(
+ integrations=[LangchainIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ callback = SentryLangchainCallback(max_span_map_size=100, include_prompts=True)
+
+ run_id = "test-ai-system-uuid"
+ serialized = {"_type": ai_type} if ai_type is not None else {}
+ prompts = ["Test prompt"]
+
+ with start_transaction():
+ callback.on_llm_start(
+ serialized=serialized,
+ prompts=prompts,
+ run_id=run_id,
+ invocation_params={"_type": ai_type, "model": "test-model"},
+ )
+
+ generation = Mock(text="Test response", message=None)
+ response = Mock(generations=[[generation]])
+ callback.on_llm_end(response=response, run_id=run_id)
+
+ assert len(events) > 0
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ llm_spans = [
+ span
+ for span in tx.get("spans", [])
+ if span.get("op") == "gen_ai.text_completion"
+ ]
+ assert len(llm_spans) > 0
+
+ llm_span = llm_spans[0]
+
+ if expected_system is not None:
+ assert llm_span["data"][SPANDATA.GEN_AI_SYSTEM] == expected_system
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM not in llm_span.get("data", {})
+
+
+class TestTransformLangchainMessageContent:
+ """Tests for _transform_langchain_message_content function."""
+
+ def test_transform_string_content(self):
+ """Test that string content is returned unchanged."""
+ result = _transform_langchain_message_content("Hello, world!")
+ assert result == "Hello, world!"
+
+ def test_transform_list_with_text_blocks(self):
+ """Test transformation of list with text blocks (unchanged)."""
+ content = [
+ {"type": "text", "text": "First message"},
+ {"type": "text", "text": "Second message"},
+ ]
+ result = _transform_langchain_message_content(content)
+ assert result == content
+
+ def test_transform_list_with_image_blocks(self):
+ """Test transformation of list containing image blocks."""
+ content = [
+ {"type": "text", "text": "Check out this image:"},
+ {
+ "type": "image",
+ "base64": "/9j/4AAQSkZJRgABAQAAAQABAAD...",
+ "mime_type": "image/jpeg",
+ },
+ ]
+ result = _transform_langchain_message_content(content)
+ assert len(result) == 2
+ assert result[0] == {"type": "text", "text": "Check out this image:"}
+ assert result[1] == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "content": "/9j/4AAQSkZJRgABAQAAAQABAAD...",
+ }
+
+ def test_transform_list_with_mixed_content(self):
+ """Test transformation of list with mixed content types."""
+ content = [
+ {"type": "text", "text": "Here are some files:"},
+ {
+ "type": "image",
+ "url": "https://example.com/image.jpg",
+ "mime_type": "image/jpeg",
+ },
+ {
+ "type": "file",
+ "file_id": "doc-123",
+ "mime_type": "application/pdf",
+ },
+ {"type": "audio", "base64": "audio_data...", "mime_type": "audio/mp3"},
+ ]
+ result = _transform_langchain_message_content(content)
+ assert len(result) == 4
+ assert result[0] == {"type": "text", "text": "Here are some files:"}
+ assert result[1] == {
+ "type": "uri",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "uri": "https://example.com/image.jpg",
+ }
+ assert result[2] == {
+ "type": "file",
+ "modality": "document",
+ "mime_type": "application/pdf",
+ "file_id": "doc-123",
+ }
+ assert result[3] == {
+ "type": "blob",
+ "modality": "audio",
+ "mime_type": "audio/mp3",
+ "content": "audio_data...",
+ }
+
+ def test_transform_list_with_non_dict_items(self):
+ """Test transformation handles non-dict items in list."""
+ content = ["plain string", {"type": "text", "text": "dict text"}]
+ result = _transform_langchain_message_content(content)
+ assert result == ["plain string", {"type": "text", "text": "dict text"}]
+
+ def test_transform_tuple_content(self):
+ """Test transformation of tuple content."""
+ content = (
+ {"type": "text", "text": "Message"},
+ {"type": "image", "base64": "data...", "mime_type": "image/png"},
+ )
+ result = _transform_langchain_message_content(content)
+ assert len(result) == 2
+ assert result[1] == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/png",
+ "content": "data...",
+ }
+
+ def test_transform_list_with_legacy_image_url(self):
+ """Test transformation of list containing legacy image_url blocks."""
+ content = [
+ {"type": "text", "text": "Check this:"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
+ },
+ ]
+ result = _transform_langchain_message_content(content)
+ assert len(result) == 2
+ assert result[0] == {"type": "text", "text": "Check this:"}
+ assert result[1] == {
+ "type": "blob",
+ "modality": "image",
+ "mime_type": "image/jpeg",
+ "content": "/9j/4AAQ...",
+ }
diff --git a/tests/integrations/langgraph/__init__.py b/tests/integrations/langgraph/__init__.py
new file mode 100644
index 0000000000..b7dd1cb562
--- /dev/null
+++ b/tests/integrations/langgraph/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("langgraph")
diff --git a/tests/integrations/langgraph/test_langgraph.py b/tests/integrations/langgraph/test_langgraph.py
new file mode 100644
index 0000000000..2a385d8a78
--- /dev/null
+++ b/tests/integrations/langgraph/test_langgraph.py
@@ -0,0 +1,1387 @@
+import asyncio
+import sys
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import SPANDATA, OP
+
+
+def mock_langgraph_imports():
+ """Mock langgraph modules to prevent import errors."""
+ mock_state_graph = MagicMock()
+ mock_pregel = MagicMock()
+
+ langgraph_graph_mock = MagicMock()
+ langgraph_graph_mock.StateGraph = mock_state_graph
+
+ langgraph_pregel_mock = MagicMock()
+ langgraph_pregel_mock.Pregel = mock_pregel
+
+ sys.modules["langgraph"] = MagicMock()
+ sys.modules["langgraph.graph"] = langgraph_graph_mock
+ sys.modules["langgraph.pregel"] = langgraph_pregel_mock
+
+ return mock_state_graph, mock_pregel
+
+
+mock_state_graph, mock_pregel = mock_langgraph_imports()
+
+from sentry_sdk.integrations.langgraph import ( # noqa: E402
+ LanggraphIntegration,
+ _parse_langgraph_messages,
+ _wrap_state_graph_compile,
+ _wrap_pregel_invoke,
+ _wrap_pregel_ainvoke,
+)
+
+
+class MockStateGraph:
+ def __init__(self, schema=None):
+ self.name = "test_graph"
+ self.schema = schema
+ self._compiled_graph = None
+
+ def compile(self, *args, **kwargs):
+ compiled = MockCompiledGraph(self.name)
+ compiled.graph = self
+ return compiled
+
+
+class MockCompiledGraph:
+ def __init__(self, name="test_graph"):
+ self.name = name
+ self._graph = None
+
+ def get_graph(self):
+ return MockGraphRepresentation()
+
+ def invoke(self, state, config=None):
+ return {"messages": [MockMessage("Response from graph")]}
+
+ async def ainvoke(self, state, config=None):
+ return {"messages": [MockMessage("Async response from graph")]}
+
+
+class MockGraphRepresentation:
+ def __init__(self):
+ self.nodes = {"tools": MockToolsNode()}
+
+
+class MockToolsNode:
+ def __init__(self):
+ self.data = MockToolsData()
+
+
+class MockToolsData:
+ def __init__(self):
+ self.tools_by_name = {
+ "search_tool": MockTool("search_tool"),
+ "calculator": MockTool("calculator"),
+ }
+
+
+class MockTool:
+ def __init__(self, name):
+ self.name = name
+
+
+class MockMessage:
+ def __init__(
+ self,
+ content,
+ name=None,
+ tool_calls=None,
+ function_call=None,
+ role=None,
+ type=None,
+ response_metadata=None,
+ ):
+ self.content = content
+ self.name = name
+ self.tool_calls = tool_calls
+ self.function_call = function_call
+ self.role = role
+ # The integration uses getattr(message, "type", None) for the role in _normalize_langgraph_message
+ # Set default type based on name if type not explicitly provided
+ if type is None and name in ["assistant", "ai", "user", "system", "function"]:
+ self.type = name
+ else:
+ self.type = type
+ self.response_metadata = response_metadata
+
+
+class MockPregelInstance:
+ def __init__(self, name="test_pregel"):
+ self.name = name
+ self.graph_name = name
+
+ def invoke(self, state, config=None):
+ return {"messages": [MockMessage("Pregel response")]}
+
+ async def ainvoke(self, state, config=None):
+ return {"messages": [MockMessage("Async Pregel response")]}
+
+
+def test_langgraph_integration_init():
+ """Test LanggraphIntegration initialization with different parameters."""
+ integration = LanggraphIntegration()
+ assert integration.include_prompts is True
+ assert integration.identifier == "langgraph"
+ assert integration.origin == "auto.ai.langgraph"
+
+ integration = LanggraphIntegration(include_prompts=False)
+ assert integration.include_prompts is False
+ assert integration.identifier == "langgraph"
+ assert integration.origin == "auto.ai.langgraph"
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_state_graph_compile(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ """Test StateGraph.compile() wrapper creates proper create_agent span."""
+ sentry_init(
+ integrations=[LanggraphIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+ graph = MockStateGraph()
+
+ def original_compile(self, *args, **kwargs):
+ return MockCompiledGraph(self.name)
+
+ with patch("sentry_sdk.integrations.langgraph.StateGraph"):
+ with start_transaction():
+ wrapped_compile = _wrap_state_graph_compile(original_compile)
+ compiled_graph = wrapped_compile(
+ graph, model="test-model", checkpointer=None
+ )
+
+ assert compiled_graph is not None
+ assert compiled_graph.name == "test_graph"
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ agent_spans = [span for span in tx["spans"] if span["op"] == OP.GEN_AI_CREATE_AGENT]
+ assert len(agent_spans) == 1
+
+ agent_span = agent_spans[0]
+ assert agent_span["description"] == "create_agent test_graph"
+ assert agent_span["origin"] == "auto.ai.langgraph"
+ assert agent_span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "create_agent"
+ assert agent_span["data"][SPANDATA.GEN_AI_AGENT_NAME] == "test_graph"
+ assert agent_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "test-model"
+ assert SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS in agent_span["data"]
+
+ tools_data = agent_span["data"][SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS]
+ assert tools_data == ["search_tool", "calculator"]
+ assert len(tools_data) == 2
+ assert "search_tool" in tools_data
+ assert "calculator" in tools_data
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_pregel_invoke(sentry_init, capture_events, send_default_pii, include_prompts):
+ """Test Pregel.invoke() wrapper creates proper invoke_agent span."""
+ sentry_init(
+ integrations=[LanggraphIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ def original_invoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ )
+ ]
+ return {"messages": new_messages}
+
+ with start_transaction():
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ result = wrapped_invoke(pregel, test_state)
+
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_span = invoke_spans[0]
+ assert invoke_span["description"] == "invoke_agent test_graph"
+ assert invoke_span["origin"] == "auto.ai.langgraph"
+ assert invoke_span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "invoke_agent"
+ assert invoke_span["data"][SPANDATA.GEN_AI_PIPELINE_NAME] == "test_graph"
+ assert invoke_span["data"][SPANDATA.GEN_AI_AGENT_NAME] == "test_graph"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in invoke_span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT in invoke_span["data"]
+
+ request_messages = invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+
+ if isinstance(request_messages, str):
+ import json
+
+ request_messages = json.loads(request_messages)
+ assert len(request_messages) == 1
+ assert request_messages[0]["content"] == "Of course! How can I assist you?"
+
+ response_text = invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ assert response_text == expected_assistant_response
+
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS in invoke_span["data"]
+ tool_calls_data = invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS]
+ if isinstance(tool_calls_data, str):
+ import json
+
+ tool_calls_data = json.loads(tool_calls_data)
+
+ assert len(tool_calls_data) == 1
+ assert tool_calls_data[0]["id"] == "call_test_123"
+ assert tool_calls_data[0]["function"]["name"] == "search_tool"
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in invoke_span.get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in invoke_span.get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS not in invoke_span.get("data", {})
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_pregel_ainvoke(sentry_init, capture_events, send_default_pii, include_prompts):
+ """Test Pregel.ainvoke() async wrapper creates proper invoke_agent span."""
+ sentry_init(
+ integrations=[LanggraphIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+ test_state = {"messages": [MockMessage("What's the weather like?", name="user")]}
+ pregel = MockPregelInstance("async_graph")
+
+ expected_assistant_response = "It's sunny and 72°F today!"
+ expected_tool_calls = [
+ {
+ "id": "call_weather_456",
+ "type": "function",
+ "function": {"name": "get_weather", "arguments": '{"location": "current"}'},
+ }
+ ]
+
+ async def original_ainvoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ )
+ ]
+ return {"messages": new_messages}
+
+ async def run_test():
+ with start_transaction():
+ wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke)
+ result = await wrapped_ainvoke(pregel, test_state)
+ return result
+
+ result = asyncio.run(run_test())
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_span = invoke_spans[0]
+ assert invoke_span["description"] == "invoke_agent async_graph"
+ assert invoke_span["origin"] == "auto.ai.langgraph"
+ assert invoke_span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "invoke_agent"
+ assert invoke_span["data"][SPANDATA.GEN_AI_PIPELINE_NAME] == "async_graph"
+ assert invoke_span["data"][SPANDATA.GEN_AI_AGENT_NAME] == "async_graph"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in invoke_span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT in invoke_span["data"]
+
+ response_text = invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ assert response_text == expected_assistant_response
+
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS in invoke_span["data"]
+ tool_calls_data = invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS]
+ if isinstance(tool_calls_data, str):
+ import json
+
+ tool_calls_data = json.loads(tool_calls_data)
+
+ assert len(tool_calls_data) == 1
+ assert tool_calls_data[0]["id"] == "call_weather_456"
+ assert tool_calls_data[0]["function"]["name"] == "get_weather"
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in invoke_span.get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in invoke_span.get("data", {})
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS not in invoke_span.get("data", {})
+
+
+def test_pregel_invoke_error(sentry_init, capture_events):
+ """Test error handling during graph execution."""
+ sentry_init(
+ integrations=[LanggraphIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ test_state = {"messages": [MockMessage("This will fail")]}
+ pregel = MockPregelInstance("error_graph")
+
+ def original_invoke(self, *args, **kwargs):
+ raise Exception("Graph execution failed")
+
+ with start_transaction(), pytest.raises(Exception, match="Graph execution failed"):
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ wrapped_invoke(pregel, test_state)
+
+ tx = events[0]
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_span = invoke_spans[0]
+ assert invoke_span.get("status") == "internal_error"
+ assert invoke_span.get("tags", {}).get("status") == "internal_error"
+
+
+def test_pregel_ainvoke_error(sentry_init, capture_events):
+ """Test error handling during async graph execution."""
+ sentry_init(
+ integrations=[LanggraphIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+ test_state = {"messages": [MockMessage("This will fail async")]}
+ pregel = MockPregelInstance("async_error_graph")
+
+ async def original_ainvoke(self, *args, **kwargs):
+ raise Exception("Async graph execution failed")
+
+ async def run_error_test():
+ with start_transaction(), pytest.raises(
+ Exception, match="Async graph execution failed"
+ ):
+ wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke)
+ await wrapped_ainvoke(pregel, test_state)
+
+ asyncio.run(run_error_test())
+
+ tx = events[0]
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_span = invoke_spans[0]
+ assert invoke_span.get("status") == "internal_error"
+ assert invoke_span.get("tags", {}).get("status") == "internal_error"
+
+
+def test_span_origin(sentry_init, capture_events):
+ """Test that span origins are correctly set."""
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ graph = MockStateGraph()
+
+ def original_compile(self, *args, **kwargs):
+ return MockCompiledGraph(self.name)
+
+ with start_transaction():
+ from sentry_sdk.integrations.langgraph import _wrap_state_graph_compile
+
+ wrapped_compile = _wrap_state_graph_compile(original_compile)
+ wrapped_compile(graph)
+
+ tx = events[0]
+ assert tx["contexts"]["trace"]["origin"] == "manual"
+
+ for span in tx["spans"]:
+ assert span["origin"] == "auto.ai.langgraph"
+
+
+@pytest.mark.parametrize("graph_name", ["my_graph", None, ""])
+def test_pregel_invoke_with_different_graph_names(
+ sentry_init, capture_events, graph_name
+):
+ """Test Pregel.invoke() with different graph name scenarios."""
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ pregel = MockPregelInstance(graph_name) if graph_name else MockPregelInstance()
+ if not graph_name:
+ delattr(pregel, "name")
+ delattr(pregel, "graph_name")
+
+ def original_invoke(self, *args, **kwargs):
+ return {"result": "test"}
+
+ with start_transaction():
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ wrapped_invoke(pregel, {"messages": []})
+
+ tx = events[0]
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_span = invoke_spans[0]
+
+ if graph_name and graph_name.strip():
+ assert invoke_span["description"] == "invoke_agent my_graph"
+ assert invoke_span["data"][SPANDATA.GEN_AI_PIPELINE_NAME] == graph_name
+ assert invoke_span["data"][SPANDATA.GEN_AI_AGENT_NAME] == graph_name
+ else:
+ assert invoke_span["description"] == "invoke_agent"
+ assert SPANDATA.GEN_AI_PIPELINE_NAME not in invoke_span.get("data", {})
+ assert SPANDATA.GEN_AI_AGENT_NAME not in invoke_span.get("data", {})
+
+
+def test_pregel_invoke_span_includes_usage_data(sentry_init, capture_events):
+ """
+ Test that invoke_agent spans include aggregated usage data from context_wrapper.
+ This verifies the new functionality added to track token usage in invoke_agent spans.
+ """
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ def original_invoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 30,
+ "prompt_tokens": 10,
+ "completion_tokens": 20,
+ },
+ "model_name": "gpt-4.1-2025-04-14",
+ },
+ )
+ ]
+ return {"messages": new_messages}
+
+ with start_transaction():
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ result = wrapped_invoke(pregel, test_state)
+
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_agent_span = invoke_spans[0]
+
+ # Verify invoke_agent span has usage data
+ assert invoke_agent_span["description"] == "invoke_agent test_graph"
+ assert "gen_ai.usage.input_tokens" in invoke_agent_span["data"]
+ assert "gen_ai.usage.output_tokens" in invoke_agent_span["data"]
+ assert "gen_ai.usage.total_tokens" in invoke_agent_span["data"]
+
+ # The usage should match the mock_usage values (aggregated across all calls)
+ assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 10
+ assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20
+ assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+def test_pregel_ainvoke_span_includes_usage_data(sentry_init, capture_events):
+ """
+ Test that invoke_agent spans include aggregated usage data from context_wrapper.
+ This verifies the new functionality added to track token usage in invoke_agent spans.
+ """
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ async def original_ainvoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 30,
+ "prompt_tokens": 10,
+ "completion_tokens": 20,
+ },
+ "model_name": "gpt-4.1-2025-04-14",
+ },
+ )
+ ]
+ return {"messages": new_messages}
+
+ async def run_test():
+ with start_transaction():
+ wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke)
+ result = await wrapped_ainvoke(pregel, test_state)
+ return result
+
+ result = asyncio.run(run_test())
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_agent_span = invoke_spans[0]
+
+ # Verify invoke_agent span has usage data
+ assert invoke_agent_span["description"] == "invoke_agent test_graph"
+ assert "gen_ai.usage.input_tokens" in invoke_agent_span["data"]
+ assert "gen_ai.usage.output_tokens" in invoke_agent_span["data"]
+ assert "gen_ai.usage.total_tokens" in invoke_agent_span["data"]
+
+ # The usage should match the mock_usage values (aggregated across all calls)
+ assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 10
+ assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20
+ assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+def test_pregel_invoke_multiple_llm_calls_aggregate_usage(sentry_init, capture_events):
+ """
+ Test that invoke_agent spans show aggregated usage across multiple LLM calls
+ (e.g., when tools are used and multiple API calls are made).
+ """
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ def original_invoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 15,
+ "prompt_tokens": 10,
+ "completion_tokens": 5,
+ },
+ },
+ ),
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 35,
+ "prompt_tokens": 20,
+ "completion_tokens": 15,
+ },
+ },
+ ),
+ ]
+ return {"messages": new_messages}
+
+ with start_transaction():
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ result = wrapped_invoke(pregel, test_state)
+
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+ invoke_agent_span = invoke_spans[0]
+
+ # Verify invoke_agent span has aggregated usage from both API calls
+ # Total: 10 + 20 = 30 input tokens, 5 + 15 = 20 output tokens, 15 + 35 = 50 total
+ assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 30
+ assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20
+ assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 50
+
+
+def test_pregel_ainvoke_multiple_llm_calls_aggregate_usage(sentry_init, capture_events):
+ """
+ Test that invoke_agent spans show aggregated usage across multiple LLM calls
+ (e.g., when tools are used and multiple API calls are made).
+ """
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ async def original_ainvoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 15,
+ "prompt_tokens": 10,
+ "completion_tokens": 5,
+ },
+ },
+ ),
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 35,
+ "prompt_tokens": 20,
+ "completion_tokens": 15,
+ },
+ },
+ ),
+ ]
+ return {"messages": new_messages}
+
+ async def run_test():
+ with start_transaction():
+ wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke)
+ result = await wrapped_ainvoke(pregel, test_state)
+ return result
+
+ result = asyncio.run(run_test())
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+ invoke_agent_span = invoke_spans[0]
+
+ # Verify invoke_agent span has aggregated usage from both API calls
+ # Total: 10 + 20 = 30 input tokens, 5 + 15 = 20 output tokens, 15 + 35 = 50 total
+ assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 30
+ assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20
+ assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 50
+
+
+def test_pregel_invoke_span_includes_response_model(sentry_init, capture_events):
+ """
+ Test that invoke_agent spans include the response model.
+ When an agent makes multiple LLM calls, it should report the last model used.
+ """
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ def original_invoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 30,
+ "prompt_tokens": 10,
+ "completion_tokens": 20,
+ },
+ "model_name": "gpt-4.1-2025-04-14",
+ },
+ )
+ ]
+ return {"messages": new_messages}
+
+ with start_transaction():
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ result = wrapped_invoke(pregel, test_state)
+
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_agent_span = invoke_spans[0]
+
+ # Verify invoke_agent span has response model
+ assert invoke_agent_span["description"] == "invoke_agent test_graph"
+ assert "gen_ai.response.model" in invoke_agent_span["data"]
+ assert invoke_agent_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+
+
+def test_pregel_ainvoke_span_includes_response_model(sentry_init, capture_events):
+ """
+ Test that invoke_agent spans include the response model.
+ When an agent makes multiple LLM calls, it should report the last model used.
+ """
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ async def original_ainvoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 30,
+ "prompt_tokens": 10,
+ "completion_tokens": 20,
+ },
+ "model_name": "gpt-4.1-2025-04-14",
+ },
+ )
+ ]
+ return {"messages": new_messages}
+
+ async def run_test():
+ with start_transaction():
+ wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke)
+ result = await wrapped_ainvoke(pregel, test_state)
+ return result
+
+ result = asyncio.run(run_test())
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_agent_span = invoke_spans[0]
+
+ # Verify invoke_agent span has response model
+ assert invoke_agent_span["description"] == "invoke_agent test_graph"
+ assert "gen_ai.response.model" in invoke_agent_span["data"]
+ assert invoke_agent_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+
+
+def test_pregel_invoke_span_uses_last_response_model(sentry_init, capture_events):
+ """
+ Test that when an agent makes multiple LLM calls (e.g., with tools),
+ the invoke_agent span reports the last response model used.
+ """
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ def original_invoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 15,
+ "prompt_tokens": 10,
+ "completion_tokens": 5,
+ },
+ "model_name": "gpt-4-0613",
+ },
+ ),
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 35,
+ "prompt_tokens": 20,
+ "completion_tokens": 15,
+ },
+ "model_name": "gpt-4.1-2025-04-14",
+ },
+ ),
+ ]
+ return {"messages": new_messages}
+
+ with start_transaction():
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ result = wrapped_invoke(pregel, test_state)
+
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_agent_span = invoke_spans[0]
+
+ # Verify invoke_agent span uses the LAST response model
+ assert "gen_ai.response.model" in invoke_agent_span["data"]
+ assert invoke_agent_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+
+
+def test_pregel_ainvoke_span_uses_last_response_model(sentry_init, capture_events):
+ """
+ Test that when an agent makes multiple LLM calls (e.g., with tools),
+ the invoke_agent span reports the last response model used.
+ """
+ sentry_init(
+ integrations=[LanggraphIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ test_state = {
+ "messages": [
+ MockMessage("Hello, can you help me?", name="user"),
+ MockMessage("Of course! How can I assist you?", name="assistant"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ expected_assistant_response = "I'll help you with that task!"
+ expected_tool_calls = [
+ {
+ "id": "call_test_123",
+ "type": "function",
+ "function": {"name": "search_tool", "arguments": '{"query": "help"}'},
+ }
+ ]
+
+ async def original_ainvoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 15,
+ "prompt_tokens": 10,
+ "completion_tokens": 5,
+ },
+ "model_name": "gpt-4-0613",
+ },
+ ),
+ MockMessage(
+ content=expected_assistant_response,
+ name="assistant",
+ tool_calls=expected_tool_calls,
+ response_metadata={
+ "token_usage": {
+ "total_tokens": 35,
+ "prompt_tokens": 20,
+ "completion_tokens": 15,
+ },
+ "model_name": "gpt-4.1-2025-04-14",
+ },
+ ),
+ ]
+ return {"messages": new_messages}
+
+ async def run_test():
+ with start_transaction():
+ wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke)
+ result = await wrapped_ainvoke(pregel, test_state)
+ return result
+
+ result = asyncio.run(run_test())
+ assert result is not None
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_agent_span = invoke_spans[0]
+
+ # Verify invoke_agent span uses the LAST response model
+ assert "gen_ai.response.model" in invoke_agent_span["data"]
+ assert invoke_agent_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+
+
+def test_complex_message_parsing():
+ """Test message parsing with complex message structures."""
+ messages = [
+ MockMessage(content="User query", name="user"),
+ MockMessage(
+ content="Assistant response with tools",
+ name="assistant",
+ tool_calls=[
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": {"name": "search", "arguments": "{}"},
+ },
+ {
+ "id": "call_2",
+ "type": "function",
+ "function": {"name": "calculate", "arguments": '{"x": 5}'},
+ },
+ ],
+ ),
+ MockMessage(
+ content="Function call response",
+ name="function",
+ function_call={"name": "search", "arguments": '{"query": "test"}'},
+ ),
+ ]
+
+ state = {"messages": messages}
+ result = _parse_langgraph_messages(state)
+
+ assert result is not None
+ assert len(result) == 3
+
+ assert result[0]["content"] == "User query"
+ assert result[0]["name"] == "user"
+ assert "tool_calls" not in result[0]
+ assert "function_call" not in result[0]
+
+ assert result[1]["content"] == "Assistant response with tools"
+ assert result[1]["name"] == "assistant"
+ assert len(result[1]["tool_calls"]) == 2
+
+ assert result[2]["content"] == "Function call response"
+ assert result[2]["name"] == "function"
+ assert result[2]["function_call"]["name"] == "search"
+
+
+def test_extraction_functions_complex_scenario(sentry_init, capture_events):
+ """Test extraction functions with complex scenarios including multiple messages and edge cases."""
+ sentry_init(
+ integrations=[LanggraphIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ pregel = MockPregelInstance("complex_graph")
+ test_state = {"messages": [MockMessage("Complex request", name="user")]}
+
+ def original_invoke(self, *args, **kwargs):
+ input_messages = args[0].get("messages", [])
+ new_messages = input_messages + [
+ MockMessage(
+ content="I'll help with multiple tasks",
+ name="assistant",
+ tool_calls=[
+ {
+ "id": "call_multi_1",
+ "type": "function",
+ "function": {
+ "name": "search",
+ "arguments": '{"query": "complex"}',
+ },
+ },
+ {
+ "id": "call_multi_2",
+ "type": "function",
+ "function": {
+ "name": "calculate",
+ "arguments": '{"expr": "2+2"}',
+ },
+ },
+ ],
+ ),
+ MockMessage("", name="assistant"),
+ MockMessage("Final response", name="ai", type="ai"),
+ ]
+ return {"messages": new_messages}
+
+ with start_transaction():
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ result = wrapped_invoke(pregel, test_state)
+
+ assert result is not None
+
+ tx = events[0]
+ invoke_spans = [
+ span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) == 1
+
+ invoke_span = invoke_spans[0]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT in invoke_span["data"]
+ response_text = invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+ assert response_text == "Final response"
+
+ assert SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS in invoke_span["data"]
+ import json
+
+ tool_calls_data = invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS]
+ if isinstance(tool_calls_data, str):
+ tool_calls_data = json.loads(tool_calls_data)
+
+ assert len(tool_calls_data) == 2
+ assert tool_calls_data[0]["id"] == "call_multi_1"
+ assert tool_calls_data[0]["function"]["name"] == "search"
+ assert tool_calls_data[1]["id"] == "call_multi_2"
+ assert tool_calls_data[1]["function"]["name"] == "calculate"
+
+
+def test_langgraph_message_role_mapping(sentry_init, capture_events):
+ """Test that Langgraph integration properly maps message roles like 'ai' to 'assistant'"""
+ sentry_init(
+ integrations=[LanggraphIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ # Mock a langgraph message with mixed roles
+ class MockMessage:
+ def __init__(self, content, message_type="human"):
+ self.content = content
+ self.type = message_type
+
+ # Create mock state with messages having different roles
+ state_data = {
+ "messages": [
+ MockMessage("System prompt", "system"),
+ MockMessage("Hello", "human"),
+ MockMessage("Hi there!", "ai"), # Should be mapped to "assistant"
+ MockMessage("How can I help?", "assistant"), # Should stay "assistant"
+ ]
+ }
+
+ compiled_graph = MockCompiledGraph("test_graph")
+ pregel = MockPregelInstance(compiled_graph)
+
+ with start_transaction(name="langgraph tx"):
+ # Use the wrapped invoke function directly
+ from sentry_sdk.integrations.langgraph import _wrap_pregel_invoke
+
+ wrapped_invoke = _wrap_pregel_invoke(
+ lambda self, state_data: {"result": "success"}
+ )
+ wrapped_invoke(pregel, state_data)
+
+ (event,) = events
+ span = event["spans"][0]
+
+ # Verify that the span was created correctly
+ assert span["op"] == "gen_ai.invoke_agent"
+
+ # If messages were captured, verify role mapping
+ if SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]:
+ import json
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ # Find messages with specific content to verify role mapping
+ ai_message = next(
+ (msg for msg in stored_messages if msg.get("content") == "Hi there!"), None
+ )
+ assistant_message = next(
+ (msg for msg in stored_messages if msg.get("content") == "How can I help?"),
+ None,
+ )
+
+ if ai_message:
+ # "ai" should have been mapped to "assistant"
+ assert ai_message["role"] == "assistant"
+
+ if assistant_message:
+ # "assistant" should stay "assistant"
+ assert assistant_message["role"] == "assistant"
+
+ # Verify no "ai" roles remain
+ roles = [msg["role"] for msg in stored_messages if "role" in msg]
+ assert "ai" not in roles
+
+
+def test_langgraph_message_truncation(sentry_init, capture_events):
+ """Test that large messages are truncated properly in Langgraph integration."""
+ import json
+
+ sentry_init(
+ integrations=[LanggraphIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ large_content = (
+ "This is a very long message that will exceed our size limits. " * 1000
+ )
+ test_state = {
+ "messages": [
+ MockMessage("small message 1", name="user"),
+ MockMessage(large_content, name="assistant"),
+ MockMessage(large_content, name="user"),
+ MockMessage("small message 4", name="assistant"),
+ MockMessage("small message 5", name="user"),
+ ]
+ }
+
+ pregel = MockPregelInstance("test_graph")
+
+ def original_invoke(self, *args, **kwargs):
+ return {"messages": args[0].get("messages", [])}
+
+ with start_transaction():
+ wrapped_invoke = _wrap_pregel_invoke(original_invoke)
+ result = wrapped_invoke(pregel, test_state)
+
+ assert result is not None
+ assert len(events) > 0
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ invoke_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == OP.GEN_AI_INVOKE_AGENT
+ ]
+ assert len(invoke_spans) > 0
+
+ invoke_span = invoke_spans[0]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in invoke_span["data"]
+
+ messages_data = invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert isinstance(messages_data, str)
+
+ parsed_messages = json.loads(messages_data)
+ assert isinstance(parsed_messages, list)
+ assert len(parsed_messages) == 1
+ assert "small message 5" in str(parsed_messages[0])
+ assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5
diff --git a/tests/integrations/launchdarkly/__init__.py b/tests/integrations/launchdarkly/__init__.py
new file mode 100644
index 0000000000..06e09884c8
--- /dev/null
+++ b/tests/integrations/launchdarkly/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("ldclient")
diff --git a/tests/integrations/launchdarkly/test_launchdarkly.py b/tests/integrations/launchdarkly/test_launchdarkly.py
new file mode 100644
index 0000000000..e588b596d3
--- /dev/null
+++ b/tests/integrations/launchdarkly/test_launchdarkly.py
@@ -0,0 +1,251 @@
+import concurrent.futures as cf
+import sys
+
+import ldclient
+import pytest
+
+from ldclient import LDClient
+from ldclient.config import Config
+from ldclient.context import Context
+from ldclient.integrations.test_data import TestData
+
+import sentry_sdk
+from sentry_sdk.integrations import DidNotEnable
+from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
+from sentry_sdk import start_span, start_transaction
+from tests.conftest import ApproxDict
+
+
+@pytest.mark.parametrize(
+ "use_global_client",
+ (False, True),
+)
+def test_launchdarkly_integration(
+ sentry_init, use_global_client, capture_events, uninstall_integration
+):
+ td = TestData.data_source()
+ td.update(td.flag("hello").variation_for_all(True))
+ td.update(td.flag("world").variation_for_all(True))
+ # Disable background requests as we aren't using a server.
+ config = Config(
+ "sdk-key", update_processor_class=td, diagnostic_opt_out=True, send_events=False
+ )
+
+ uninstall_integration(LaunchDarklyIntegration.identifier)
+ if use_global_client:
+ ldclient.set_config(config)
+ sentry_init(integrations=[LaunchDarklyIntegration()])
+ client = ldclient.get()
+ else:
+ client = LDClient(config=config)
+ sentry_init(integrations=[LaunchDarklyIntegration(ld_client=client)])
+
+ # Evaluate
+ client.variation("hello", Context.create("my-org", "organization"), False)
+ client.variation("world", Context.create("user1", "user"), False)
+ client.variation("other", Context.create("user2", "user"), False)
+
+ events = capture_events()
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 1
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": True},
+ {"flag": "other", "result": False},
+ ]
+ }
+
+
+def test_launchdarkly_integration_threaded(
+ sentry_init, capture_events, uninstall_integration
+):
+ td = TestData.data_source()
+ td.update(td.flag("hello").variation_for_all(True))
+ td.update(td.flag("world").variation_for_all(True))
+ client = LDClient(
+ config=Config(
+ "sdk-key",
+ update_processor_class=td,
+ diagnostic_opt_out=True, # Disable background requests as we aren't using a server.
+ send_events=False,
+ )
+ )
+ context = Context.create("user1")
+
+ uninstall_integration(LaunchDarklyIntegration.identifier)
+ sentry_init(integrations=[LaunchDarklyIntegration(ld_client=client)])
+ events = capture_events()
+
+ def task(flag_key):
+ # Creates a new isolation scope for the thread.
+ # This means the evaluations in each task are captured separately.
+ with sentry_sdk.isolation_scope():
+ client.variation(flag_key, context, False)
+ # use a tag to identify to identify events later on
+ sentry_sdk.set_tag("task_id", flag_key)
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ # Capture an eval before we split isolation scopes.
+ client.variation("hello", context, False)
+
+ with cf.ThreadPoolExecutor(max_workers=2) as pool:
+ pool.map(task, ["world", "other"])
+
+ # Capture error in original scope
+ sentry_sdk.set_tag("task_id", "0")
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 3
+ events.sort(key=lambda e: e["tags"]["task_id"])
+
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ ]
+ }
+ assert events[1]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "other", "result": False},
+ ]
+ }
+ assert events[2]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": True},
+ ]
+ }
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
+def test_launchdarkly_integration_asyncio(
+ sentry_init, capture_events, uninstall_integration
+):
+ """Assert concurrently evaluated flags do not pollute one another."""
+
+ asyncio = pytest.importorskip("asyncio")
+
+ td = TestData.data_source()
+ td.update(td.flag("hello").variation_for_all(True))
+ td.update(td.flag("world").variation_for_all(True))
+ client = LDClient(
+ config=Config(
+ "sdk-key",
+ update_processor_class=td,
+ diagnostic_opt_out=True, # Disable background requests as we aren't using a server.
+ send_events=False,
+ )
+ )
+ context = Context.create("user1")
+
+ uninstall_integration(LaunchDarklyIntegration.identifier)
+ sentry_init(integrations=[LaunchDarklyIntegration(ld_client=client)])
+ events = capture_events()
+
+ async def task(flag_key):
+ with sentry_sdk.isolation_scope():
+ client.variation(flag_key, context, False)
+ # use a tag to identify to identify events later on
+ sentry_sdk.set_tag("task_id", flag_key)
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ async def runner():
+ return asyncio.gather(task("world"), task("other"))
+
+ # Capture an eval before we split isolation scopes.
+ client.variation("hello", context, False)
+
+ asyncio.run(runner())
+
+ # Capture error in original scope
+ sentry_sdk.set_tag("task_id", "0")
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 3
+ events.sort(key=lambda e: e["tags"]["task_id"])
+
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ ]
+ }
+ assert events[1]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "other", "result": False},
+ ]
+ }
+ assert events[2]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": True},
+ ]
+ }
+
+
+def test_launchdarkly_integration_did_not_enable(monkeypatch):
+ # Client is not passed in and set_config wasn't called.
+ # TODO: Bad practice to access internals like this. We can skip this test, or remove this
+ # case entirely (force user to pass in a client instance).
+ ldclient._reset_client()
+ try:
+ ldclient.__lock.lock()
+ ldclient.__config = None
+ finally:
+ ldclient.__lock.unlock()
+
+ with pytest.raises(DidNotEnable):
+ LaunchDarklyIntegration()
+
+ td = TestData.data_source()
+ # Disable background requests as we aren't using a server.
+ # Required because we corrupt the internal state above.
+ config = Config(
+ "sdk-key", update_processor_class=td, diagnostic_opt_out=True, send_events=False
+ )
+ # Client not initialized.
+ client = LDClient(config=config)
+ monkeypatch.setattr(client, "is_initialized", lambda: False)
+ with pytest.raises(DidNotEnable):
+ LaunchDarklyIntegration(ld_client=client)
+
+
+@pytest.mark.parametrize(
+ "use_global_client",
+ (False, True),
+)
+def test_launchdarkly_span_integration(
+ sentry_init, use_global_client, capture_events, uninstall_integration
+):
+ td = TestData.data_source()
+ td.update(td.flag("hello").variation_for_all(True))
+ # Disable background requests as we aren't using a server.
+ config = Config(
+ "sdk-key", update_processor_class=td, diagnostic_opt_out=True, send_events=False
+ )
+
+ uninstall_integration(LaunchDarklyIntegration.identifier)
+ if use_global_client:
+ ldclient.set_config(config)
+ sentry_init(traces_sample_rate=1.0, integrations=[LaunchDarklyIntegration()])
+ client = ldclient.get()
+ else:
+ client = LDClient(config=config)
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[LaunchDarklyIntegration(ld_client=client)],
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="hi"):
+ with start_span(op="foo", name="bar"):
+ client.variation("hello", Context.create("my-org", "organization"), False)
+ client.variation("other", Context.create("my-org", "organization"), False)
+
+ (event,) = events
+ assert event["spans"][0]["data"] == ApproxDict(
+ {"flag.evaluation.hello": True, "flag.evaluation.other": False}
+ )
diff --git a/tests/integrations/litellm/__init__.py b/tests/integrations/litellm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/litellm/test_litellm.py b/tests/integrations/litellm/test_litellm.py
new file mode 100644
index 0000000000..a8df5891ce
--- /dev/null
+++ b/tests/integrations/litellm/test_litellm.py
@@ -0,0 +1,1986 @@
+import base64
+import json
+import pytest
+import time
+import asyncio
+from unittest import mock
+from datetime import datetime
+
+try:
+ from unittest.mock import AsyncMock
+except ImportError:
+
+ class AsyncMock(mock.MagicMock):
+ async def __call__(self, *args, **kwargs):
+ return super(AsyncMock, self).__call__(*args, **kwargs)
+
+
+try:
+ import litellm
+except ImportError:
+ pytest.skip("litellm not installed", allow_module_level=True)
+
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import OP, SPANDATA
+from sentry_sdk._types import BLOB_DATA_SUBSTITUTE
+from sentry_sdk.integrations.litellm import (
+ LiteLLMIntegration,
+ _convert_message_parts,
+ _input_callback,
+ _success_callback,
+ _failure_callback,
+)
+from sentry_sdk.utils import package_version
+
+from openai import OpenAI, AsyncOpenAI
+
+from concurrent.futures import ThreadPoolExecutor
+
+import litellm.utils as litellm_utils
+from litellm.litellm_core_utils import streaming_handler
+from litellm.litellm_core_utils import thread_pool_executor
+from litellm.litellm_core_utils import litellm_logging
+from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
+from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
+
+
+LITELLM_VERSION = package_version("litellm")
+
+
+def _reset_litellm_executor():
+ thread_pool_executor.executor = ThreadPoolExecutor(max_workers=100)
+ litellm_utils.executor = thread_pool_executor.executor
+ streaming_handler.executor = thread_pool_executor.executor
+ litellm_logging.executor = thread_pool_executor.executor
+
+
+@pytest.fixture()
+def reset_litellm_executor():
+ yield
+ _reset_litellm_executor()
+
+
+@pytest.fixture
+def clear_litellm_cache():
+ """
+ Clear litellm's client cache and reset integration state to ensure test isolation.
+
+ The LiteLLM integration uses setup_once() which only runs once per Python process.
+ This fixture ensures the integration is properly re-initialized for each test.
+ """
+
+ # Stop all existing mocks
+ mock.patch.stopall()
+
+ # Clear client cache
+ if (
+ hasattr(litellm, "in_memory_llm_clients_cache")
+ and litellm.in_memory_llm_clients_cache
+ ):
+ litellm.in_memory_llm_clients_cache.flush_cache()
+
+ yield
+
+ # Clean up after test as well
+ mock.patch.stopall()
+ if (
+ hasattr(litellm, "in_memory_llm_clients_cache")
+ and litellm.in_memory_llm_clients_cache
+ ):
+ litellm.in_memory_llm_clients_cache.flush_cache()
+
+
+# Mock response objects
+class MockMessage:
+ def __init__(self, role="assistant", content="Test response"):
+ self.role = role
+ self.content = content
+ self.tool_calls = None
+
+ def model_dump(self):
+ return {"role": self.role, "content": self.content}
+
+
+class MockChoice:
+ def __init__(self, message=None):
+ self.message = message or MockMessage()
+ self.index = 0
+ self.finish_reason = "stop"
+
+
+class MockUsage:
+ def __init__(self, prompt_tokens=10, completion_tokens=20, total_tokens=30):
+ self.prompt_tokens = prompt_tokens
+ self.completion_tokens = completion_tokens
+ self.total_tokens = total_tokens
+
+
+class MockCompletionResponse:
+ def __init__(
+ self,
+ model="gpt-3.5-turbo",
+ choices=None,
+ usage=None,
+ ):
+ self.id = "chatcmpl-test"
+ self.model = model
+ self.choices = choices or [MockChoice()]
+ self.usage = usage or MockUsage()
+ self.object = "chat.completion"
+ self.created = 1234567890
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_nonstreaming_chat_completion(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "litellm test"
+
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat gpt-3.5-turbo"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-3.5-turbo"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "gpt-3.5-turbo"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT in span["data"]
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_async_nonstreaming_chat_completion(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ await litellm.acompletion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ assert event["transaction"] == "litellm test"
+
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat gpt-3.5-turbo"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-3.5-turbo"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "gpt-3.5-turbo"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat"
+
+ if send_default_pii and include_prompts:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT in span["data"]
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
+ assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20
+ assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_streaming_chat_completion(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+ streaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ server_side_event_chunks(
+ streaming_chat_completions_model_response,
+ include_event_type=False,
+ ),
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ response = litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ stream=True,
+ )
+ for _ in response:
+ pass
+
+ streaming_handler.executor.shutdown(wait=True)
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_async_streaming_chat_completion(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+ streaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ streaming_chat_completions_model_response,
+ include_event_type=False,
+ ),
+ ),
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ response = await litellm.acompletion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ stream=True,
+ )
+ async for _ in response:
+ pass
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+
+def test_embeddings_create(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ openai_embedding_model_response,
+ clear_litellm_cache,
+):
+ """
+ Test that litellm.embedding() calls are properly instrumented.
+
+ This test calls the actual litellm.embedding() function (not just callbacks)
+ to ensure proper integration testing.
+ """
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ openai_embedding_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.embeddings._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ response = litellm.embedding(
+ model="text-embedding-ada-002",
+ input="Hello, world!",
+ client=client,
+ )
+ # Allow time for callbacks to complete (they may run in separate threads)
+ time.sleep(0.1)
+
+ # Response is processed by litellm, so just check it exists
+ assert response is not None
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_EMBEDDINGS and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(spans) == 1
+ span = spans[0]
+
+ assert span["op"] == OP.GEN_AI_EMBEDDINGS
+ assert span["description"] == "embeddings text-embedding-ada-002"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 5
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-ada-002"
+ # Check that embeddings input is captured (it's JSON serialized)
+ embeddings_input = span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ assert json.loads(embeddings_input) == ["Hello, world!"]
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_embeddings_create(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ openai_embedding_model_response,
+ clear_litellm_cache,
+):
+ """
+ Test that litellm.embedding() calls are properly instrumented.
+
+ This test calls the actual litellm.embedding() function (not just callbacks)
+ to ensure proper integration testing.
+ """
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ openai_embedding_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.embeddings._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ response = await litellm.aembedding(
+ model="text-embedding-ada-002",
+ input="Hello, world!",
+ client=client,
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ # Response is processed by litellm, so just check it exists
+ assert response is not None
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_EMBEDDINGS and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(spans) == 1
+ span = spans[0]
+
+ assert span["op"] == OP.GEN_AI_EMBEDDINGS
+ assert span["description"] == "embeddings text-embedding-ada-002"
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
+ assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 5
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-ada-002"
+ # Check that embeddings input is captured (it's JSON serialized)
+ embeddings_input = span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ assert json.loads(embeddings_input) == ["Hello, world!"]
+
+
+def test_embeddings_create_with_list_input(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ openai_embedding_model_response,
+ clear_litellm_cache,
+):
+ """Test embedding with list input."""
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ openai_embedding_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.embeddings._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ response = litellm.embedding(
+ model="text-embedding-ada-002",
+ input=["First text", "Second text", "Third text"],
+ client=client,
+ )
+ # Allow time for callbacks to complete (they may run in separate threads)
+ time.sleep(0.1)
+
+ # Response is processed by litellm, so just check it exists
+ assert response is not None
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_EMBEDDINGS and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(spans) == 1
+ span = spans[0]
+
+ assert span["op"] == OP.GEN_AI_EMBEDDINGS
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
+ # Check that list of embeddings input is captured (it's JSON serialized)
+ embeddings_input = span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ assert json.loads(embeddings_input) == [
+ "First text",
+ "Second text",
+ "Third text",
+ ]
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_embeddings_create_with_list_input(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ openai_embedding_model_response,
+ clear_litellm_cache,
+):
+ """Test embedding with list input."""
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ openai_embedding_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.embeddings._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ response = await litellm.aembedding(
+ model="text-embedding-ada-002",
+ input=["First text", "Second text", "Third text"],
+ client=client,
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ # Response is processed by litellm, so just check it exists
+ assert response is not None
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_EMBEDDINGS and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(spans) == 1
+ span = spans[0]
+
+ assert span["op"] == OP.GEN_AI_EMBEDDINGS
+ assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
+ # Check that list of embeddings input is captured (it's JSON serialized)
+ embeddings_input = span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
+ assert json.loads(embeddings_input) == [
+ "First text",
+ "Second text",
+ "Third text",
+ ]
+
+
+def test_embeddings_no_pii(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ openai_embedding_model_response,
+ clear_litellm_cache,
+):
+ """Test that PII is not captured when disabled."""
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=False, # PII disabled
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ openai_embedding_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.embeddings._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ response = litellm.embedding(
+ model="text-embedding-ada-002",
+ input="Hello, world!",
+ client=client,
+ )
+ # Allow time for callbacks to complete (they may run in separate threads)
+ time.sleep(0.1)
+
+ # Response is processed by litellm, so just check it exists
+ assert response is not None
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_EMBEDDINGS and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(spans) == 1
+ span = spans[0]
+
+ assert span["op"] == OP.GEN_AI_EMBEDDINGS
+ # Check that embeddings input is NOT captured when PII is disabled
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_embeddings_no_pii(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ openai_embedding_model_response,
+ clear_litellm_cache,
+):
+ """Test that PII is not captured when disabled."""
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=False, # PII disabled
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ openai_embedding_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.embeddings._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ response = await litellm.aembedding(
+ model="text-embedding-ada-002",
+ input="Hello, world!",
+ client=client,
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ # Response is processed by litellm, so just check it exists
+ assert response is not None
+ assert len(events) == 1
+ (event,) = events
+
+ assert event["type"] == "transaction"
+ spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_EMBEDDINGS and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(spans) == 1
+ span = spans[0]
+
+ assert span["op"] == OP.GEN_AI_EMBEDDINGS
+ # Check that embeddings input is NOT captured when PII is disabled
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
+
+
+def test_exception_handling(
+ reset_litellm_executor, sentry_init, capture_events, get_rate_limit_model_response
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_rate_limit_model_response()
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ with pytest.raises(litellm.RateLimitError):
+ litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ )
+
+ # Should have error event and transaction
+ assert len(events) >= 1
+ # Find the error event
+ error_events = [e for e in events if e.get("level") == "error"]
+ assert len(error_events) == 1
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_exception_handling(
+ sentry_init, capture_events, get_rate_limit_model_response
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_rate_limit_model_response()
+
+ with mock.patch.object(
+ client.embeddings._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ with pytest.raises(litellm.RateLimitError):
+ await litellm.acompletion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ )
+
+ # Should have error event and transaction
+ assert len(events) >= 1
+ # Find the error event
+ error_events = [e for e in events if e.get("level") == "error"]
+ assert len(error_events) == 1
+
+
+def test_span_origin(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.litellm"
+
+
+def test_multiple_providers(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+ nonstreaming_anthropic_model_response,
+ nonstreaming_google_genai_model_response,
+):
+ """Test that the integration correctly identifies different providers."""
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ openai_client = OpenAI(api_key="test-key")
+ openai_model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ openai_client.completions._client._client,
+ "send",
+ return_value=openai_model_response,
+ ):
+ with start_transaction(name="test gpt-3.5-turbo"):
+ litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=openai_client,
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ _reset_litellm_executor()
+
+ anthropic_client = HTTPHandler()
+ anthropic_model_response = get_model_response(
+ nonstreaming_anthropic_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ anthropic_client,
+ "post",
+ return_value=anthropic_model_response,
+ ):
+ with start_transaction(name="test claude-3-opus-20240229"):
+ litellm.completion(
+ model="claude-3-opus-20240229",
+ messages=messages,
+ client=anthropic_client,
+ api_key="test-key",
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ _reset_litellm_executor()
+
+ gemini_client = HTTPHandler()
+ gemini_model_response = get_model_response(
+ nonstreaming_google_genai_model_response,
+ serialize_pydantic=True,
+ )
+
+ with mock.patch.object(
+ gemini_client,
+ "post",
+ return_value=gemini_model_response,
+ ):
+ with start_transaction(name="test gemini/gemini-pro"):
+ litellm.completion(
+ model="gemini/gemini-pro",
+ messages=messages,
+ client=gemini_client,
+ api_key="test-key",
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ assert len(events) == 3
+
+ for i in range(3):
+ span = events[i]["spans"][0]
+ # The provider should be detected by litellm.get_llm_provider
+ assert SPANDATA.GEN_AI_SYSTEM in span["data"]
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_multiple_providers(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+ nonstreaming_anthropic_model_response,
+ nonstreaming_google_genai_model_response,
+):
+ """Test that the integration correctly identifies different providers."""
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ openai_client = AsyncOpenAI(api_key="test-key")
+ openai_model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ openai_client.completions._client._client,
+ "send",
+ return_value=openai_model_response,
+ ):
+ with start_transaction(name="test gpt-3.5-turbo"):
+ await litellm.acompletion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=openai_client,
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ _reset_litellm_executor()
+
+ anthropic_client = AsyncHTTPHandler()
+ anthropic_model_response = get_model_response(
+ nonstreaming_anthropic_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "True"},
+ )
+
+ with mock.patch.object(
+ anthropic_client,
+ "post",
+ return_value=anthropic_model_response,
+ ):
+ with start_transaction(name="test claude-3-opus-20240229"):
+ await litellm.acompletion(
+ model="claude-3-opus-20240229",
+ messages=messages,
+ client=anthropic_client,
+ api_key="test-key",
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ _reset_litellm_executor()
+
+ gemini_client = AsyncHTTPHandler()
+ gemini_model_response = get_model_response(
+ nonstreaming_google_genai_model_response,
+ serialize_pydantic=True,
+ )
+
+ with mock.patch.object(
+ gemini_client,
+ "post",
+ return_value=gemini_model_response,
+ ):
+ with start_transaction(name="test gemini/gemini-pro"):
+ await litellm.acompletion(
+ model="gemini/gemini-pro",
+ messages=messages,
+ client=gemini_client,
+ api_key="test-key",
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ assert len(events) == 3
+
+ for i in range(3):
+ span = events[i]["spans"][0]
+ # The provider should be detected by litellm.get_llm_provider
+ assert SPANDATA.GEN_AI_SYSTEM in span["data"]
+
+
+def test_additional_parameters(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ """Test that additional parameters are captured."""
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ temperature=0.7,
+ max_tokens=100,
+ top_p=0.9,
+ frequency_penalty=0.5,
+ presence_penalty=0.5,
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ (event,) = events
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.5
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.5
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_additional_parameters(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ """Test that additional parameters are captured."""
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ await litellm.acompletion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ temperature=0.7,
+ max_tokens=100,
+ top_p=0.9,
+ frequency_penalty=0.5,
+ presence_penalty=0.5,
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ (event,) = events
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.5
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.5
+
+
+def test_no_integration(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ """Test that when integration is not enabled, callbacks don't break."""
+ sentry_init(
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ (event,) = events
+ # Should still have the transaction, but no child spans since integration is off
+ assert event["type"] == "transaction"
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 0
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_no_integration(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ """Test that when integration is not enabled, callbacks don't break."""
+ sentry_init(
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ await litellm.acompletion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ client=client,
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ (event,) = events
+ # Should still have the transaction, but no child spans since integration is off
+ assert event["type"] == "transaction"
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 0
+
+
+def test_response_without_usage(sentry_init, capture_events):
+ """Test handling of responses without usage information."""
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ messages = [{"role": "user", "content": "Hello!"}]
+
+ # Create a mock response without usage
+ mock_response = type(
+ "obj",
+ (object,),
+ {
+ "model": "gpt-3.5-turbo",
+ "choices": [MockChoice()],
+ },
+ )()
+
+ with start_transaction(name="litellm test"):
+ kwargs = {
+ "model": "gpt-3.5-turbo",
+ "messages": messages,
+ }
+
+ _input_callback(kwargs)
+ _success_callback(
+ kwargs,
+ mock_response,
+ datetime.now(),
+ datetime.now(),
+ )
+
+ (event,) = events
+ (span,) = event["spans"]
+
+ # Span should still be created even without usage info
+ assert span["op"] == OP.GEN_AI_CHAT
+ assert span["description"] == "chat gpt-3.5-turbo"
+
+
+def test_integration_setup(sentry_init):
+ """Test that the integration sets up the callbacks correctly."""
+ sentry_init(
+ integrations=[LiteLLMIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Check that callbacks are registered
+ assert _input_callback in (litellm.input_callback or [])
+ assert _success_callback in (litellm.success_callback or [])
+ assert _failure_callback in (litellm.failure_callback or [])
+
+
+def test_litellm_message_truncation(sentry_init, capture_events):
+ """Test that large messages are truncated properly in LiteLLM integration."""
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ large_content = (
+ "This is a very long message that will exceed our size limits. " * 1000
+ )
+ messages = [
+ {"role": "user", "content": "small message 1"},
+ {"role": "assistant", "content": large_content},
+ {"role": "user", "content": large_content},
+ {"role": "assistant", "content": "small message 4"},
+ {"role": "user", "content": "small message 5"},
+ ]
+ mock_response = MockCompletionResponse()
+
+ with start_transaction(name="litellm test"):
+ kwargs = {
+ "model": "gpt-3.5-turbo",
+ "messages": messages,
+ }
+
+ _input_callback(kwargs)
+ _success_callback(
+ kwargs,
+ mock_response,
+ datetime.now(),
+ datetime.now(),
+ )
+
+ assert len(events) > 0
+ tx = events[0]
+ assert tx["type"] == "transaction"
+
+ chat_spans = [
+ span for span in tx.get("spans", []) if span.get("op") == OP.GEN_AI_CHAT
+ ]
+ assert len(chat_spans) > 0
+
+ chat_span = chat_spans[0]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in chat_span["data"]
+
+ messages_data = chat_span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert isinstance(messages_data, str)
+
+ parsed_messages = json.loads(messages_data)
+ assert isinstance(parsed_messages, list)
+ assert len(parsed_messages) == 1
+ assert "small message 5" in str(parsed_messages[0])
+ assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5
+
+
+IMAGE_DATA = b"fake_image_data_12345"
+IMAGE_B64 = base64.b64encode(IMAGE_DATA).decode("utf-8")
+IMAGE_DATA_URI = f"data:image/png;base64,{IMAGE_B64}"
+
+
+def test_binary_content_encoding_image_url(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Look at this image:"},
+ {
+ "type": "image_url",
+ "image_url": {"url": IMAGE_DATA_URI, "detail": "high"},
+ },
+ ],
+ }
+ ]
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ litellm.completion(
+ model="gpt-4-vision-preview",
+ messages=messages,
+ client=client,
+ custom_llm_provider="openai",
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ (event,) = events
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+ messages_data = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ blob_item = next(
+ (
+ item
+ for msg in messages_data
+ if "content" in msg
+ for item in msg["content"]
+ if item.get("type") == "blob"
+ ),
+ None,
+ )
+ assert blob_item is not None
+ assert blob_item["modality"] == "image"
+ assert blob_item["mime_type"] == "image/png"
+ assert (
+ IMAGE_B64 in blob_item["content"]
+ or blob_item["content"] == BLOB_DATA_SUBSTITUTE
+ )
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_binary_content_encoding_image_url(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Look at this image:"},
+ {
+ "type": "image_url",
+ "image_url": {"url": IMAGE_DATA_URI, "detail": "high"},
+ },
+ ],
+ }
+ ]
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ await litellm.acompletion(
+ model="gpt-4-vision-preview",
+ messages=messages,
+ client=client,
+ custom_llm_provider="openai",
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ (event,) = events
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+ messages_data = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ blob_item = next(
+ (
+ item
+ for msg in messages_data
+ if "content" in msg
+ for item in msg["content"]
+ if item.get("type") == "blob"
+ ),
+ None,
+ )
+ assert blob_item is not None
+ assert blob_item["modality"] == "image"
+ assert blob_item["mime_type"] == "image/png"
+ assert (
+ IMAGE_B64 in blob_item["content"]
+ or blob_item["content"] == BLOB_DATA_SUBSTITUTE
+ )
+
+
+def test_binary_content_encoding_mixed_content(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Here is an image:"},
+ {
+ "type": "image_url",
+ "image_url": {"url": IMAGE_DATA_URI},
+ },
+ {"type": "text", "text": "What do you see?"},
+ ],
+ }
+ ]
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ litellm.completion(
+ model="gpt-4-vision-preview",
+ messages=messages,
+ client=client,
+ custom_llm_provider="openai",
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ (event,) = events
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+ messages_data = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ content_items = [
+ item for msg in messages_data if "content" in msg for item in msg["content"]
+ ]
+ assert any(item.get("type") == "text" for item in content_items)
+ assert any(item.get("type") == "blob" for item in content_items)
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_binary_content_encoding_mixed_content(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Here is an image:"},
+ {
+ "type": "image_url",
+ "image_url": {"url": IMAGE_DATA_URI},
+ },
+ {"type": "text", "text": "What do you see?"},
+ ],
+ }
+ ]
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ await litellm.acompletion(
+ model="gpt-4-vision-preview",
+ messages=messages,
+ client=client,
+ custom_llm_provider="openai",
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ (event,) = events
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+ messages_data = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ content_items = [
+ item for msg in messages_data if "content" in msg for item in msg["content"]
+ ]
+ assert any(item.get("type") == "text" for item in content_items)
+ assert any(item.get("type") == "blob" for item in content_items)
+
+
+def test_binary_content_encoding_uri_type(
+ reset_litellm_executor,
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": "https://example.com/image.jpg"},
+ }
+ ],
+ }
+ ]
+ client = OpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ litellm.completion(
+ model="gpt-4-vision-preview",
+ messages=messages,
+ client=client,
+ custom_llm_provider="openai",
+ )
+
+ litellm_utils.executor.shutdown(wait=True)
+
+ (event,) = events
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+ messages_data = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ uri_item = next(
+ (
+ item
+ for msg in messages_data
+ if "content" in msg
+ for item in msg["content"]
+ if item.get("type") == "uri"
+ ),
+ None,
+ )
+ assert uri_item is not None
+ assert uri_item["uri"] == "https://example.com/image.jpg"
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_async_binary_content_encoding_uri_type(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ nonstreaming_chat_completions_model_response,
+):
+ sentry_init(
+ integrations=[LiteLLMIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": "https://example.com/image.jpg"},
+ }
+ ],
+ }
+ ]
+ client = AsyncOpenAI(api_key="test-key")
+
+ model_response = get_model_response(
+ nonstreaming_chat_completions_model_response,
+ serialize_pydantic=True,
+ request_headers={"X-Stainless-Raw-Response": "true"},
+ )
+
+ with mock.patch.object(
+ client.completions._client._client,
+ "send",
+ return_value=model_response,
+ ):
+ with start_transaction(name="litellm test"):
+ await litellm.acompletion(
+ model="gpt-4-vision-preview",
+ messages=messages,
+ client=client,
+ custom_llm_provider="openai",
+ )
+
+ await GLOBAL_LOGGING_WORKER.flush()
+ await asyncio.sleep(0.5)
+
+ (event,) = events
+ chat_spans = list(
+ x
+ for x in event["spans"]
+ if x["op"] == OP.GEN_AI_CHAT and x["origin"] == "auto.ai.litellm"
+ )
+ assert len(chat_spans) == 1
+ span = chat_spans[0]
+ messages_data = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ uri_item = next(
+ (
+ item
+ for msg in messages_data
+ if "content" in msg
+ for item in msg["content"]
+ if item.get("type") == "uri"
+ ),
+ None,
+ )
+ assert uri_item is not None
+ assert uri_item["uri"] == "https://example.com/image.jpg"
+
+
+def test_convert_message_parts_direct():
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Hello"},
+ {
+ "type": "image_url",
+ "image_url": {"url": IMAGE_DATA_URI},
+ },
+ ],
+ }
+ ]
+ converted = _convert_message_parts(messages)
+ blob_item = next(
+ item for item in converted[0]["content"] if item.get("type") == "blob"
+ )
+ assert blob_item["modality"] == "image"
+ assert blob_item["mime_type"] == "image/png"
+ assert IMAGE_B64 in blob_item["content"]
+
+
+def test_convert_message_parts_does_not_mutate_original():
+ """Ensure _convert_message_parts does not mutate the original messages."""
+ original_url = IMAGE_DATA_URI
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": original_url},
+ },
+ ],
+ }
+ ]
+ _convert_message_parts(messages)
+ # Original should be unchanged
+ assert messages[0]["content"][0]["type"] == "image_url"
+ assert messages[0]["content"][0]["image_url"]["url"] == original_url
+
+
+def test_convert_message_parts_data_url_without_base64():
+ """Data URLs without ;base64, marker are still inline data and should be blobs."""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/png,rawdata"},
+ },
+ ],
+ }
+ ]
+ converted = _convert_message_parts(messages)
+ blob_item = converted[0]["content"][0]
+ # Data URIs (with or without base64 encoding) contain inline data and should be blobs
+ assert blob_item["type"] == "blob"
+ assert blob_item["modality"] == "image"
+ assert blob_item["mime_type"] == "image/png"
+ assert blob_item["content"] == "rawdata"
+
+
+def test_convert_message_parts_image_url_none():
+ """image_url being None should not crash."""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": None,
+ },
+ ],
+ }
+ ]
+ converted = _convert_message_parts(messages)
+ # Should return item unchanged
+ assert converted[0]["content"][0]["type"] == "image_url"
+
+
+def test_convert_message_parts_image_url_missing_url():
+ """image_url missing the url key should not crash."""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"detail": "high"},
+ },
+ ],
+ }
+ ]
+ converted = _convert_message_parts(messages)
+ # Should return item unchanged
+ assert converted[0]["content"][0]["type"] == "image_url"
diff --git a/tests/integrations/litestar/__init__.py b/tests/integrations/litestar/__init__.py
new file mode 100644
index 0000000000..3a4a6235de
--- /dev/null
+++ b/tests/integrations/litestar/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("litestar")
diff --git a/tests/integrations/litestar/test_litestar.py b/tests/integrations/litestar/test_litestar.py
new file mode 100644
index 0000000000..b064c17112
--- /dev/null
+++ b/tests/integrations/litestar/test_litestar.py
@@ -0,0 +1,493 @@
+from __future__ import annotations
+import functools
+
+from litestar.exceptions import HTTPException
+import pytest
+
+from sentry_sdk import capture_message
+from sentry_sdk.integrations.litestar import LitestarIntegration
+
+from typing import Any
+
+from litestar import Litestar, get, Controller
+from litestar.logging.config import LoggingConfig
+from litestar.middleware import AbstractMiddleware
+from litestar.middleware.logging import LoggingMiddlewareConfig
+from litestar.middleware.rate_limit import RateLimitConfig
+from litestar.middleware.session.server_side import ServerSideSessionConfig
+from litestar.testing import TestClient
+
+from tests.integrations.conftest import parametrize_test_configurable_status_codes
+
+
+def litestar_app_factory(middleware=None, debug=True, exception_handlers=None):
+ class MyController(Controller):
+ path = "/controller"
+
+ @get("/error")
+ async def controller_error(self) -> None:
+ raise Exception("Whoa")
+
+ @get("/some_url")
+ async def homepage_handler() -> "dict[str, Any]":
+ 1 / 0
+ return {"status": "ok"}
+
+ @get("/custom_error", name="custom_name")
+ async def custom_error() -> Any:
+ raise Exception("Too Hot")
+
+ @get("/message")
+ async def message() -> "dict[str, Any]":
+ capture_message("hi")
+ return {"status": "ok"}
+
+ @get("/message/{message_id:str}")
+ async def message_with_id() -> "dict[str, Any]":
+ capture_message("hi")
+ return {"status": "ok"}
+
+ logging_config = LoggingConfig()
+
+ app = Litestar(
+ route_handlers=[
+ homepage_handler,
+ custom_error,
+ message,
+ message_with_id,
+ MyController,
+ ],
+ debug=debug,
+ middleware=middleware,
+ logging_config=logging_config,
+ exception_handlers=exception_handlers,
+ )
+
+ return app
+
+
+@pytest.mark.parametrize(
+ "test_url,expected_error,expected_message,expected_tx_name",
+ [
+ (
+ "/some_url",
+ ZeroDivisionError,
+ "division by zero",
+ "tests.integrations.litestar.test_litestar.litestar_app_factory..homepage_handler",
+ ),
+ (
+ "/custom_error",
+ Exception,
+ "Too Hot",
+ "custom_name",
+ ),
+ (
+ "/controller/error",
+ Exception,
+ "Whoa",
+ "tests.integrations.litestar.test_litestar.litestar_app_factory..MyController.controller_error",
+ ),
+ ],
+)
+def test_catch_exceptions(
+ sentry_init,
+ capture_exceptions,
+ capture_events,
+ test_url,
+ expected_error,
+ expected_message,
+ expected_tx_name,
+):
+ sentry_init(integrations=[LitestarIntegration()])
+ litestar_app = litestar_app_factory()
+ exceptions = capture_exceptions()
+ events = capture_events()
+
+ client = TestClient(litestar_app)
+ try:
+ client.get(test_url)
+ except Exception:
+ pass
+
+ (exc,) = exceptions
+ assert isinstance(exc, expected_error)
+ assert str(exc) == expected_message
+
+ (event,) = events
+ assert expected_tx_name in event["transaction"]
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "litestar"
+
+
+def test_middleware_spans(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[LitestarIntegration()],
+ )
+
+ logging_config = LoggingMiddlewareConfig()
+ session_config = ServerSideSessionConfig()
+ rate_limit_config = RateLimitConfig(rate_limit=("hour", 5))
+
+ litestar_app = litestar_app_factory(
+ middleware=[
+ session_config.middleware,
+ logging_config.middleware,
+ rate_limit_config.middleware,
+ ]
+ )
+ events = capture_events()
+
+ client = TestClient(
+ litestar_app, raise_server_exceptions=False, base_url="http://testserver.local"
+ )
+ client.get("/message")
+
+ (_, transaction_event) = events
+
+ expected = {"SessionMiddleware", "LoggingMiddleware", "RateLimitMiddleware"}
+ found = set()
+
+ litestar_spans = (
+ span
+ for span in transaction_event["spans"]
+ if span["op"] == "middleware.litestar"
+ )
+
+ for span in litestar_spans:
+ assert span["description"] in expected
+ assert span["description"] not in found
+ found.add(span["description"])
+ assert span["description"] == span["tags"]["litestar.middleware_name"]
+
+
+def test_middleware_callback_spans(sentry_init, capture_events):
+ class SampleMiddleware(AbstractMiddleware):
+ async def __call__(self, scope, receive, send) -> None:
+ async def do_stuff(message):
+ if message["type"] == "http.response.start":
+ # do something here.
+ pass
+ await send(message)
+
+ await self.app(scope, receive, do_stuff)
+
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[LitestarIntegration()],
+ )
+ litestar_app = litestar_app_factory(middleware=[SampleMiddleware])
+ events = capture_events()
+
+ client = TestClient(litestar_app, raise_server_exceptions=False)
+ client.get("/message")
+
+ (_, transaction_events) = events
+
+ expected_litestar_spans = [
+ {
+ "op": "middleware.litestar",
+ "description": "SampleMiddleware",
+ "tags": {"litestar.middleware_name": "SampleMiddleware"},
+ },
+ {
+ "op": "middleware.litestar.send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
+ "tags": {"litestar.middleware_name": "SampleMiddleware"},
+ },
+ {
+ "op": "middleware.litestar.send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
+ "tags": {"litestar.middleware_name": "SampleMiddleware"},
+ },
+ ]
+
+ def is_matching_span(expected_span, actual_span):
+ return (
+ expected_span["op"] == actual_span["op"]
+ and expected_span["description"] == actual_span["description"]
+ and expected_span["tags"] == actual_span["tags"]
+ )
+
+ actual_litestar_spans = list(
+ span
+ for span in transaction_events["spans"]
+ if "middleware.litestar" in span["op"]
+ )
+ assert len(actual_litestar_spans) == 3
+
+ for expected_span in expected_litestar_spans:
+ assert any(
+ is_matching_span(expected_span, actual_span)
+ for actual_span in actual_litestar_spans
+ )
+
+
+def test_middleware_receive_send(sentry_init, capture_events):
+ class SampleReceiveSendMiddleware(AbstractMiddleware):
+ async def __call__(self, scope, receive, send):
+ message = await receive()
+ assert message
+ assert message["type"] == "http.request"
+
+ send_output = await send({"type": "something-unimportant"})
+ assert send_output is None
+
+ await self.app(scope, receive, send)
+
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[LitestarIntegration()],
+ )
+ litestar_app = litestar_app_factory(middleware=[SampleReceiveSendMiddleware])
+
+ client = TestClient(litestar_app, raise_server_exceptions=False)
+ # See SampleReceiveSendMiddleware.__call__ above for assertions of correct behavior
+ client.get("/message")
+
+
+def test_middleware_partial_receive_send(sentry_init, capture_events):
+ class SamplePartialReceiveSendMiddleware(AbstractMiddleware):
+ async def __call__(self, scope, receive, send):
+ message = await receive()
+ assert message
+ assert message["type"] == "http.request"
+
+ send_output = await send({"type": "something-unimportant"})
+ assert send_output is None
+
+ async def my_receive(*args, **kwargs):
+ pass
+
+ async def my_send(*args, **kwargs):
+ pass
+
+ partial_receive = functools.partial(my_receive)
+ partial_send = functools.partial(my_send)
+
+ await self.app(scope, partial_receive, partial_send)
+
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[LitestarIntegration()],
+ )
+ litestar_app = litestar_app_factory(middleware=[SamplePartialReceiveSendMiddleware])
+ events = capture_events()
+
+ client = TestClient(litestar_app, raise_server_exceptions=False)
+ # See SamplePartialReceiveSendMiddleware.__call__ above for assertions of correct behavior
+ client.get("/message")
+
+ (_, transaction_events) = events
+
+ expected_litestar_spans = [
+ {
+ "op": "middleware.litestar",
+ "description": "SamplePartialReceiveSendMiddleware",
+ "tags": {"litestar.middleware_name": "SamplePartialReceiveSendMiddleware"},
+ },
+ {
+ "op": "middleware.litestar.receive",
+ "description": "TestClientTransport.create_receive..receive",
+ "tags": {"litestar.middleware_name": "SamplePartialReceiveSendMiddleware"},
+ },
+ {
+ "op": "middleware.litestar.send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
+ "tags": {"litestar.middleware_name": "SamplePartialReceiveSendMiddleware"},
+ },
+ ]
+
+ def is_matching_span(expected_span, actual_span):
+ return (
+ expected_span["op"] == actual_span["op"]
+ and actual_span["description"].startswith(expected_span["description"])
+ and expected_span["tags"] == actual_span["tags"]
+ )
+
+ actual_litestar_spans = list(
+ span
+ for span in transaction_events["spans"]
+ if "middleware.litestar" in span["op"]
+ )
+ assert len(actual_litestar_spans) == 3
+
+ for expected_span in expected_litestar_spans:
+ assert any(
+ is_matching_span(expected_span, actual_span)
+ for actual_span in actual_litestar_spans
+ )
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[LitestarIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ logging_config = LoggingMiddlewareConfig()
+ session_config = ServerSideSessionConfig()
+ rate_limit_config = RateLimitConfig(rate_limit=("hour", 5))
+
+ litestar_app = litestar_app_factory(
+ middleware=[
+ session_config.middleware,
+ logging_config.middleware,
+ rate_limit_config.middleware,
+ ]
+ )
+ events = capture_events()
+
+ client = TestClient(
+ litestar_app, raise_server_exceptions=False, base_url="http://testserver.local"
+ )
+ client.get("/message")
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.litestar"
+ for span in event["spans"]:
+ assert span["origin"] == "auto.http.litestar"
+
+
+@pytest.mark.parametrize(
+ "is_send_default_pii",
+ [
+ True,
+ False,
+ ],
+ ids=[
+ "send_default_pii=True",
+ "send_default_pii=False",
+ ],
+)
+def test_litestar_scope_user_on_exception_event(
+ sentry_init, capture_exceptions, capture_events, is_send_default_pii
+):
+ class TestUserMiddleware(AbstractMiddleware):
+ async def __call__(self, scope, receive, send):
+ scope["user"] = {
+ "email": "lennon@thebeatles.com",
+ "username": "john",
+ "id": "1",
+ }
+ await self.app(scope, receive, send)
+
+ sentry_init(
+ integrations=[LitestarIntegration()], send_default_pii=is_send_default_pii
+ )
+ litestar_app = litestar_app_factory(middleware=[TestUserMiddleware])
+ exceptions = capture_exceptions()
+ events = capture_events()
+
+ # This request intentionally raises an exception
+ client = TestClient(litestar_app)
+ try:
+ client.get("/some_url")
+ except Exception:
+ pass
+
+ assert len(exceptions) == 1
+ assert len(events) == 1
+ (event,) = events
+
+ if is_send_default_pii:
+ assert "user" in event
+ assert event["user"] == {
+ "email": "lennon@thebeatles.com",
+ "username": "john",
+ "id": "1",
+ }
+ else:
+ assert "user" not in event
+
+
+@parametrize_test_configurable_status_codes
+def test_configurable_status_codes_handler(
+ sentry_init,
+ capture_events,
+ failed_request_status_codes,
+ status_code,
+ expected_error,
+):
+ integration_kwargs = (
+ {"failed_request_status_codes": failed_request_status_codes}
+ if failed_request_status_codes is not None
+ else {}
+ )
+ sentry_init(integrations=[LitestarIntegration(**integration_kwargs)])
+
+ events = capture_events()
+
+ @get("/error")
+ async def error() -> None:
+ raise HTTPException(status_code=status_code)
+
+ app = Litestar([error])
+ client = TestClient(app)
+ client.get("/error")
+
+ assert len(events) == int(expected_error)
+
+
+@parametrize_test_configurable_status_codes
+def test_configurable_status_codes_middleware(
+ sentry_init,
+ capture_events,
+ failed_request_status_codes,
+ status_code,
+ expected_error,
+):
+ integration_kwargs = (
+ {"failed_request_status_codes": failed_request_status_codes}
+ if failed_request_status_codes is not None
+ else {}
+ )
+ sentry_init(integrations=[LitestarIntegration(**integration_kwargs)])
+
+ events = capture_events()
+
+ def create_raising_middleware(app):
+ async def raising_middleware(scope, receive, send):
+ raise HTTPException(status_code=status_code)
+
+ return raising_middleware
+
+ @get("/error")
+ async def error() -> None: ...
+
+ app = Litestar([error], middleware=[create_raising_middleware])
+ client = TestClient(app)
+ client.get("/error")
+
+ assert len(events) == int(expected_error)
+
+
+def test_catch_non_http_exceptions_in_middleware(
+ sentry_init,
+ capture_events,
+):
+ sentry_init(integrations=[LitestarIntegration()])
+
+ events = capture_events()
+
+ def create_raising_middleware(app):
+ async def raising_middleware(scope, receive, send):
+ raise RuntimeError("Too Hot")
+
+ return raising_middleware
+
+ @get("/error")
+ async def error() -> None: ...
+
+ app = Litestar([error], middleware=[create_raising_middleware])
+ client = TestClient(app)
+
+ try:
+ client.get("/error")
+ except RuntimeError:
+ pass
+
+ assert len(events) == 1
+ event_exception = events[0]["exception"]["values"][0]
+ assert event_exception["type"] == "RuntimeError"
+ assert event_exception["value"] == "Too Hot"
diff --git a/tests/integrations/logging/test_logging.py b/tests/integrations/logging/test_logging.py
index de1c55e26f..5e384bd3db 100644
--- a/tests/integrations/logging/test_logging.py
+++ b/tests/integrations/logging/test_logging.py
@@ -1,11 +1,18 @@
-# coding: utf-8
-import sys
-
-import pytest
import logging
import warnings
-from sentry_sdk.integrations.logging import LoggingIntegration, ignore_logger
+import pytest
+
+from sentry_sdk import get_client
+from sentry_sdk.consts import VERSION
+from sentry_sdk.integrations.logging import (
+ LoggingIntegration,
+ ignore_logger,
+ ignore_logger_for_sentry_logs,
+ unignore_logger,
+ unignore_logger_for_sentry_logs,
+)
+from tests.test_logs import envelopes_to_logs
other_logger = logging.getLogger("testfoo")
logger = logging.getLogger(__name__)
@@ -28,6 +35,7 @@ def test_logging_works_with_many_loggers(sentry_init, capture_events, logger):
assert event["level"] == "fatal"
assert not event["logentry"]["params"]
assert event["logentry"]["message"] == "LOL"
+ assert event["logentry"]["formatted"] == "LOL"
assert any(crumb["message"] == "bread" for crumb in event["breadcrumbs"]["values"])
@@ -79,12 +87,18 @@ def test_logging_extra_data_integer_keys(sentry_init, capture_events):
assert event["extra"] == {"1": 1}
-@pytest.mark.xfail(sys.version_info[:2] == (3, 4), reason="buggy logging module")
-def test_logging_stack(sentry_init, capture_events):
+@pytest.mark.parametrize(
+ "enable_stack_trace_kwarg",
+ (
+ pytest.param({"exc_info": True}, id="exc_info"),
+ pytest.param({"stack_info": True}, id="stack_info"),
+ ),
+)
+def test_logging_stack_trace(sentry_init, capture_events, enable_stack_trace_kwarg):
sentry_init(integrations=[LoggingIntegration()], default_integrations=False)
events = capture_events()
- logger.error("first", exc_info=True)
+ logger.error("first", **enable_stack_trace_kwarg)
logger.error("second")
(
@@ -108,6 +122,7 @@ def test_logging_level(sentry_init, capture_events):
(event,) = events
assert event["level"] == "error"
assert event["logentry"]["message"] == "hi"
+ assert event["logentry"]["formatted"] == "hi"
del events[:]
@@ -128,9 +143,7 @@ def test_custom_log_level_names(sentry_init, capture_events):
}
# set custom log level names
- # fmt: off
- logging.addLevelName(logging.DEBUG, u"custom level debüg: ")
- # fmt: on
+ logging.addLevelName(logging.DEBUG, "custom level debüg: ")
logging.addLevelName(logging.INFO, "")
logging.addLevelName(logging.WARN, "custom level warn: ")
logging.addLevelName(logging.WARNING, "custom level warning: ")
@@ -150,6 +163,7 @@ def test_custom_log_level_names(sentry_init, capture_events):
assert events
assert events[0]["level"] == sentry_level
assert events[0]["logentry"]["message"] == "Trying level %s"
+ assert events[0]["logentry"]["formatted"] == f"Trying level {logging_level}"
assert events[0]["logentry"]["params"] == [logging_level]
del events[:]
@@ -175,6 +189,7 @@ def filter(self, record):
(event,) = events
assert event["logentry"]["message"] == "hi"
+ assert event["logentry"]["formatted"] == "hi"
def test_logging_captured_warnings(sentry_init, capture_events, recwarn):
@@ -185,21 +200,27 @@ def test_logging_captured_warnings(sentry_init, capture_events, recwarn):
events = capture_events()
logging.captureWarnings(True)
- warnings.warn("first")
- warnings.warn("second")
+ warnings.warn("first", stacklevel=2)
+ warnings.warn("second", stacklevel=2)
logging.captureWarnings(False)
- warnings.warn("third")
+ warnings.warn("third", stacklevel=2)
assert len(events) == 2
assert events[0]["level"] == "warning"
# Captured warnings start with the path where the warning was raised
assert "UserWarning: first" in events[0]["logentry"]["message"]
+ assert "UserWarning: first" in events[0]["logentry"]["formatted"]
+ # For warnings, the message and formatted message are the same
+ assert events[0]["logentry"]["message"] == events[0]["logentry"]["formatted"]
assert events[0]["logentry"]["params"] == []
assert events[1]["level"] == "warning"
assert "UserWarning: second" in events[1]["logentry"]["message"]
+ assert "UserWarning: second" in events[1]["logentry"]["formatted"]
+ # For warnings, the message and formatted message are the same
+ assert events[1]["logentry"]["message"] == events[1]["logentry"]["formatted"]
assert events[1]["logentry"]["params"] == []
# Using recwarn suppresses the "third" warning in the test output
@@ -207,22 +228,37 @@ def test_logging_captured_warnings(sentry_init, capture_events, recwarn):
assert str(recwarn[0].message) == "third"
-def test_ignore_logger(sentry_init, capture_events):
+def test_ignore_logger(sentry_init, capture_events, request):
sentry_init(integrations=[LoggingIntegration()], default_integrations=False)
events = capture_events()
ignore_logger("testfoo")
+ request.addfinalizer(lambda: unignore_logger("testfoo"))
other_logger.error("hi")
assert not events
-def test_ignore_logger_wildcard(sentry_init, capture_events):
+def test_ignore_logger_whitespace_padding(sentry_init, capture_events, request):
+ """Here we test insensitivity to whitespace padding of ignored loggers"""
+ sentry_init(integrations=[LoggingIntegration()], default_integrations=False)
+ events = capture_events()
+
+ ignore_logger("testfoo")
+ request.addfinalizer(lambda: unignore_logger("testfoo"))
+
+ padded_logger = logging.getLogger(" testfoo ")
+ padded_logger.error("hi")
+ assert not events
+
+
+def test_ignore_logger_wildcard(sentry_init, capture_events, request):
sentry_init(integrations=[LoggingIntegration()], default_integrations=False)
events = capture_events()
ignore_logger("testfoo.*")
+ request.addfinalizer(lambda: unignore_logger("testfoo.*"))
nested_logger = logging.getLogger("testfoo.submodule")
@@ -232,3 +268,375 @@ def test_ignore_logger_wildcard(sentry_init, capture_events):
(event,) = events
assert event["logentry"]["message"] == "hi"
+ assert event["logentry"]["formatted"] == "hi"
+
+
+def test_ignore_logger_does_not_affect_sentry_logs(
+ sentry_init, capture_envelopes, request
+):
+ """ignore_logger should suppress events/breadcrumbs but not Sentry Logs."""
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ ignore_logger("testfoo")
+ request.addfinalizer(lambda: unignore_logger("testfoo"))
+
+ other_logger.error("hi")
+ get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+ assert len(logs) == 1
+ assert logs[0]["body"] == "hi"
+
+
+def test_ignore_logger_for_sentry_logs(sentry_init, capture_envelopes, request):
+ """ignore_logger_for_sentry_logs should suppress Sentry Logs but not events."""
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ ignore_logger_for_sentry_logs("testfoo")
+ request.addfinalizer(lambda: unignore_logger_for_sentry_logs("testfoo"))
+
+ other_logger.error("hi")
+ get_client().flush()
+
+ # Event should still be captured
+ event_envelopes = [e for e in envelopes if e.items[0].type == "event"]
+ assert len(event_envelopes) == 1
+
+ # But no Sentry Logs
+ logs = envelopes_to_logs(envelopes)
+ assert len(logs) == 0
+
+
+def test_logging_dictionary_interpolation(sentry_init, capture_events):
+ """Here we test an entire dictionary being interpolated into the log message."""
+ sentry_init(integrations=[LoggingIntegration()], default_integrations=False)
+ events = capture_events()
+
+ logger.error("this is a log with a dictionary %s", {"foo": "bar"})
+
+ (event,) = events
+ assert event["logentry"]["message"] == "this is a log with a dictionary %s"
+ assert (
+ event["logentry"]["formatted"]
+ == "this is a log with a dictionary {'foo': 'bar'}"
+ )
+ assert event["logentry"]["params"] == {"foo": "bar"}
+
+
+def test_logging_dictionary_args(sentry_init, capture_events):
+ """Here we test items from a dictionary being interpolated into the log message."""
+ sentry_init(integrations=[LoggingIntegration()], default_integrations=False)
+ events = capture_events()
+
+ logger.error(
+ "the value of foo is %(foo)s, and the value of bar is %(bar)s",
+ {"foo": "bar", "bar": "baz"},
+ )
+
+ (event,) = events
+ assert (
+ event["logentry"]["message"]
+ == "the value of foo is %(foo)s, and the value of bar is %(bar)s"
+ )
+ assert (
+ event["logentry"]["formatted"]
+ == "the value of foo is bar, and the value of bar is baz"
+ )
+ assert event["logentry"]["params"] == {"foo": "bar", "bar": "baz"}
+
+
+def test_sentry_logs_warning(sentry_init, capture_envelopes):
+ """
+ The python logger module should create 'warn' sentry logs if the flag is on.
+ """
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ python_logger.warning("this is %s a template %s", "1", "2")
+
+ get_client().flush()
+ logs = envelopes_to_logs(envelopes)
+ attrs = logs[0]["attributes"]
+ assert attrs["sentry.message.template"] == "this is %s a template %s"
+ assert "code.file.path" in attrs
+ assert "code.line.number" in attrs
+ assert attrs["logger.name"] == "test-logger"
+ assert attrs["sentry.environment"] == "production"
+ assert attrs["sentry.message.parameter.0"] == "1"
+ assert attrs["sentry.message.parameter.1"] == "2"
+ assert attrs["sentry.origin"] == "auto.log.stdlib"
+ assert logs[0]["severity_number"] == 13
+ assert logs[0]["severity_text"] == "warn"
+
+
+def test_sentry_logs_debug(sentry_init, capture_envelopes):
+ """
+ The python logger module should not create 'debug' sentry logs if the flag is on by default
+ """
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ python_logger.debug("this is %s a template %s", "1", "2")
+ get_client().flush()
+
+ assert len(envelopes) == 0
+
+
+def test_no_log_infinite_loop(sentry_init, capture_envelopes):
+ """
+ If 'debug' mode is true, and you set a low log level in the logging integration, there should be no infinite loops.
+ """
+ sentry_init(
+ enable_logs=True,
+ integrations=[LoggingIntegration(sentry_logs_level=logging.DEBUG)],
+ debug=True,
+ )
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ python_logger.debug("this is %s a template %s", "1", "2")
+ get_client().flush()
+
+ assert len(envelopes) == 1
+
+
+def test_logging_errors(sentry_init, capture_envelopes):
+ """
+ The python logger module should be able to log errors without erroring
+ """
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ python_logger.error(Exception("test exc 1"))
+ python_logger.error("error is %s", Exception("test exc 2"))
+ get_client().flush()
+
+ error_event_1 = envelopes[0].items[0].payload.json
+ assert error_event_1["level"] == "error"
+ error_event_2 = envelopes[1].items[0].payload.json
+ assert error_event_2["level"] == "error"
+
+ logs = envelopes_to_logs(envelopes)
+ assert logs[0]["severity_text"] == "error"
+ assert "sentry.message.template" not in logs[0]["attributes"]
+ assert "sentry.message.parameter.0" not in logs[0]["attributes"]
+ assert "code.line.number" in logs[0]["attributes"]
+
+ assert logs[1]["severity_text"] == "error"
+ assert logs[1]["attributes"]["sentry.message.template"] == "error is %s"
+ assert logs[1]["attributes"]["sentry.message.parameter.0"] in (
+ "Exception('test exc 2')",
+ "Exception('test exc 2',)", # py3.6
+ )
+ assert "code.line.number" in logs[1]["attributes"]
+
+ assert len(logs) == 2
+
+
+def test_log_strips_project_root(sentry_init, capture_envelopes):
+ """
+ The python logger should strip project roots from the log record path
+ """
+ sentry_init(
+ enable_logs=True,
+ project_root="/custom/test",
+ )
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ python_logger.handle(
+ logging.LogRecord(
+ name="test-logger",
+ level=logging.WARN,
+ pathname="/custom/test/blah/path.py",
+ lineno=123,
+ msg="This is a test log with a custom pathname",
+ args=(),
+ exc_info=None,
+ )
+ )
+ get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+ assert len(logs) == 1
+ attrs = logs[0]["attributes"]
+ assert attrs["code.file.path"] == "blah/path.py"
+
+
+def test_logger_with_all_attributes(sentry_init, capture_envelopes):
+ """
+ The python logger should be able to log all attributes, including extra data.
+ """
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ python_logger.warning(
+ "log #%d",
+ 1,
+ extra={"foo": "bar", "numeric": 42, "more_complex": {"nested": "data"}},
+ )
+ get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+
+ assert "span_id" in logs[0]
+ assert isinstance(logs[0]["span_id"], str)
+
+ attributes = logs[0]["attributes"]
+
+ assert "process.pid" in attributes
+ assert isinstance(attributes["process.pid"], int)
+ del attributes["process.pid"]
+
+ assert "sentry.release" in attributes
+ assert isinstance(attributes["sentry.release"], str)
+ del attributes["sentry.release"]
+
+ assert "server.address" in attributes
+ assert isinstance(attributes["server.address"], str)
+ del attributes["server.address"]
+
+ assert "thread.id" in attributes
+ assert isinstance(attributes["thread.id"], int)
+ del attributes["thread.id"]
+
+ assert "code.file.path" in attributes
+ assert isinstance(attributes["code.file.path"], str)
+ del attributes["code.file.path"]
+
+ assert "code.function.name" in attributes
+ assert isinstance(attributes["code.function.name"], str)
+ del attributes["code.function.name"]
+
+ assert "code.line.number" in attributes
+ assert isinstance(attributes["code.line.number"], int)
+ del attributes["code.line.number"]
+
+ assert "process.executable.name" in attributes
+ assert isinstance(attributes["process.executable.name"], str)
+ del attributes["process.executable.name"]
+
+ assert "thread.name" in attributes
+ assert isinstance(attributes["thread.name"], str)
+ del attributes["thread.name"]
+
+ assert attributes.pop("sentry.sdk.name").startswith("sentry.python")
+
+ # Assert on the remaining non-dynamic attributes.
+ assert attributes == {
+ "foo": "bar",
+ "numeric": 42,
+ "more_complex": "{'nested': 'data'}",
+ "logger.name": "test-logger",
+ "sentry.origin": "auto.log.stdlib",
+ "sentry.message.template": "log #%d",
+ "sentry.message.parameter.0": 1,
+ "sentry.environment": "production",
+ "sentry.sdk.version": VERSION,
+ "sentry.severity_number": 13,
+ "sentry.severity_text": "warn",
+ }
+
+
+def test_sentry_logs_named_parameters(sentry_init, capture_envelopes):
+ """
+ The python logger module should capture named parameters from dictionary arguments in Sentry logs.
+ """
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ python_logger.info(
+ "%(source)s call completed, %(input_tk)i input tk, %(output_tk)i output tk (model %(model)s, cost $%(cost).4f)",
+ {
+ "source": "test_source",
+ "input_tk": 100,
+ "output_tk": 50,
+ "model": "gpt-4",
+ "cost": 0.0234,
+ },
+ )
+
+ get_client().flush()
+ logs = envelopes_to_logs(envelopes)
+
+ assert len(logs) == 1
+ attrs = logs[0]["attributes"]
+
+ # Check that the template is captured
+ assert (
+ attrs["sentry.message.template"]
+ == "%(source)s call completed, %(input_tk)i input tk, %(output_tk)i output tk (model %(model)s, cost $%(cost).4f)"
+ )
+
+ # Check that dictionary arguments are captured as named parameters
+ assert attrs["sentry.message.parameter.source"] == "test_source"
+ assert attrs["sentry.message.parameter.input_tk"] == 100
+ assert attrs["sentry.message.parameter.output_tk"] == 50
+ assert attrs["sentry.message.parameter.model"] == "gpt-4"
+ assert attrs["sentry.message.parameter.cost"] == 0.0234
+
+ # Check other standard attributes
+ assert attrs["logger.name"] == "test-logger"
+ assert attrs["sentry.origin"] == "auto.log.stdlib"
+ assert logs[0]["severity_number"] == 9 # info level
+ assert logs[0]["severity_text"] == "info"
+
+
+def test_sentry_logs_named_parameters_complex_values(sentry_init, capture_envelopes):
+ """
+ The python logger module should handle complex values in named parameters using safe_repr.
+ """
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ complex_object = {"nested": {"data": [1, 2, 3]}, "tuple": (4, 5, 6)}
+ python_logger.warning(
+ "Processing %(simple)s with %(complex)s data",
+ {
+ "simple": "simple_value",
+ "complex": complex_object,
+ },
+ )
+
+ get_client().flush()
+ logs = envelopes_to_logs(envelopes)
+
+ assert len(logs) == 1
+ attrs = logs[0]["attributes"]
+
+ # Check that simple values are kept as-is
+ assert attrs["sentry.message.parameter.simple"] == "simple_value"
+
+ # Check that complex values are converted using safe_repr
+ assert "sentry.message.parameter.complex" in attrs
+ complex_param = attrs["sentry.message.parameter.complex"]
+ assert isinstance(complex_param, str)
+ assert "nested" in complex_param
+ assert "data" in complex_param
+
+
+def test_sentry_logs_no_parameters_no_template(sentry_init, capture_envelopes):
+ """
+ There shouldn't be a template if there are no parameters.
+ """
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ python_logger = logging.Logger("test-logger")
+ python_logger.warning("Warning about something without any parameters.")
+
+ get_client().flush()
+ logs = envelopes_to_logs(envelopes)
+
+ assert len(logs) == 1
+
+ attrs = logs[0]["attributes"]
+ assert "sentry.message.template" not in attrs
diff --git a/tests/integrations/loguru/__init__.py b/tests/integrations/loguru/__init__.py
new file mode 100644
index 0000000000..9d67fb3799
--- /dev/null
+++ b/tests/integrations/loguru/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("loguru")
diff --git a/tests/integrations/loguru/test_loguru.py b/tests/integrations/loguru/test_loguru.py
new file mode 100644
index 0000000000..66cc336de5
--- /dev/null
+++ b/tests/integrations/loguru/test_loguru.py
@@ -0,0 +1,587 @@
+from unittest.mock import MagicMock, patch
+import re
+
+import pytest
+from loguru import logger
+from loguru._recattrs import RecordFile, RecordLevel
+
+import sentry_sdk
+from sentry_sdk.consts import VERSION
+from sentry_sdk.integrations.loguru import LoguruIntegration, LoggingLevels
+from tests.test_logs import envelopes_to_logs
+
+logger.remove(0) # don't print to console
+
+
+@pytest.mark.parametrize(
+ "level,created_event,expected_sentry_level",
+ [
+ # None - no breadcrumb
+ # False - no event
+ # True - event created
+ (LoggingLevels.TRACE, None, "debug"),
+ (LoggingLevels.DEBUG, None, "debug"),
+ (LoggingLevels.INFO, False, "info"),
+ (LoggingLevels.SUCCESS, False, "info"),
+ (LoggingLevels.WARNING, False, "warning"),
+ (LoggingLevels.ERROR, True, "error"),
+ (LoggingLevels.CRITICAL, True, "critical"),
+ ],
+)
+@pytest.mark.parametrize("disable_breadcrumbs", [True, False])
+@pytest.mark.parametrize("disable_events", [True, False])
+def test_just_log(
+ sentry_init,
+ capture_events,
+ level,
+ created_event,
+ expected_sentry_level,
+ disable_breadcrumbs,
+ disable_events,
+ uninstall_integration,
+ request,
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ integrations=[
+ LoguruIntegration(
+ level=None if disable_breadcrumbs else LoggingLevels.INFO.value,
+ event_level=None if disable_events else LoggingLevels.ERROR.value,
+ )
+ ],
+ default_integrations=False,
+ )
+ events = capture_events()
+
+ getattr(logger, level.name.lower())("test")
+
+ expected_pattern = (
+ r" \| "
+ + r"{:9}".format(level.name.upper())
+ + r"\| tests\.integrations\.loguru\.test_loguru:test_just_log:\d+ - test"
+ )
+
+ if not created_event:
+ assert not events
+
+ breadcrumbs = sentry_sdk.get_isolation_scope()._breadcrumbs
+ if (
+ not disable_breadcrumbs and created_event is not None
+ ): # not None == not TRACE or DEBUG level
+ (breadcrumb,) = breadcrumbs
+ assert breadcrumb["level"] == expected_sentry_level
+ assert breadcrumb["category"] == "tests.integrations.loguru.test_loguru"
+ assert re.fullmatch(expected_pattern, breadcrumb["message"][23:])
+ else:
+ assert not breadcrumbs
+
+ return
+
+ if disable_events:
+ assert not events
+ return
+
+ (event,) = events
+ assert event["level"] == expected_sentry_level
+ assert event["logger"] == "tests.integrations.loguru.test_loguru"
+ assert re.fullmatch(expected_pattern, event["logentry"]["message"][23:])
+
+
+def test_breadcrumb_format(sentry_init, capture_events, uninstall_integration, request):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ integrations=[
+ LoguruIntegration(
+ level=LoggingLevels.INFO.value,
+ event_level=None,
+ breadcrumb_format="{message}",
+ )
+ ],
+ default_integrations=False,
+ )
+
+ logger.info("test")
+ formatted_message = "test"
+
+ breadcrumbs = sentry_sdk.get_isolation_scope()._breadcrumbs
+ (breadcrumb,) = breadcrumbs
+ assert breadcrumb["message"] == formatted_message
+
+
+def test_event_format(sentry_init, capture_events, uninstall_integration, request):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ integrations=[
+ LoguruIntegration(
+ level=None,
+ event_level=LoggingLevels.ERROR.value,
+ event_format="{message}",
+ )
+ ],
+ default_integrations=False,
+ )
+ events = capture_events()
+
+ logger.error("test")
+ formatted_message = "test"
+
+ (event,) = events
+ assert event["logentry"]["message"] == formatted_message
+
+
+def test_sentry_logs_warning(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.warning("this is {} a {}", "just", "template")
+
+ sentry_sdk.get_client().flush()
+ logs = envelopes_to_logs(envelopes)
+
+ attrs = logs[0]["attributes"]
+ assert "code.file.path" in attrs
+ assert "code.line.number" in attrs
+ assert attrs["logger.name"] == "tests.integrations.loguru.test_loguru"
+ assert attrs["sentry.environment"] == "production"
+ assert attrs["sentry.origin"] == "auto.log.loguru"
+ assert logs[0]["severity_number"] == 13
+ assert logs[0]["severity_text"] == "warn"
+
+
+def test_sentry_logs_debug(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.debug("this is %s a template %s", "1", "2")
+ sentry_sdk.get_client().flush()
+
+ assert len(envelopes) == 0
+
+
+def test_sentry_log_levels(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ integrations=[LoguruIntegration(sentry_logs_level=LoggingLevels.SUCCESS)],
+ enable_logs=True,
+ )
+ envelopes = capture_envelopes()
+
+ logger.trace("this is a log")
+ logger.debug("this is a log")
+ logger.info("this is a log")
+ logger.success("this is a log")
+ logger.warning("this is a log")
+ logger.error("this is a log")
+ logger.critical("this is a log")
+
+ sentry_sdk.get_client().flush()
+ logs = envelopes_to_logs(envelopes)
+ assert len(logs) == 4
+
+ assert logs[0]["severity_number"] == 11
+ assert logs[0]["severity_text"] == "info"
+ assert logs[1]["severity_number"] == 13
+ assert logs[1]["severity_text"] == "warn"
+ assert logs[2]["severity_number"] == 17
+ assert logs[2]["severity_text"] == "error"
+ assert logs[3]["severity_number"] == 21
+ assert logs[3]["severity_text"] == "fatal"
+
+
+def test_disable_loguru_logs(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ integrations=[LoguruIntegration(sentry_logs_level=None)],
+ enable_logs=True,
+ )
+ envelopes = capture_envelopes()
+
+ logger.trace("this is a log")
+ logger.debug("this is a log")
+ logger.info("this is a log")
+ logger.success("this is a log")
+ logger.warning("this is a log")
+ logger.error("this is a log")
+ logger.critical("this is a log")
+
+ sentry_sdk.get_client().flush()
+ logs = envelopes_to_logs(envelopes)
+ assert len(logs) == 0
+
+
+def test_disable_sentry_logs(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ _experiments={"enable_logs": False},
+ )
+ envelopes = capture_envelopes()
+
+ logger.trace("this is a log")
+ logger.debug("this is a log")
+ logger.info("this is a log")
+ logger.success("this is a log")
+ logger.warning("this is a log")
+ logger.error("this is a log")
+ logger.critical("this is a log")
+
+ sentry_sdk.get_client().flush()
+ logs = envelopes_to_logs(envelopes)
+ assert len(logs) == 0
+
+
+def test_no_log_infinite_loop(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ """
+ In debug mode, there should be no infinite loops even when a low log level is set.
+ """
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ enable_logs=True,
+ integrations=[LoguruIntegration(sentry_logs_level=LoggingLevels.DEBUG)],
+ debug=True,
+ )
+ envelopes = capture_envelopes()
+
+ logger.debug("this is %s a template %s", "1", "2")
+ sentry_sdk.get_client().flush()
+
+ assert len(envelopes) == 1
+
+
+def test_logging_errors(sentry_init, capture_envelopes, uninstall_integration, request):
+ """We're able to log errors without erroring."""
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.error(Exception("test exc 1"))
+ logger.error("error is %s", Exception("test exc 2"))
+ sentry_sdk.get_client().flush()
+
+ error_event_1 = envelopes[0].items[0].payload.json
+ assert error_event_1["level"] == "error"
+ error_event_2 = envelopes[1].items[0].payload.json
+ assert error_event_2["level"] == "error"
+
+ logs = envelopes_to_logs(envelopes)
+ assert logs[0]["severity_text"] == "error"
+ assert "code.line.number" in logs[0]["attributes"]
+
+ assert logs[1]["severity_text"] == "error"
+ assert "code.line.number" in logs[1]["attributes"]
+
+ assert len(logs) == 2
+
+
+def test_log_strips_project_root(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ enable_logs=True,
+ project_root="/custom/test",
+ )
+ envelopes = capture_envelopes()
+
+ class FakeMessage:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ @property
+ def record(self):
+ return {
+ "elapsed": MagicMock(),
+ "exception": None,
+ "file": RecordFile(name="app.py", path="/custom/test/blah/path.py"),
+ "function": "",
+ "level": RecordLevel(name="ERROR", no=20, icon=""),
+ "line": 35,
+ "message": "some message",
+ "module": "app",
+ "name": "__main__",
+ "process": MagicMock(),
+ "thread": MagicMock(),
+ "time": MagicMock(),
+ "extra": MagicMock(),
+ }
+
+ @record.setter
+ def record(self, val):
+ pass
+
+ with patch("loguru._handler.Message", FakeMessage):
+ logger.error("some message")
+
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+ assert len(logs) == 1
+ attrs = logs[0]["attributes"]
+ assert attrs["code.file.path"] == "blah/path.py"
+
+
+def test_log_keeps_full_path_if_not_in_project_root(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(
+ enable_logs=True,
+ project_root="/custom/test",
+ )
+ envelopes = capture_envelopes()
+
+ class FakeMessage:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ @property
+ def record(self):
+ return {
+ "elapsed": MagicMock(),
+ "exception": None,
+ "file": RecordFile(name="app.py", path="/blah/path.py"),
+ "function": "",
+ "level": RecordLevel(name="ERROR", no=20, icon=""),
+ "line": 35,
+ "message": "some message",
+ "module": "app",
+ "name": "__main__",
+ "process": MagicMock(),
+ "thread": MagicMock(),
+ "time": MagicMock(),
+ "extra": MagicMock(),
+ }
+
+ @record.setter
+ def record(self, val):
+ pass
+
+ with patch("loguru._handler.Message", FakeMessage):
+ logger.error("some message")
+
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+ assert len(logs) == 1
+ attrs = logs[0]["attributes"]
+ assert attrs["code.file.path"] == "/blah/path.py"
+
+
+def test_logger_with_all_attributes(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.warning("log #{}", 1)
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+
+ assert "span_id" in logs[0]
+ assert isinstance(logs[0]["span_id"], str)
+
+ attributes = logs[0]["attributes"]
+
+ assert "process.pid" in attributes
+ assert isinstance(attributes["process.pid"], int)
+ del attributes["process.pid"]
+
+ assert "sentry.release" in attributes
+ assert isinstance(attributes["sentry.release"], str)
+ del attributes["sentry.release"]
+
+ assert "server.address" in attributes
+ assert isinstance(attributes["server.address"], str)
+ del attributes["server.address"]
+
+ assert "thread.id" in attributes
+ assert isinstance(attributes["thread.id"], int)
+ del attributes["thread.id"]
+
+ assert "code.file.path" in attributes
+ assert isinstance(attributes["code.file.path"], str)
+ del attributes["code.file.path"]
+
+ assert "code.function.name" in attributes
+ assert isinstance(attributes["code.function.name"], str)
+ del attributes["code.function.name"]
+
+ assert "code.line.number" in attributes
+ assert isinstance(attributes["code.line.number"], int)
+ del attributes["code.line.number"]
+
+ assert "process.executable.name" in attributes
+ assert isinstance(attributes["process.executable.name"], str)
+ del attributes["process.executable.name"]
+
+ assert "thread.name" in attributes
+ assert isinstance(attributes["thread.name"], str)
+ del attributes["thread.name"]
+
+ assert attributes.pop("sentry.sdk.name").startswith("sentry.python")
+
+ # Assert on the remaining non-dynamic attributes.
+ assert attributes == {
+ "logger.name": "tests.integrations.loguru.test_loguru",
+ "sentry.origin": "auto.log.loguru",
+ "sentry.environment": "production",
+ "sentry.sdk.version": VERSION,
+ "sentry.severity_number": 13,
+ "sentry.severity_text": "warn",
+ }
+
+
+def test_logger_capture_parameters_from_args(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ # This is currently not supported as regular args don't get added to extra
+ # (which we use for populating parameters). Adding this test to make that
+ # explicit and so that it's easy to change later.
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.warning("Task ID: {}", 123)
+
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+
+ attributes = logs[0]["attributes"]
+ assert "sentry.message.parameter.0" not in attributes
+
+
+def test_logger_capture_parameters_from_kwargs(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.warning("Task ID: {task_id}", task_id=123)
+
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+
+ attributes = logs[0]["attributes"]
+ assert attributes["sentry.message.parameter.task_id"] == 123
+
+
+def test_logger_capture_parameters_from_contextualize(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ with logger.contextualize(task_id=123):
+ logger.warning("Log")
+
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+
+ attributes = logs[0]["attributes"]
+ assert attributes["sentry.message.parameter.task_id"] == 123
+
+
+def test_logger_capture_parameters_from_bind(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.bind(task_id=123).warning("Log")
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+
+ attributes = logs[0]["attributes"]
+ assert attributes["sentry.message.parameter.task_id"] == 123
+
+
+def test_logger_capture_parameters_from_patch(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.patch(lambda record: record["extra"].update(task_id=123)).warning("Log")
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+
+ attributes = logs[0]["attributes"]
+ assert attributes["sentry.message.parameter.task_id"] == 123
+
+
+def test_no_parameters_no_template(
+ sentry_init, capture_envelopes, uninstall_integration, request
+):
+ uninstall_integration("loguru")
+ request.addfinalizer(logger.remove)
+
+ sentry_init(enable_logs=True)
+ envelopes = capture_envelopes()
+
+ logger.warning("Logging a hardcoded warning")
+ sentry_sdk.get_client().flush()
+
+ logs = envelopes_to_logs(envelopes)
+
+ attributes = logs[0]["attributes"]
+ assert "sentry.message.template" not in attributes
diff --git a/tests/integrations/mcp/__init__.py b/tests/integrations/mcp/__init__.py
new file mode 100644
index 0000000000..01ef442500
--- /dev/null
+++ b/tests/integrations/mcp/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("mcp")
diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py
new file mode 100644
index 0000000000..f51d0491ae
--- /dev/null
+++ b/tests/integrations/mcp/test_mcp.py
@@ -0,0 +1,1133 @@
+"""
+Unit tests for the MCP (Model Context Protocol) integration.
+
+This test suite covers:
+- Tool handlers (sync and async)
+- Prompt handlers (sync and async)
+- Resource handlers (sync and async)
+- Error handling for each handler type
+- Request context data extraction (request_id, session_id, transport)
+- Tool result content extraction (various formats)
+- Span data validation
+- Origin tracking
+
+The tests mock the MCP server components and request context to verify
+that the integration properly instruments MCP handlers with Sentry spans.
+"""
+
+import anyio
+import asyncio
+
+import pytest
+import json
+from unittest import mock
+
+try:
+ from unittest.mock import AsyncMock
+except ImportError:
+
+ class AsyncMock(mock.MagicMock):
+ async def __call__(self, *args, **kwargs):
+ return super(AsyncMock, self).__call__(*args, **kwargs)
+
+
+from mcp.server.lowlevel import Server
+from mcp.server.lowlevel.server import request_ctx
+from mcp.types import GetPromptResult, PromptMessage, TextContent
+from mcp.server.lowlevel.helper_types import ReadResourceContents
+
+try:
+ from mcp.server.lowlevel.server import request_ctx
+except ImportError:
+ request_ctx = None
+
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import SPANDATA, OP
+from sentry_sdk.integrations.mcp import MCPIntegration
+
+from mcp.server.sse import SseServerTransport
+from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
+from starlette.routing import Mount, Route
+from starlette.applications import Starlette
+from starlette.responses import Response
+
+
+@pytest.fixture(autouse=True)
+def reset_request_ctx():
+ """Reset request context before and after each test"""
+ if request_ctx is not None:
+ try:
+ if request_ctx.get() is not None:
+ request_ctx.set(None)
+ except LookupError:
+ pass
+
+ yield
+
+ if request_ctx is not None:
+ try:
+ request_ctx.set(None)
+ except LookupError:
+ pass
+
+
+class MockTextContent:
+ """Mock TextContent object"""
+
+ def __init__(self, text):
+ self.text = text
+
+
+def test_integration_patches_server(sentry_init):
+ """Test that MCPIntegration patches the Server class"""
+ # Get original methods before integration
+ original_call_tool = Server.call_tool
+ original_get_prompt = Server.get_prompt
+ original_read_resource = Server.read_resource
+
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # After initialization, the methods should be patched
+ assert Server.call_tool is not original_call_tool
+ assert Server.get_prompt is not original_get_prompt
+ assert Server.read_resource is not original_read_resource
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+async def test_tool_handler_stdio(
+ sentry_init, capture_events, send_default_pii, include_prompts, stdio
+):
+ """Test that synchronous tool handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.call_tool()
+ async def test_tool(tool_name, arguments):
+ return {"result": "success", "value": 42}
+
+ with start_transaction(name="mcp tx"):
+ result = await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "calculate",
+ "arguments": {"x": 10, "y": 5},
+ },
+ request_id="req-123",
+ )
+
+ assert result.message.root.result["content"][0]["text"] == json.dumps(
+ {"result": "success", "value": 42},
+ indent=2,
+ )
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+
+ span = tx["spans"][0]
+ assert span["op"] == OP.MCP_SERVER
+ assert span["description"] == "tools/call calculate"
+ assert span["origin"] == "auto.ai.mcp"
+
+ # Check span data
+ assert span["data"][SPANDATA.MCP_TOOL_NAME] == "calculate"
+ assert span["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call"
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio"
+ assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-123"
+ assert span["data"]["mcp.request.argument.x"] == "10"
+ assert span["data"]["mcp.request.argument.y"] == "5"
+
+ # Check PII-sensitive data is only present when both flags are True
+ if send_default_pii and include_prompts:
+ assert span["data"][SPANDATA.MCP_TOOL_RESULT_CONTENT] == json.dumps(
+ {
+ "result": "success",
+ "value": 42,
+ }
+ )
+ assert span["data"][SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT] == 2
+ else:
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"]
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT not in span["data"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+async def test_tool_handler_streamable_http(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ json_rpc,
+ select_transactions_with_mcp_spans,
+):
+ """Test that async tool handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ session_manager = StreamableHTTPSessionManager(
+ app=server,
+ json_response=True,
+ )
+
+ app = Starlette(
+ routes=[
+ Mount("/mcp", app=session_manager.handle_request),
+ ],
+ lifespan=lambda app: session_manager.run(),
+ )
+
+ @server.call_tool()
+ async def test_tool_async(tool_name, arguments):
+ return [
+ TextContent(
+ type="text",
+ text=json.dumps({"status": "completed"}),
+ )
+ ]
+
+ session_id, result = json_rpc(
+ app,
+ method="tools/call",
+ params={
+ "name": "process",
+ "arguments": {
+ "data": "test",
+ },
+ },
+ request_id="req-456",
+ )
+ assert result.json()["result"]["content"][0]["text"] == json.dumps(
+ {"status": "completed"}
+ )
+
+ transactions = select_transactions_with_mcp_spans(events, method_name="tools/call")
+ assert len(transactions) == 1
+ tx = transactions[0]
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+ span = tx["spans"][0]
+
+ assert span["op"] == OP.MCP_SERVER
+ assert span["description"] == "tools/call process"
+ assert span["origin"] == "auto.ai.mcp"
+
+ # Check span data
+ assert span["data"][SPANDATA.MCP_TOOL_NAME] == "process"
+ assert span["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call"
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "http"
+ assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-456"
+ assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id
+ assert span["data"]["mcp.request.argument.data"] == "test"
+
+ # Check PII-sensitive data
+ if send_default_pii and include_prompts:
+ assert span["data"][SPANDATA.MCP_TOOL_RESULT_CONTENT] == json.dumps(
+ {"status": "completed"}
+ )
+ else:
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"]
+
+
+@pytest.mark.asyncio
+async def test_tool_handler_with_error(sentry_init, capture_events, stdio):
+ """Test that tool handler errors are captured properly"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.call_tool()
+ def failing_tool(tool_name, arguments):
+ raise ValueError("Tool execution failed")
+
+ with start_transaction(name="mcp tx"):
+ result = await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "bad_tool",
+ "arguments": {},
+ },
+ request_id="req-error",
+ )
+
+ assert (
+ result.message.root.result["content"][0]["text"] == "Tool execution failed"
+ )
+
+ # Should have error event and transaction
+ assert len(events) == 2
+ error_event, tx = events
+
+ # Check error event
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "ValueError"
+ assert error_event["exception"]["values"][0]["value"] == "Tool execution failed"
+
+ # Check transaction and span
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+ span = tx["spans"][0]
+
+ # Error flag should be set for tools
+ assert span["data"][SPANDATA.MCP_TOOL_RESULT_IS_ERROR] is True
+ assert span["status"] == "internal_error"
+ assert span["tags"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+async def test_prompt_handler_stdio(
+ sentry_init, capture_events, send_default_pii, include_prompts, stdio
+):
+ """Test that synchronous prompt handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.get_prompt()
+ async def test_prompt(name, arguments):
+ return GetPromptResult(
+ description="A helpful test prompt",
+ messages=[
+ PromptMessage(
+ role="user",
+ content=TextContent(type="text", text="Tell me about Python"),
+ ),
+ ],
+ )
+
+ with start_transaction(name="mcp tx"):
+ result = await stdio(
+ server,
+ method="prompts/get",
+ params={
+ "name": "code_help",
+ "arguments": {"language": "python"},
+ },
+ request_id="req-prompt",
+ )
+
+ assert result.message.root.result["messages"][0]["role"] == "user"
+ assert (
+ result.message.root.result["messages"][0]["content"]["text"]
+ == "Tell me about Python"
+ )
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+
+ span = tx["spans"][0]
+ assert span["op"] == OP.MCP_SERVER
+ assert span["description"] == "prompts/get code_help"
+ assert span["origin"] == "auto.ai.mcp"
+
+ # Check span data
+ assert span["data"][SPANDATA.MCP_PROMPT_NAME] == "code_help"
+ assert span["data"][SPANDATA.MCP_METHOD_NAME] == "prompts/get"
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio"
+ assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-prompt"
+ assert span["data"]["mcp.request.argument.name"] == "code_help"
+ assert span["data"]["mcp.request.argument.language"] == "python"
+
+ # Message count is always captured
+ assert span["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT] == 1
+
+ # For single message prompts, role and content should be captured only with PII
+ if send_default_pii and include_prompts:
+ assert span["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE] == "user"
+ assert (
+ span["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT]
+ == "Tell me about Python"
+ )
+ else:
+ assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE not in span["data"]
+ assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT not in span["data"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+async def test_prompt_handler_streamable_http(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ json_rpc,
+ select_transactions_with_mcp_spans,
+):
+ """Test that async prompt handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ session_manager = StreamableHTTPSessionManager(
+ app=server,
+ json_response=True,
+ )
+
+ app = Starlette(
+ routes=[
+ Mount("/mcp", app=session_manager.handle_request),
+ ],
+ lifespan=lambda app: session_manager.run(),
+ )
+
+ @server.get_prompt()
+ async def test_prompt_async(name, arguments):
+ return GetPromptResult(
+ description="A helpful test prompt",
+ messages=[
+ PromptMessage(
+ role="user",
+ content=TextContent(
+ type="text", text="You are a helpful assistant"
+ ),
+ ),
+ PromptMessage(
+ role="user", content=TextContent(type="text", text="What is MCP?")
+ ),
+ ],
+ )
+
+ _, result = json_rpc(
+ app,
+ method="prompts/get",
+ params={
+ "name": "mcp_info",
+ "arguments": {},
+ },
+ request_id="req-async-prompt",
+ )
+ assert len(result.json()["result"]["messages"]) == 2
+
+ transactions = select_transactions_with_mcp_spans(events, method_name="prompts/get")
+ assert len(transactions) == 1
+ tx = transactions[0]
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+ span = tx["spans"][0]
+
+ assert span["op"] == OP.MCP_SERVER
+ assert span["description"] == "prompts/get mcp_info"
+
+ # For multi-message prompts, count is always captured
+ assert span["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT] == 2
+ # Role/content are never captured for multi-message prompts (even with PII)
+ assert (
+ SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE not in tx["contexts"]["trace"]["data"]
+ )
+ assert (
+ SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT
+ not in tx["contexts"]["trace"]["data"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_prompt_handler_with_error(sentry_init, capture_events, stdio):
+ """Test that prompt handler errors are captured"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.get_prompt()
+ async def failing_prompt(name, arguments):
+ raise RuntimeError("Prompt not found")
+
+ with start_transaction(name="mcp tx"):
+ response = await stdio(
+ server,
+ method="prompts/get",
+ params={
+ "name": "code_help",
+ "arguments": {"language": "python"},
+ },
+ request_id="req-error-prompt",
+ )
+
+ assert response.message.root.error.message == "Prompt not found"
+
+ # Should have error event and transaction
+ assert len(events) == 2
+ error_event, tx = events
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "RuntimeError"
+
+
+@pytest.mark.asyncio
+async def test_resource_handler_stdio(sentry_init, capture_events, stdio):
+ """Test that synchronous resource handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.read_resource()
+ async def test_resource(uri):
+ return [
+ ReadResourceContents(
+ content=json.dumps({"content": "file contents"}), mime_type="text/plain"
+ )
+ ]
+
+ with start_transaction(name="mcp tx"):
+ result = await stdio(
+ server,
+ method="resources/read",
+ params={
+ "uri": "file:///path/to/file.txt",
+ },
+ request_id="req-resource",
+ )
+
+ assert result.message.root.result["contents"][0]["text"] == json.dumps(
+ {"content": "file contents"},
+ )
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+
+ span = tx["spans"][0]
+ assert span["op"] == OP.MCP_SERVER
+ assert span["description"] == "resources/read file:///path/to/file.txt"
+ assert span["origin"] == "auto.ai.mcp"
+
+ # Check span data
+ assert span["data"][SPANDATA.MCP_RESOURCE_URI] == "file:///path/to/file.txt"
+ assert span["data"][SPANDATA.MCP_METHOD_NAME] == "resources/read"
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio"
+ assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-resource"
+ assert span["data"][SPANDATA.MCP_RESOURCE_PROTOCOL] == "file"
+ # Resources don't capture result content
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"]
+
+
+@pytest.mark.asyncio
+async def test_resource_handler_streamble_http(
+ sentry_init,
+ capture_events,
+ json_rpc,
+ select_transactions_with_mcp_spans,
+):
+ """Test that async resource handlers create proper spans"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ session_manager = StreamableHTTPSessionManager(
+ app=server,
+ json_response=True,
+ )
+
+ app = Starlette(
+ routes=[
+ Mount("/mcp", app=session_manager.handle_request),
+ ],
+ lifespan=lambda app: session_manager.run(),
+ )
+
+ @server.read_resource()
+ async def test_resource_async(uri):
+ return [
+ ReadResourceContents(
+ content=json.dumps({"data": "resource data"}), mime_type="text/plain"
+ )
+ ]
+
+ session_id, result = json_rpc(
+ app,
+ method="resources/read",
+ params={
+ "uri": "https://example.com/resource",
+ },
+ request_id="req-async-resource",
+ )
+
+ assert result.json()["result"]["contents"][0]["text"] == json.dumps(
+ {"data": "resource data"}
+ )
+
+ transactions = select_transactions_with_mcp_spans(
+ events, method_name="resources/read"
+ )
+ assert len(transactions) == 1
+ tx = transactions[0]
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+ span = tx["spans"][0]
+
+ assert span["op"] == OP.MCP_SERVER
+ assert span["description"] == "resources/read https://example.com/resource"
+
+ assert span["data"][SPANDATA.MCP_RESOURCE_URI] == "https://example.com/resource"
+ assert span["data"][SPANDATA.MCP_RESOURCE_PROTOCOL] == "https"
+ assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id
+
+
+@pytest.mark.asyncio
+async def test_resource_handler_with_error(sentry_init, capture_events, stdio):
+ """Test that resource handler errors are captured"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.read_resource()
+ def failing_resource(uri):
+ raise FileNotFoundError("Resource not found")
+
+ with start_transaction(name="mcp tx"):
+ await stdio(
+ server,
+ method="resources/read",
+ params={
+ "uri": "file:///missing.txt",
+ },
+ request_id="req-error-resource",
+ )
+
+ # Should have error event and transaction
+ assert len(events) == 2
+ error_event, tx = events
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "FileNotFoundError"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (False, False)],
+)
+async def test_tool_result_extraction_tuple(
+ sentry_init, capture_events, send_default_pii, include_prompts, stdio
+):
+ """Test extraction of tool results from tuple format (UnstructuredContent, StructuredContent)"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.call_tool()
+ def test_tool_tuple(tool_name, arguments):
+ # Return CombinationContent: (UnstructuredContent, StructuredContent)
+ unstructured = [MockTextContent("Result text")]
+ structured = {"key": "value", "count": 5}
+ return (unstructured, structured)
+
+ with start_transaction(name="mcp tx"):
+ await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "calculate",
+ "arguments": {},
+ },
+ request_id="req-tuple",
+ )
+
+ (tx,) = events
+ span = tx["spans"][0]
+
+ # Should extract the structured content (second element of tuple) only with PII
+ if send_default_pii and include_prompts:
+ assert span["data"][SPANDATA.MCP_TOOL_RESULT_CONTENT] == json.dumps(
+ {
+ "key": "value",
+ "count": 5,
+ }
+ )
+ assert span["data"][SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT] == 2
+ else:
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"]
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT_COUNT not in span["data"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (False, False)],
+)
+async def test_tool_result_extraction_unstructured(
+ sentry_init, capture_events, send_default_pii, include_prompts, stdio
+):
+ """Test extraction of tool results from UnstructuredContent (list of content blocks)"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.call_tool()
+ def test_tool_unstructured(tool_name, arguments):
+ # Return UnstructuredContent as list of content blocks
+ return [
+ MockTextContent("First part"),
+ MockTextContent("Second part"),
+ ]
+
+ with start_transaction(name="mcp tx"):
+ await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "text_tool",
+ "arguments": {},
+ },
+ request_id="req-unstructured",
+ )
+
+ (tx,) = events
+ span = tx["spans"][0]
+
+ # Should extract and join text from content blocks only with PII
+ if send_default_pii and include_prompts:
+ assert (
+ span["data"][SPANDATA.MCP_TOOL_RESULT_CONTENT] == "First part Second part"
+ )
+ else:
+ assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"]
+
+
+@pytest.mark.asyncio
+async def test_span_origin(sentry_init, capture_events, stdio):
+ """Test that span origin is set correctly"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.call_tool()
+ def test_tool(tool_name, arguments):
+ return {"result": "test"}
+
+ with start_transaction(name="mcp tx"):
+ await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "calculate",
+ "arguments": {"x": 10, "y": 5},
+ },
+ request_id="req-origin",
+ )
+
+ (tx,) = events
+
+ assert tx["contexts"]["trace"]["origin"] == "manual"
+ assert tx["spans"][0]["origin"] == "auto.ai.mcp"
+
+
+@pytest.mark.asyncio
+async def test_multiple_handlers(sentry_init, capture_events, stdio):
+ """Test that multiple handler calls create multiple spans"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.call_tool()
+ def tool1(tool_name, arguments):
+ return {"result": "tool1"}
+
+ @server.call_tool()
+ def tool2(tool_name, arguments):
+ return {"result": "tool2"}
+
+ @server.get_prompt()
+ def prompt1(name, arguments):
+ return GetPromptResult(
+ description="A test prompt",
+ messages=[
+ PromptMessage(
+ role="user", content=TextContent(type="text", text="Test prompt")
+ )
+ ],
+ )
+
+ with start_transaction(name="mcp tx"):
+ await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "tool_a",
+ "arguments": {},
+ },
+ request_id="req-multi",
+ )
+
+ await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "tool_b",
+ "arguments": {},
+ },
+ request_id="req-multi",
+ )
+
+ await stdio(
+ server,
+ method="prompts/get",
+ params={
+ "name": "prompt_a",
+ "arguments": {},
+ },
+ request_id="req-multi",
+ )
+
+ (tx,) = events
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 3
+
+ # Check that we have different span types
+ span_ops = [span["op"] for span in tx["spans"]]
+ assert all(op == OP.MCP_SERVER for op in span_ops)
+
+ span_descriptions = [span["description"] for span in tx["spans"]]
+ assert "tools/call tool_a" in span_descriptions
+ assert "tools/call tool_b" in span_descriptions
+ assert "prompts/get prompt_a" in span_descriptions
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (False, False)],
+)
+async def test_prompt_with_dict_result(
+ sentry_init, capture_events, send_default_pii, include_prompts, stdio
+):
+ """Test prompt handler with dict result instead of GetPromptResult object"""
+ sentry_init(
+ integrations=[MCPIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.get_prompt()
+ def test_prompt_dict(name, arguments):
+ # Return dict format instead of GetPromptResult object
+ return {
+ "messages": [
+ {"role": "user", "content": {"text": "Hello from dict"}},
+ ]
+ }
+
+ with start_transaction(name="mcp tx"):
+ await stdio(
+ server,
+ method="prompts/get",
+ params={
+ "name": "dict_prompt",
+ "arguments": {},
+ },
+ request_id="req-dict-prompt",
+ )
+
+ (tx,) = events
+ span = tx["spans"][0]
+
+ # Message count is always captured
+ assert span["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT] == 1
+
+ # Role and content only captured with PII
+ if send_default_pii and include_prompts:
+ assert span["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE] == "user"
+ assert (
+ span["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT]
+ == "Hello from dict"
+ )
+ else:
+ assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE not in span["data"]
+ assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT not in span["data"]
+
+
+@pytest.mark.asyncio
+async def test_tool_with_complex_arguments(sentry_init, capture_events, stdio):
+ """Test tool handler with complex nested arguments"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.call_tool()
+ def test_tool_complex(tool_name, arguments):
+ return {"processed": True}
+
+ with start_transaction(name="mcp tx"):
+ complex_args = {
+ "nested": {"key": "value", "list": [1, 2, 3]},
+ "string": "test",
+ "number": 42,
+ }
+ await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "complex_tool",
+ "arguments": complex_args,
+ },
+ request_id="req-complex",
+ )
+
+ (tx,) = events
+ span = tx["spans"][0]
+
+ # Complex arguments should be serialized
+ assert span["data"]["mcp.request.argument.nested"] == json.dumps(
+ {"key": "value", "list": [1, 2, 3]}
+ )
+ assert span["data"]["mcp.request.argument.string"] == "test"
+ assert span["data"]["mcp.request.argument.number"] == "42"
+
+
+@pytest.mark.asyncio
+async def test_sse_transport_detection(sentry_init, capture_events, json_rpc_sse):
+ """Test that SSE transport is correctly detected via query parameter"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+ sse = SseServerTransport("/messages/")
+
+ sse_connection_closed = asyncio.Event()
+
+ async def handle_sse(request):
+ async with sse.connect_sse(
+ request.scope, request.receive, request._send
+ ) as streams:
+ async with anyio.create_task_group() as tg:
+
+ async def run_server():
+ await server.run(
+ streams[0], streams[1], server.create_initialization_options()
+ )
+
+ tg.start_soon(run_server)
+
+ sse_connection_closed.set()
+ return Response()
+
+ app = Starlette(
+ routes=[
+ Route("/sse", endpoint=handle_sse, methods=["GET"]),
+ Mount("/messages/", app=sse.handle_post_message),
+ ],
+ )
+
+ @server.call_tool()
+ async def test_tool(tool_name, arguments):
+ return {"result": "success"}
+
+ keep_sse_alive = asyncio.Event()
+ app_task, session_id, result = await json_rpc_sse(
+ app,
+ method="tools/call",
+ params={
+ "name": "sse_tool",
+ "arguments": {},
+ },
+ request_id="req-sse",
+ keep_sse_alive=keep_sse_alive,
+ )
+
+ await sse_connection_closed.wait()
+ await app_task
+
+ assert result["result"]["structuredContent"] == {"result": "success"}
+
+ transactions = [
+ event
+ for event in events
+ if event["type"] == "transaction" and event["transaction"] == "/sse"
+ ]
+ assert len(transactions) == 1
+ tx = transactions[0]
+ span = tx["spans"][0]
+
+ # Check that SSE transport is detected
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "sse"
+ assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp"
+ assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id
+
+
+def test_streamable_http_transport_detection(
+ sentry_init,
+ capture_events,
+ json_rpc,
+ select_transactions_with_mcp_spans,
+):
+ """Test that StreamableHTTP transport is correctly detected via header"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ session_manager = StreamableHTTPSessionManager(
+ app=server,
+ json_response=True,
+ )
+
+ app = Starlette(
+ routes=[
+ Mount("/mcp", app=session_manager.handle_request),
+ ],
+ lifespan=lambda app: session_manager.run(),
+ )
+
+ @server.call_tool()
+ async def test_tool(tool_name, arguments):
+ return [
+ TextContent(
+ type="text",
+ text=json.dumps({"status": "success"}),
+ )
+ ]
+
+ session_id, result = json_rpc(
+ app,
+ method="tools/call",
+ params={
+ "name": "http_tool",
+ "arguments": {},
+ },
+ request_id="req-http",
+ )
+ assert result.json()["result"]["content"][0]["text"] == json.dumps(
+ {"status": "success"}
+ )
+
+ transactions = select_transactions_with_mcp_spans(events, method_name="tools/call")
+ assert len(transactions) == 1
+ tx = transactions[0]
+ assert tx["type"] == "transaction"
+ assert len(tx["spans"]) == 1
+ span = tx["spans"][0]
+
+ # Check that HTTP transport is detected
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "http"
+ assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp"
+ assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id
+
+
+@pytest.mark.asyncio
+async def test_stdio_transport_detection(sentry_init, capture_events, stdio):
+ """Test that stdio transport is correctly detected when no HTTP request"""
+ sentry_init(
+ integrations=[MCPIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ server = Server("test-server")
+
+ @server.call_tool()
+ async def test_tool(tool_name, arguments):
+ return {"result": "success"}
+
+ with start_transaction(name="mcp tx"):
+ result = await stdio(
+ server,
+ method="tools/call",
+ params={
+ "name": "stdio_tool",
+ "arguments": {},
+ },
+ request_id="req-stdio",
+ )
+
+ assert result.message.root.result["structuredContent"] == {"result": "success"}
+
+ (tx,) = events
+ span = tx["spans"][0]
+
+ # Check that stdio transport is detected
+ assert span["data"][SPANDATA.MCP_TRANSPORT] == "stdio"
+ assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "pipe"
+ # No session ID for stdio transport
+ assert SPANDATA.MCP_SESSION_ID not in span["data"]
diff --git a/tests/integrations/openai/__init__.py b/tests/integrations/openai/__init__.py
new file mode 100644
index 0000000000..d6cc3d5505
--- /dev/null
+++ b/tests/integrations/openai/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("openai")
diff --git a/tests/integrations/openai/test_openai.py b/tests/integrations/openai/test_openai.py
new file mode 100644
index 0000000000..ada2e633de
--- /dev/null
+++ b/tests/integrations/openai/test_openai.py
@@ -0,0 +1,3997 @@
+import json
+import pytest
+
+from sentry_sdk.utils import package_version
+
+try:
+ from openai import NOT_GIVEN
+except ImportError:
+ NOT_GIVEN = None
+try:
+ from openai import omit
+ from openai import Omit
+except ImportError:
+ omit = None
+ Omit = None
+
+from openai import AsyncOpenAI, OpenAI, AsyncStream, Stream, OpenAIError
+from openai.types import CompletionUsage, CreateEmbeddingResponse, Embedding
+from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionChunk
+from openai.types.chat.chat_completion import Choice
+from openai.types.chat.chat_completion_chunk import ChoiceDelta, Choice as DeltaChoice
+from openai.types.create_embedding_response import Usage as EmbeddingTokenUsage
+
+SKIP_RESPONSES_TESTS = False
+
+try:
+ from openai.types.responses.response_completed_event import ResponseCompletedEvent
+ from openai.types.responses.response_created_event import ResponseCreatedEvent
+ from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
+ from openai.types.responses.response_usage import (
+ InputTokensDetails,
+ OutputTokensDetails,
+ )
+ from openai.types.responses import (
+ Response,
+ ResponseUsage,
+ ResponseOutputMessage,
+ ResponseOutputText,
+ )
+except ImportError:
+ SKIP_RESPONSES_TESTS = True
+
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import SPANDATA, OP
+from sentry_sdk.integrations.openai import (
+ OpenAIIntegration,
+ _calculate_completions_token_usage,
+ _calculate_responses_token_usage,
+)
+from sentry_sdk.utils import safe_serialize
+
+from unittest import mock # python 3.3 and above
+
+try:
+ from unittest.mock import AsyncMock
+except ImportError:
+
+ class AsyncMock(mock.MagicMock):
+ async def __call__(self, *args, **kwargs):
+ return super(AsyncMock, self).__call__(*args, **kwargs)
+
+
+OPENAI_VERSION = package_version("openai")
+EXAMPLE_CHAT_COMPLETION = ChatCompletion(
+ id="chat-id",
+ choices=[
+ Choice(
+ index=0,
+ finish_reason="stop",
+ message=ChatCompletionMessage(
+ role="assistant", content="the model response"
+ ),
+ )
+ ],
+ created=10000000,
+ model="response-model-id",
+ object="chat.completion",
+ usage=CompletionUsage(
+ completion_tokens=10,
+ prompt_tokens=20,
+ total_tokens=30,
+ ),
+)
+
+
+if SKIP_RESPONSES_TESTS:
+ EXAMPLE_RESPONSE = None
+else:
+ EXAMPLE_RESPONSE = Response(
+ id="chat-id",
+ output=[
+ ResponseOutputMessage(
+ id="message-id",
+ content=[
+ ResponseOutputText(
+ annotations=[],
+ text="the model response",
+ type="output_text",
+ ),
+ ],
+ role="assistant",
+ status="completed",
+ type="message",
+ ),
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="response-model-id",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=20,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=5,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=8,
+ ),
+ total_tokens=30,
+ ),
+ )
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_nonstreaming_chat_completion_no_prompts(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ with start_transaction(name="openai tx"):
+ response = (
+ client.chat.completions.create(
+ model="some-model",
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "hello"},
+ ],
+ max_tokens=100,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ temperature=0.7,
+ top_p=0.9,
+ )
+ .choices[0]
+ .message.content
+ )
+
+ assert response == "the model response"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.parametrize(
+ "messages",
+ [
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="parts",
+ ),
+ pytest.param(
+ iter(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ]
+ ),
+ id="iterator",
+ ),
+ ],
+)
+def test_nonstreaming_chat_completion(sentry_init, capture_events, messages, request):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ with start_transaction(name="openai tx"):
+ response = (
+ client.chat.completions.create(
+ model="some-model",
+ messages=messages,
+ max_tokens=100,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ temperature=0.7,
+ top_p=0.9,
+ )
+ .choices[0]
+ .message.content
+ )
+
+ assert response == "the model response"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ param_id = request.node.callspec.id
+ if "blocks" in param_id:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ }
+ ]
+ else:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "type": "text",
+ "content": "Be concise and clear.",
+ },
+ ]
+
+ assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_nonstreaming_chat_completion_async_no_prompts(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ client.chat.completions._post = mock.AsyncMock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ with start_transaction(name="openai tx"):
+ response = await client.chat.completions.create(
+ model="some-model",
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "hello"},
+ ],
+ max_tokens=100,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ temperature=0.7,
+ top_p=0.9,
+ )
+ response = response.choices[0].message.content
+
+ assert response == "the model response"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "messages",
+ [
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="parts",
+ ),
+ pytest.param(
+ iter(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ]
+ ),
+ id="iterator",
+ ),
+ ],
+)
+async def test_nonstreaming_chat_completion_async(
+ sentry_init, capture_events, messages, request
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ client.chat.completions._post = AsyncMock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ with start_transaction(name="openai tx"):
+ response = await client.chat.completions.create(
+ model="some-model",
+ messages=messages,
+ max_tokens=100,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ temperature=0.7,
+ top_p=0.9,
+ )
+ response = response.choices[0].message.content
+
+ assert response == "the model response"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ param_id = request.node.callspec.id
+ if "blocks" in param_id:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ }
+ ]
+ else:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "type": "text",
+ "content": "Be concise and clear.",
+ },
+ ]
+
+ assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+def tiktoken_encoding_if_installed():
+ try:
+ import tiktoken # type: ignore # noqa # pylint: disable=unused-import
+
+ return "cl100k_base"
+ except ImportError:
+ return None
+
+
+# noinspection PyTypeChecker
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_streaming_chat_completion_no_prompts(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+):
+ sentry_init(
+ integrations=[
+ OpenAIIntegration(
+ include_prompts=include_prompts,
+ tiktoken_encoding_name=tiktoken_encoding_if_installed(),
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ returned_stream = get_model_response(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="hel"),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=1,
+ delta=ChoiceDelta(content="lo "),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=2,
+ delta=ChoiceDelta(content="world"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = client.chat.completions.create(
+ model="some-model",
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "hello"},
+ ],
+ stream=True,
+ max_tokens=100,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ temperature=0.7,
+ top_p=0.9,
+ )
+ response_string = "".join(
+ map(lambda x: x.choices[0].delta.content, response_stream)
+ )
+
+ assert response_string == "hello world"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "model-id"
+
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ try:
+ import tiktoken # type: ignore # noqa # pylint: disable=unused-import
+
+ assert span["data"]["gen_ai.usage.output_tokens"] == 2
+ assert span["data"]["gen_ai.usage.input_tokens"] == 7
+ assert span["data"]["gen_ai.usage.total_tokens"] == 9
+ except ImportError:
+ pass # if tiktoken is not installed, we can't guarantee token usage will be calculated properly
+
+
+@pytest.mark.skipif(
+ OPENAI_VERSION <= (1, 1, 0),
+ reason="OpenAI versions <=1.1.0 do not support the stream_options parameter.",
+)
+def test_streaming_chat_completion_with_usage_in_stream(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ """When stream_options=include_usage is set, token usage comes from the final chunk's usage field."""
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ returned_stream = get_model_response(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="hel"),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="lo"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ usage=CompletionUsage(
+ prompt_tokens=20,
+ completion_tokens=10,
+ total_tokens=30,
+ ),
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = client.chat.completions.create(
+ model="some-model",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=True,
+ stream_options={"include_usage": True},
+ )
+ for _ in response_stream:
+ pass
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.skipif(
+ OPENAI_VERSION <= (1, 1, 0),
+ reason="OpenAI versions <=1.1.0 do not support the stream_options parameter.",
+)
+def test_streaming_chat_completion_empty_content_preserves_token_usage(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ server_side_event_chunks,
+):
+ """Token usage from the stream is recorded even when no content is produced (e.g. content filter)."""
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ returned_stream = get_model_response(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ usage=CompletionUsage(
+ prompt_tokens=20,
+ completion_tokens=0,
+ total_tokens=20,
+ ),
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = client.chat.completions.create(
+ model="some-model",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=True,
+ stream_options={"include_usage": True},
+ )
+ for _ in response_stream:
+ pass
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert "gen_ai.usage.output_tokens" not in span["data"]
+ assert span["data"]["gen_ai.usage.total_tokens"] == 20
+
+
+@pytest.mark.skipif(
+ OPENAI_VERSION <= (1, 1, 0),
+ reason="OpenAI versions <=1.1.0 do not support the stream_options parameter.",
+)
+@pytest.mark.asyncio
+async def test_streaming_chat_completion_empty_content_preserves_token_usage_async(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ """Token usage from the stream is recorded even when no content is produced - async variant."""
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ returned_stream = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ usage=CompletionUsage(
+ prompt_tokens=20,
+ completion_tokens=0,
+ total_tokens=20,
+ ),
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = await client.chat.completions.create(
+ model="some-model",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=True,
+ stream_options={"include_usage": True},
+ )
+ async for _ in response_stream:
+ pass
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert "gen_ai.usage.output_tokens" not in span["data"]
+ assert span["data"]["gen_ai.usage.total_tokens"] == 20
+
+
+@pytest.mark.skipif(
+ OPENAI_VERSION <= (1, 1, 0),
+ reason="OpenAI versions <=1.1.0 do not support the stream_options parameter.",
+)
+@pytest.mark.asyncio
+async def test_streaming_chat_completion_async_with_usage_in_stream(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ """When stream_options=include_usage is set, token usage comes from the final chunk's usage field (async)."""
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ returned_stream = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="hel"),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="lo"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ usage=CompletionUsage(
+ prompt_tokens=20,
+ completion_tokens=10,
+ total_tokens=30,
+ ),
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = await client.chat.completions.create(
+ model="some-model",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=True,
+ stream_options={"include_usage": True},
+ )
+ async for _ in response_stream:
+ pass
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+# noinspection PyTypeChecker
+@pytest.mark.parametrize(
+ "messages",
+ [
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="parts",
+ ),
+ pytest.param(
+ iter(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ]
+ ),
+ id="iterator",
+ ),
+ ],
+)
+def test_streaming_chat_completion(
+ sentry_init,
+ capture_events,
+ messages,
+ request,
+ get_model_response,
+ server_side_event_chunks,
+):
+ sentry_init(
+ integrations=[
+ OpenAIIntegration(
+ include_prompts=True,
+ tiktoken_encoding_name=tiktoken_encoding_if_installed(),
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ returned_stream = get_model_response(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="hel"),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=1,
+ delta=ChoiceDelta(content="lo "),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=2,
+ delta=ChoiceDelta(content="world"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = client.chat.completions.create(
+ model="some-model",
+ messages=messages,
+ stream=True,
+ max_tokens=100,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ temperature=0.7,
+ top_p=0.9,
+ )
+ response_string = "".join(
+ map(lambda x: x.choices[0].delta.content, response_stream)
+ )
+ assert response_string == "hello world"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ param_id = request.node.callspec.id
+ if "blocks" in param_id:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ }
+ ]
+ else:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "type": "text",
+ "content": "Be concise and clear.",
+ },
+ ]
+
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "model-id"
+
+ assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert "hello world" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+
+ try:
+ import tiktoken # type: ignore # noqa # pylint: disable=unused-import
+
+ if "blocks" in param_id:
+ assert span["data"]["gen_ai.usage.output_tokens"] == 2
+ assert span["data"]["gen_ai.usage.input_tokens"] == 7
+ assert span["data"]["gen_ai.usage.total_tokens"] == 9
+ else:
+ assert span["data"]["gen_ai.usage.output_tokens"] == 2
+ assert span["data"]["gen_ai.usage.input_tokens"] == 12
+ assert span["data"]["gen_ai.usage.total_tokens"] == 14
+ except ImportError:
+ pass # if tiktoken is not installed, we can't guarantee token usage will be calculated properly
+
+
+# noinspection PyTypeChecker
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_streaming_chat_completion_async_no_prompts(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ sentry_init(
+ integrations=[
+ OpenAIIntegration(
+ include_prompts=include_prompts,
+ tiktoken_encoding_name=tiktoken_encoding_if_installed(),
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ returned_stream = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="hel"),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=1,
+ delta=ChoiceDelta(content="lo "),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=2,
+ delta=ChoiceDelta(content="world"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = await client.chat.completions.create(
+ model="some-model",
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "hello"},
+ ],
+ stream=True,
+ max_tokens=100,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ temperature=0.7,
+ top_p=0.9,
+ )
+
+ response_string = ""
+ async for x in response_stream:
+ response_string += x.choices[0].delta.content
+
+ assert response_string == "hello world"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "model-id"
+
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in span["data"]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ try:
+ import tiktoken # type: ignore # noqa # pylint: disable=unused-import
+
+ assert span["data"]["gen_ai.usage.output_tokens"] == 2
+ assert span["data"]["gen_ai.usage.input_tokens"] == 7
+ assert span["data"]["gen_ai.usage.total_tokens"] == 9
+
+ except ImportError:
+ pass # if tiktoken is not installed, we can't guarantee token usage will be calculated properly
+
+
+# noinspection PyTypeChecker
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "messages",
+ [
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="parts",
+ ),
+ pytest.param(
+ iter(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ]
+ ),
+ id="iterator",
+ ),
+ ],
+)
+async def test_streaming_chat_completion_async(
+ sentry_init,
+ capture_events,
+ messages,
+ request,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ sentry_init(
+ integrations=[
+ OpenAIIntegration(
+ include_prompts=True,
+ tiktoken_encoding_name=tiktoken_encoding_if_installed(),
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+
+ returned_stream = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="hel"),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=1,
+ delta=ChoiceDelta(content="lo "),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=2,
+ delta=ChoiceDelta(content="world"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ],
+ include_event_type=False,
+ )
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = await client.chat.completions.create(
+ model="some-model",
+ messages=messages,
+ stream=True,
+ max_tokens=100,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ temperature=0.7,
+ top_p=0.9,
+ )
+
+ response_string = ""
+ async for x in response_stream:
+ response_string += x.choices[0].delta.content
+
+ assert response_string == "hello world"
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
+
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY] == 0.1
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY] == 0.2
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "model-id"
+
+ param_id = request.node.callspec.id
+ if "blocks" in param_id:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ }
+ ]
+ else:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "type": "text",
+ "content": "Be concise and clear.",
+ },
+ ]
+
+ assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert "hello world" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT]
+
+ try:
+ import tiktoken # type: ignore # noqa # pylint: disable=unused-import
+
+ if "blocks" in param_id:
+ assert span["data"]["gen_ai.usage.output_tokens"] == 2
+ assert span["data"]["gen_ai.usage.input_tokens"] == 7
+ assert span["data"]["gen_ai.usage.total_tokens"] == 9
+ else:
+ assert span["data"]["gen_ai.usage.output_tokens"] == 2
+ assert span["data"]["gen_ai.usage.input_tokens"] == 12
+ assert span["data"]["gen_ai.usage.total_tokens"] == 14
+
+ except ImportError:
+ pass # if tiktoken is not installed, we can't guarantee token usage will be calculated properly
+
+
+def test_bad_chat_completion(sentry_init, capture_events):
+ sentry_init(integrations=[OpenAIIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.chat.completions._post = mock.Mock(
+ side_effect=OpenAIError("API rate limit reached")
+ )
+ with pytest.raises(OpenAIError):
+ client.chat.completions.create(
+ model="some-model",
+ messages=[{"role": "system", "content": "hello"}],
+ )
+
+ (event,) = events
+ assert event["level"] == "error"
+
+
+def test_span_status_error(sentry_init, capture_events):
+ sentry_init(integrations=[OpenAIIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with start_transaction(name="test"):
+ client = OpenAI(api_key="z")
+ client.chat.completions._post = mock.Mock(
+ side_effect=OpenAIError("API rate limit reached")
+ )
+ with pytest.raises(OpenAIError):
+ client.chat.completions.create(
+ model="some-model", messages=[{"role": "system", "content": "hello"}]
+ )
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+ assert transaction["spans"][0]["status"] == "internal_error"
+ assert transaction["spans"][0]["tags"]["status"] == "internal_error"
+ assert transaction["contexts"]["trace"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_bad_chat_completion_async(sentry_init, capture_events):
+ sentry_init(integrations=[OpenAIIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ client.chat.completions._post = AsyncMock(
+ side_effect=OpenAIError("API rate limit reached")
+ )
+ with pytest.raises(OpenAIError):
+ await client.chat.completions.create(
+ model="some-model", messages=[{"role": "system", "content": "hello"}]
+ )
+
+ (event,) = events
+ assert event["level"] == "error"
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_embeddings_create_no_pii(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+
+ returned_embedding = CreateEmbeddingResponse(
+ data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
+ model="some-model",
+ object="list",
+ usage=EmbeddingTokenUsage(
+ prompt_tokens=20,
+ total_tokens=30,
+ ),
+ )
+
+ client.embeddings._post = mock.Mock(return_value=returned_embedding)
+ with start_transaction(name="openai tx"):
+ response = client.embeddings.create(
+ input="hello", model="text-embedding-3-large"
+ )
+
+ assert len(response.data[0].embedding) == 3
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.embeddings"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-3-large"
+
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
+
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.parametrize(
+ "input",
+ [
+ pytest.param(
+ "hello",
+ id="string",
+ ),
+ pytest.param(
+ ["First text", "Second text", "Third text"],
+ id="string_sequence",
+ ),
+ pytest.param(
+ iter(["First text", "Second text", "Third text"]),
+ id="string_iterable",
+ ),
+ pytest.param(
+ [5, 8, 13, 21, 34],
+ id="tokens",
+ ),
+ pytest.param(
+ iter(
+ [5, 8, 13, 21, 34],
+ ),
+ id="token_iterable",
+ ),
+ pytest.param(
+ [
+ [5, 8, 13, 21, 34],
+ [8, 13, 21, 34, 55],
+ ],
+ id="tokens_sequence",
+ ),
+ pytest.param(
+ iter(
+ [
+ [5, 8, 13, 21, 34],
+ [8, 13, 21, 34, 55],
+ ]
+ ),
+ id="tokens_sequence_iterable",
+ ),
+ ],
+)
+def test_embeddings_create(sentry_init, capture_events, input, request):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+
+ returned_embedding = CreateEmbeddingResponse(
+ data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
+ model="some-model",
+ object="list",
+ usage=EmbeddingTokenUsage(
+ prompt_tokens=20,
+ total_tokens=30,
+ ),
+ )
+
+ client.embeddings._post = mock.Mock(return_value=returned_embedding)
+ with start_transaction(name="openai tx"):
+ response = client.embeddings.create(input=input, model="text-embedding-3-large")
+
+ assert len(response.data[0].embedding) == 3
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.embeddings"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-3-large"
+
+ param_id = request.node.callspec.id
+ if param_id == "string":
+ assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
+ elif param_id == "string_sequence" or param_id == "string_iterable":
+ assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
+ "First text",
+ "Second text",
+ "Third text",
+ ]
+ elif param_id == "tokens" or param_id == "token_iterable":
+ assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
+ 5,
+ 8,
+ 13,
+ 21,
+ 34,
+ ]
+ else:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
+ [5, 8, 13, 21, 34],
+ [8, 13, 21, 34, 55],
+ ]
+
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_embeddings_create_async_no_pii(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+
+ returned_embedding = CreateEmbeddingResponse(
+ data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
+ model="some-model",
+ object="list",
+ usage=EmbeddingTokenUsage(
+ prompt_tokens=20,
+ total_tokens=30,
+ ),
+ )
+
+ client.embeddings._post = AsyncMock(return_value=returned_embedding)
+ with start_transaction(name="openai tx"):
+ response = await client.embeddings.create(
+ input="hello", model="text-embedding-3-large"
+ )
+
+ assert len(response.data[0].embedding) == 3
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.embeddings"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-3-large"
+
+ assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
+
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "input",
+ [
+ pytest.param(
+ "hello",
+ id="string",
+ ),
+ pytest.param(
+ ["First text", "Second text", "Third text"],
+ id="string_sequence",
+ ),
+ pytest.param(
+ iter(["First text", "Second text", "Third text"]),
+ id="string_iterable",
+ ),
+ pytest.param(
+ [5, 8, 13, 21, 34],
+ id="tokens",
+ ),
+ pytest.param(
+ iter(
+ [5, 8, 13, 21, 34],
+ ),
+ id="token_iterable",
+ ),
+ pytest.param(
+ [
+ [5, 8, 13, 21, 34],
+ [8, 13, 21, 34, 55],
+ ],
+ id="tokens_sequence",
+ ),
+ pytest.param(
+ iter(
+ [
+ [5, 8, 13, 21, 34],
+ [8, 13, 21, 34, 55],
+ ]
+ ),
+ id="tokens_sequence_iterable",
+ ),
+ ],
+)
+async def test_embeddings_create_async(sentry_init, capture_events, input, request):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+
+ returned_embedding = CreateEmbeddingResponse(
+ data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
+ model="some-model",
+ object="list",
+ usage=EmbeddingTokenUsage(
+ prompt_tokens=20,
+ total_tokens=30,
+ ),
+ )
+
+ client.embeddings._post = AsyncMock(return_value=returned_embedding)
+ with start_transaction(name="openai tx"):
+ response = await client.embeddings.create(
+ input=input, model="text-embedding-3-large"
+ )
+
+ assert len(response.data[0].embedding) == 3
+
+ tx = events[0]
+ assert tx["type"] == "transaction"
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.embeddings"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-3-large"
+
+ param_id = request.node.callspec.id
+ if param_id == "string":
+ assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
+ elif param_id == "string_sequence" or param_id == "string_iterable":
+ assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
+ "First text",
+ "Second text",
+ "Third text",
+ ]
+ elif param_id == "tokens" or param_id == "token_iterable":
+ assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
+ 5,
+ 8,
+ 13,
+ 21,
+ 34,
+ ]
+ else:
+ assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
+ [5, 8, 13, 21, 34],
+ [8, 13, 21, 34, 55],
+ ]
+
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+def test_embeddings_create_raises_error(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+
+ client.embeddings._post = mock.Mock(
+ side_effect=OpenAIError("API rate limit reached")
+ )
+
+ with pytest.raises(OpenAIError):
+ client.embeddings.create(input="hello", model="text-embedding-3-large")
+
+ (event,) = events
+ assert event["level"] == "error"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+async def test_embeddings_create_raises_error_async(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+
+ client.embeddings._post = AsyncMock(
+ side_effect=OpenAIError("API rate limit reached")
+ )
+
+ with pytest.raises(OpenAIError):
+ await client.embeddings.create(input="hello", model="text-embedding-3-large")
+
+ (event,) = events
+ assert event["level"] == "error"
+
+
+def test_span_origin_nonstreaming_chat(sentry_init, capture_events):
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ with start_transaction(name="openai tx"):
+ client.chat.completions.create(
+ model="some-model", messages=[{"role": "system", "content": "hello"}]
+ )
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.openai"
+
+
+@pytest.mark.asyncio
+async def test_span_origin_nonstreaming_chat_async(sentry_init, capture_events):
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ client.chat.completions._post = AsyncMock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ with start_transaction(name="openai tx"):
+ await client.chat.completions.create(
+ model="some-model", messages=[{"role": "system", "content": "hello"}]
+ )
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.openai"
+
+
+def test_span_origin_streaming_chat(sentry_init, capture_events):
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ returned_stream = Stream(cast_to=None, response=None, client=client)
+ returned_stream._iterator = [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0, delta=ChoiceDelta(content="hel"), finish_reason=None
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=1, delta=ChoiceDelta(content="lo "), finish_reason=None
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=2, delta=ChoiceDelta(content="world"), finish_reason="stop"
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ]
+
+ client.chat.completions._post = mock.Mock(return_value=returned_stream)
+ with start_transaction(name="openai tx"):
+ response_stream = client.chat.completions.create(
+ model="some-model", messages=[{"role": "system", "content": "hello"}]
+ )
+
+ "".join(map(lambda x: x.choices[0].delta.content, response_stream))
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.openai"
+
+
+@pytest.mark.asyncio
+async def test_span_origin_streaming_chat_async(
+ sentry_init, capture_events, async_iterator
+):
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ returned_stream = AsyncStream(cast_to=None, response=None, client=client)
+ returned_stream._iterator = async_iterator(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0, delta=ChoiceDelta(content="hel"), finish_reason=None
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=1, delta=ChoiceDelta(content="lo "), finish_reason=None
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=2,
+ delta=ChoiceDelta(content="world"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ]
+ )
+
+ client.chat.completions._post = AsyncMock(return_value=returned_stream)
+ with start_transaction(name="openai tx"):
+ response_stream = await client.chat.completions.create(
+ model="some-model", messages=[{"role": "system", "content": "hello"}]
+ )
+ async for _ in response_stream:
+ pass
+
+ # "".join(map(lambda x: x.choices[0].delta.content, response_stream))
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.openai"
+
+
+def test_span_origin_embeddings(sentry_init, capture_events):
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+
+ returned_embedding = CreateEmbeddingResponse(
+ data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
+ model="some-model",
+ object="list",
+ usage=EmbeddingTokenUsage(
+ prompt_tokens=20,
+ total_tokens=30,
+ ),
+ )
+
+ client.embeddings._post = mock.Mock(return_value=returned_embedding)
+ with start_transaction(name="openai tx"):
+ client.embeddings.create(input="hello", model="text-embedding-3-large")
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.openai"
+
+
+@pytest.mark.asyncio
+async def test_span_origin_embeddings_async(sentry_init, capture_events):
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+
+ returned_embedding = CreateEmbeddingResponse(
+ data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
+ model="some-model",
+ object="list",
+ usage=EmbeddingTokenUsage(
+ prompt_tokens=20,
+ total_tokens=30,
+ ),
+ )
+
+ client.embeddings._post = AsyncMock(return_value=returned_embedding)
+ with start_transaction(name="openai tx"):
+ await client.embeddings.create(input="hello", model="text-embedding-3-large")
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.ai.openai"
+
+
+def test_completions_token_usage_from_response():
+ """Token counts are extracted from response.usage using Completions API field names."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ response.usage = mock.MagicMock()
+ response.usage.completion_tokens = 10
+ response.usage.prompt_tokens = 20
+ response.usage.total_tokens = 30
+ messages = []
+ streaming_message_responses = []
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_completions_token_usage(
+ messages=messages,
+ response=response,
+ span=span,
+ streaming_message_responses=streaming_message_responses,
+ streaming_message_total_token_usage=None,
+ count_tokens=count_tokens,
+ )
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=20,
+ input_tokens_cached=None,
+ output_tokens=10,
+ output_tokens_reasoning=None,
+ total_tokens=30,
+ )
+
+
+def test_completions_token_usage_with_detailed_fields():
+ """Cached and reasoning token counts are extracted from prompt_tokens_details and completion_tokens_details."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ response.usage = mock.MagicMock()
+ response.usage.prompt_tokens = 20
+ response.usage.prompt_tokens_details = mock.MagicMock()
+ response.usage.prompt_tokens_details.cached_tokens = 5
+ response.usage.completion_tokens = 10
+ response.usage.completion_tokens_details = mock.MagicMock()
+ response.usage.completion_tokens_details.reasoning_tokens = 8
+ response.usage.total_tokens = 30
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_completions_token_usage(
+ messages=[],
+ response=response,
+ span=span,
+ streaming_message_responses=[],
+ streaming_message_total_token_usage=None,
+ count_tokens=count_tokens,
+ )
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=20,
+ input_tokens_cached=5,
+ output_tokens=10,
+ output_tokens_reasoning=8,
+ total_tokens=30,
+ )
+
+
+def test_completions_token_usage_manual_input_counting():
+ """When prompt_tokens is missing, input tokens are counted manually from messages."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ response.usage = mock.MagicMock()
+ response.usage.completion_tokens = 10
+ response.usage.total_tokens = 10
+ messages = [
+ {"content": "one"},
+ {"content": "two"},
+ {"content": "three"},
+ ]
+ streaming_message_responses = []
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_completions_token_usage(
+ messages=messages,
+ response=response,
+ span=span,
+ streaming_message_responses=streaming_message_responses,
+ streaming_message_total_token_usage=None,
+ count_tokens=count_tokens,
+ )
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=11,
+ input_tokens_cached=None,
+ output_tokens=10,
+ output_tokens_reasoning=None,
+ total_tokens=10,
+ )
+
+
+def test_completions_token_usage_manual_output_counting_streaming():
+ """When completion_tokens is missing, output tokens are counted from streaming responses."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ response.usage = mock.MagicMock()
+ response.usage.prompt_tokens = 20
+ response.usage.total_tokens = 20
+ messages = []
+ streaming_message_responses = [
+ "one",
+ "two",
+ "three",
+ ]
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_completions_token_usage(
+ messages=messages,
+ response=response,
+ span=span,
+ streaming_message_responses=streaming_message_responses,
+ streaming_message_total_token_usage=None,
+ count_tokens=count_tokens,
+ )
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=20,
+ input_tokens_cached=None,
+ output_tokens=11,
+ output_tokens_reasoning=None,
+ total_tokens=20,
+ )
+
+
+def test_completions_token_usage_manual_output_counting_choices():
+ """When completion_tokens is missing, output tokens are counted from response.choices."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ response.usage = mock.MagicMock()
+ response.usage.prompt_tokens = 20
+ response.usage.total_tokens = 20
+ response.choices = [
+ Choice(
+ index=0,
+ finish_reason="stop",
+ message=ChatCompletionMessage(role="assistant", content="one"),
+ ),
+ Choice(
+ index=1,
+ finish_reason="stop",
+ message=ChatCompletionMessage(role="assistant", content="two"),
+ ),
+ Choice(
+ index=2,
+ finish_reason="stop",
+ message=ChatCompletionMessage(role="assistant", content="three"),
+ ),
+ ]
+ messages = []
+ streaming_message_responses = None
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_completions_token_usage(
+ messages=messages,
+ response=response,
+ span=span,
+ streaming_message_responses=streaming_message_responses,
+ streaming_message_total_token_usage=None,
+ count_tokens=count_tokens,
+ )
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=20,
+ input_tokens_cached=None,
+ output_tokens=11,
+ output_tokens_reasoning=None,
+ total_tokens=20,
+ )
+
+
+def test_completions_token_usage_no_usage_data():
+ """When response has no usage data and no streaming responses, all tokens are None."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ messages = []
+ streaming_message_responses = None
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_completions_token_usage(
+ messages=messages,
+ response=response,
+ span=span,
+ streaming_message_responses=streaming_message_responses,
+ streaming_message_total_token_usage=None,
+ count_tokens=count_tokens,
+ )
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=None,
+ input_tokens_cached=None,
+ output_tokens=None,
+ output_tokens_reasoning=None,
+ total_tokens=None,
+ )
+
+
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+def test_responses_token_usage_from_response():
+ """Token counts including cached and reasoning tokens are extracted from Responses API."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ response.usage = mock.MagicMock()
+ response.usage.input_tokens = 20
+ response.usage.input_tokens_details = mock.MagicMock()
+ response.usage.input_tokens_details.cached_tokens = 5
+ response.usage.output_tokens = 10
+ response.usage.output_tokens_details = mock.MagicMock()
+ response.usage.output_tokens_details.reasoning_tokens = 8
+ response.usage.total_tokens = 30
+ input = []
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_responses_token_usage(input, response, span, None, count_tokens)
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=20,
+ input_tokens_cached=5,
+ output_tokens=10,
+ output_tokens_reasoning=8,
+ total_tokens=30,
+ )
+
+
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+def test_responses_token_usage_no_usage_data():
+ """When Responses API response has no usage data, all tokens are None."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ response.usage = None
+ input = []
+ streaming_message_responses = None
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_responses_token_usage(
+ input, response, span, streaming_message_responses, count_tokens
+ )
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=None,
+ input_tokens_cached=None,
+ output_tokens=None,
+ output_tokens_reasoning=None,
+ total_tokens=None,
+ )
+
+
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+def test_responses_token_usage_manual_output_counting_response_output():
+ """When output_tokens is missing, output tokens are counted from response.output."""
+ span = mock.MagicMock()
+
+ def count_tokens(msg):
+ return len(str(msg))
+
+ response = mock.MagicMock()
+ response.usage = mock.MagicMock()
+ response.usage.input_tokens = 20
+ response.usage.total_tokens = 20
+ response.output = [
+ ResponseOutputMessage(
+ id="msg-1",
+ content=[
+ ResponseOutputText(
+ annotations=[],
+ text="one",
+ type="output_text",
+ ),
+ ],
+ role="assistant",
+ status="completed",
+ type="message",
+ ),
+ ResponseOutputMessage(
+ id="msg-2",
+ content=[
+ ResponseOutputText(
+ annotations=[],
+ text="two",
+ type="output_text",
+ ),
+ ResponseOutputText(
+ annotations=[],
+ text="three",
+ type="output_text",
+ ),
+ ],
+ role="assistant",
+ status="completed",
+ type="message",
+ ),
+ ]
+ input = []
+ streaming_message_responses = None
+
+ with mock.patch(
+ "sentry_sdk.integrations.openai.record_token_usage"
+ ) as mock_record_token_usage:
+ _calculate_responses_token_usage(
+ input, response, span, streaming_message_responses, count_tokens
+ )
+ mock_record_token_usage.assert_called_once_with(
+ span,
+ input_tokens=20,
+ input_tokens_cached=None,
+ output_tokens=11,
+ output_tokens_reasoning=None,
+ total_tokens=20,
+ )
+
+
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+def test_ai_client_span_responses_api_no_pii(sentry_init, capture_events):
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.responses._post = mock.Mock(return_value=EXAMPLE_RESPONSE)
+
+ with start_transaction(name="openai tx"):
+ client.responses.create(
+ model="gpt-4o",
+ instructions="You are a coding assistant that talks like a pirate.",
+ input="How do I check if a Python object is an instance of a class?",
+ max_output_tokens=100,
+ temperature=0.7,
+ top_p=0.9,
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ assert len(spans) == 1
+ assert spans[0]["op"] == "gen_ai.responses"
+ assert spans[0]["origin"] == "auto.ai.openai"
+ assert spans[0]["data"] == {
+ "gen_ai.operation.name": "responses",
+ "gen_ai.request.max_tokens": 100,
+ "gen_ai.request.temperature": 0.7,
+ "gen_ai.request.top_p": 0.9,
+ "gen_ai.request.model": "gpt-4o",
+ "gen_ai.response.model": "response-model-id",
+ "gen_ai.response.streaming": False,
+ "gen_ai.system": "openai",
+ "gen_ai.usage.input_tokens": 20,
+ "gen_ai.usage.input_tokens.cached": 5,
+ "gen_ai.usage.output_tokens": 10,
+ "gen_ai.usage.output_tokens.reasoning": 8,
+ "gen_ai.usage.total_tokens": 30,
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ assert "gen_ai.system_instructions" not in spans[0]["data"]
+ assert "gen_ai.request.messages" not in spans[0]["data"]
+ assert "gen_ai.response.text" not in spans[0]["data"]
+
+
+@pytest.mark.parametrize(
+ "instructions",
+ (
+ omit,
+ None,
+ "You are a coding assistant that talks like a pirate.",
+ ),
+)
+@pytest.mark.parametrize(
+ "input",
+ [
+ pytest.param(
+ "How do I check if a Python object is an instance of a class?", id="string"
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="blocks_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"type": "message", "role": "user", "content": "hello"},
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="parts_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"type": "message", "role": "user", "content": "hello"},
+ ],
+ id="parts",
+ ),
+ ],
+)
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+def test_ai_client_span_responses_api(
+ sentry_init, capture_events, instructions, input, request
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.responses._post = mock.Mock(return_value=EXAMPLE_RESPONSE)
+
+ with start_transaction(name="openai tx"):
+ client.responses.create(
+ model="gpt-4o",
+ instructions=instructions,
+ input=input,
+ max_output_tokens=100,
+ temperature=0.7,
+ top_p=0.9,
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ assert len(spans) == 1
+ assert spans[0]["op"] == "gen_ai.responses"
+ assert spans[0]["origin"] == "auto.ai.openai"
+
+ expected_data = {
+ "gen_ai.operation.name": "responses",
+ "gen_ai.request.max_tokens": 100,
+ "gen_ai.request.temperature": 0.7,
+ "gen_ai.request.top_p": 0.9,
+ "gen_ai.system": "openai",
+ "gen_ai.response.model": "response-model-id",
+ "gen_ai.response.streaming": False,
+ "gen_ai.usage.input_tokens": 20,
+ "gen_ai.usage.input_tokens.cached": 5,
+ "gen_ai.usage.output_tokens": 10,
+ "gen_ai.usage.output_tokens.reasoning": 8,
+ "gen_ai.usage.total_tokens": 30,
+ "gen_ai.request.model": "gpt-4o",
+ "gen_ai.response.text": "the model response",
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ param_id = request.node.callspec.id
+ if "string" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.request.messages": safe_serialize(
+ ["How do I check if a Python object is an instance of a class?"]
+ ),
+ }
+ )
+ elif "string" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ }
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ ["How do I check if a Python object is an instance of a class?"]
+ ),
+ }
+ )
+ elif "blocks_no_type" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [{"type": "text", "content": "You are a helpful assistant."}]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks_no_type" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [{"type": "text", "content": "You are a helpful assistant."}]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "parts_no_type" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "parts_no_type" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif instructions is None or isinstance(instructions, Omit): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ else:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+
+ assert spans[0]["data"] == expected_data
+
+
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+def test_error_in_responses_api(sentry_init, capture_events):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.responses._post = mock.Mock(
+ side_effect=OpenAIError("API rate limit reached")
+ )
+
+ with start_transaction(name="openai tx"):
+ with pytest.raises(OpenAIError):
+ client.responses.create(
+ model="gpt-4o",
+ instructions="You are a coding assistant that talks like a pirate.",
+ input="How do I check if a Python object is an instance of a class?",
+ )
+
+ (error_event, transaction_event) = events
+
+ assert transaction_event["type"] == "transaction"
+ # make sure the span where the error occurred is captured
+ assert transaction_event["spans"][0]["op"] == "gen_ai.responses"
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "OpenAIError"
+
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+@pytest.mark.parametrize(
+ "instructions",
+ (
+ omit,
+ None,
+ "You are a coding assistant that talks like a pirate.",
+ ),
+)
+@pytest.mark.parametrize(
+ "input",
+ [
+ pytest.param(
+ "How do I check if a Python object is an instance of a class?", id="string"
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="blocks_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"type": "message", "role": "user", "content": "hello"},
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="parts_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"type": "message", "role": "user", "content": "hello"},
+ ],
+ id="parts",
+ ),
+ ],
+)
+async def test_ai_client_span_responses_async_api(
+ sentry_init, capture_events, instructions, input, request
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ client.responses._post = AsyncMock(return_value=EXAMPLE_RESPONSE)
+
+ with start_transaction(name="openai tx"):
+ await client.responses.create(
+ model="gpt-4o",
+ instructions=instructions,
+ input=input,
+ max_output_tokens=100,
+ temperature=0.7,
+ top_p=0.9,
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ assert len(spans) == 1
+ assert spans[0]["op"] == "gen_ai.responses"
+ assert spans[0]["origin"] == "auto.ai.openai"
+
+ expected_data = {
+ "gen_ai.operation.name": "responses",
+ "gen_ai.request.max_tokens": 100,
+ "gen_ai.request.temperature": 0.7,
+ "gen_ai.request.top_p": 0.9,
+ "gen_ai.request.messages": '["How do I check if a Python object is an instance of a class?"]',
+ "gen_ai.request.model": "gpt-4o",
+ "gen_ai.response.model": "response-model-id",
+ "gen_ai.response.streaming": False,
+ "gen_ai.system": "openai",
+ "gen_ai.usage.input_tokens": 20,
+ "gen_ai.usage.input_tokens.cached": 5,
+ "gen_ai.usage.output_tokens": 10,
+ "gen_ai.usage.output_tokens.reasoning": 8,
+ "gen_ai.usage.total_tokens": 30,
+ "gen_ai.response.text": "the model response",
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ param_id = request.node.callspec.id
+ if "string" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.request.messages": safe_serialize(
+ ["How do I check if a Python object is an instance of a class?"]
+ ),
+ }
+ )
+ elif "string" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ }
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ ["How do I check if a Python object is an instance of a class?"]
+ ),
+ }
+ )
+ elif "blocks_no_type" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [{"type": "text", "content": "You are a helpful assistant."}]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks_no_type" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [{"type": "text", "content": "You are a helpful assistant."}]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "parts_no_type" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "parts_no_type" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif instructions is None or isinstance(instructions, Omit): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ else:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+
+ assert spans[0]["data"] == expected_data
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "instructions",
+ (
+ omit,
+ None,
+ "You are a coding assistant that talks like a pirate.",
+ ),
+)
+@pytest.mark.parametrize(
+ "input",
+ [
+ pytest.param(
+ "How do I check if a Python object is an instance of a class?", id="string"
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="blocks_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"type": "message", "role": "user", "content": "hello"},
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"role": "user", "content": "hello"},
+ ],
+ id="parts_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {"type": "message", "role": "user", "content": "hello"},
+ ],
+ id="parts",
+ ),
+ ],
+)
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+async def test_ai_client_span_streaming_responses_async_api(
+ sentry_init,
+ capture_events,
+ instructions,
+ input,
+ request,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ returned_stream = get_model_response(
+ async_iterator(server_side_event_chunks(EXAMPLE_RESPONSES_STREAM))
+ )
+
+ with mock.patch.object(
+ client.responses._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ result = await client.responses.create(
+ model="gpt-4o",
+ instructions=instructions,
+ input=input,
+ stream=True,
+ max_output_tokens=100,
+ temperature=0.7,
+ top_p=0.9,
+ )
+ async for _ in result:
+ pass
+
+ (transaction,) = events
+ spans = [span for span in transaction["spans"] if span["op"] == OP.GEN_AI_RESPONSES]
+
+ assert len(spans) == 1
+ assert spans[0]["origin"] == "auto.ai.openai"
+
+ expected_data = {
+ "gen_ai.operation.name": "responses",
+ "gen_ai.request.max_tokens": 100,
+ "gen_ai.request.temperature": 0.7,
+ "gen_ai.request.top_p": 0.9,
+ "gen_ai.response.model": "response-model-id",
+ "gen_ai.response.streaming": True,
+ "gen_ai.system": "openai",
+ "gen_ai.response.time_to_first_token": mock.ANY,
+ "gen_ai.usage.input_tokens": 20,
+ "gen_ai.usage.input_tokens.cached": 5,
+ "gen_ai.usage.output_tokens": 10,
+ "gen_ai.usage.output_tokens.reasoning": 8,
+ "gen_ai.usage.total_tokens": 30,
+ "gen_ai.request.model": "gpt-4o",
+ "gen_ai.response.text": "hello world",
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ }
+
+ param_id = request.node.callspec.id
+ if "string" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.request.messages": safe_serialize(
+ ["How do I check if a Python object is an instance of a class?"]
+ ),
+ }
+ )
+ elif "string" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ }
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ ["How do I check if a Python object is an instance of a class?"]
+ ),
+ }
+ )
+ elif "blocks_no_type" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [{"type": "text", "content": "You are a helpful assistant."}]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks_no_type" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [{"type": "text", "content": "You are a helpful assistant."}]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "blocks" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "parts_no_type" in param_id and (
+ instructions is None or isinstance(instructions, Omit)
+ ): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif "parts_no_type" in param_id:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ elif instructions is None or isinstance(instructions, Omit): # type: ignore
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+ else:
+ expected_data.update(
+ {
+ "gen_ai.system_instructions": safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ ),
+ "gen_ai.request.messages": safe_serialize(
+ [{"type": "message", "role": "user", "content": "hello"}]
+ ),
+ }
+ )
+
+ assert spans[0]["data"] == expected_data
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+async def test_error_in_responses_async_api(sentry_init, capture_events):
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ client.responses._post = AsyncMock(
+ side_effect=OpenAIError("API rate limit reached")
+ )
+
+ with start_transaction(name="openai tx"):
+ with pytest.raises(OpenAIError):
+ await client.responses.create(
+ model="gpt-4o",
+ instructions="You are a coding assistant that talks like a pirate.",
+ input="How do I check if a Python object is an instance of a class?",
+ )
+
+ (error_event, transaction_event) = events
+
+ assert transaction_event["type"] == "transaction"
+ # make sure the span where the error occurred is captured
+ assert transaction_event["spans"][0]["op"] == "gen_ai.responses"
+
+ assert error_event["level"] == "error"
+ assert error_event["exception"]["values"][0]["type"] == "OpenAIError"
+
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == transaction_event["contexts"]["trace"]["trace_id"]
+ )
+
+
+if SKIP_RESPONSES_TESTS:
+ EXAMPLE_RESPONSES_STREAM = []
+else:
+ EXAMPLE_RESPONSES_STREAM = [
+ ResponseCreatedEvent(
+ sequence_number=1,
+ type="response.created",
+ response=Response(
+ id="chat-id",
+ created_at=10000000,
+ model="response-model-id",
+ object="response",
+ output=[],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ ),
+ ),
+ ResponseTextDeltaEvent(
+ item_id="msg_1",
+ sequence_number=2,
+ type="response.output_text.delta",
+ logprobs=[],
+ content_index=0,
+ output_index=0,
+ delta="hel",
+ ),
+ ResponseTextDeltaEvent(
+ item_id="msg_1",
+ sequence_number=3,
+ type="response.output_text.delta",
+ logprobs=[],
+ content_index=0,
+ output_index=0,
+ delta="lo ",
+ ),
+ ResponseTextDeltaEvent(
+ item_id="msg_1",
+ sequence_number=4,
+ type="response.output_text.delta",
+ logprobs=[],
+ content_index=0,
+ output_index=0,
+ delta="world",
+ ),
+ ResponseCompletedEvent(
+ sequence_number=5,
+ type="response.completed",
+ response=Response(
+ id="chat-id",
+ created_at=10000000,
+ model="response-model-id",
+ object="response",
+ output=[],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ usage=ResponseUsage(
+ input_tokens=20,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=5,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=8,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ ),
+ ]
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+def test_streaming_responses_api(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ server_side_event_chunks,
+):
+ sentry_init(
+ integrations=[
+ OpenAIIntegration(
+ include_prompts=include_prompts,
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ returned_stream = get_model_response(
+ server_side_event_chunks(
+ EXAMPLE_RESPONSES_STREAM,
+ )
+ )
+
+ with mock.patch.object(
+ client.responses._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = client.responses.create(
+ model="some-model",
+ input="hello",
+ stream=True,
+ max_output_tokens=100,
+ temperature=0.7,
+ top_p=0.9,
+ )
+
+ response_string = ""
+ for item in response_stream:
+ if hasattr(item, "delta"):
+ response_string += item.delta
+
+ assert response_string == "hello world"
+
+ (transaction,) = events
+ (span,) = transaction["spans"]
+ assert span["op"] == "gen_ai.responses"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "response-model-id"
+
+ if send_default_pii and include_prompts:
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] == '["hello"]'
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "hello world"
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [(True, True), (True, False), (False, True), (False, False)],
+)
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+async def test_streaming_responses_api_async(
+ sentry_init,
+ capture_events,
+ send_default_pii,
+ include_prompts,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ sentry_init(
+ integrations=[
+ OpenAIIntegration(
+ include_prompts=include_prompts,
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ returned_stream = get_model_response(
+ async_iterator(server_side_event_chunks(EXAMPLE_RESPONSES_STREAM))
+ )
+
+ with mock.patch.object(
+ client.responses._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = await client.responses.create(
+ model="some-model",
+ input="hello",
+ stream=True,
+ max_output_tokens=100,
+ temperature=0.7,
+ top_p=0.9,
+ )
+
+ response_string = ""
+ async for item in response_stream:
+ if hasattr(item, "delta"):
+ response_string += item.delta
+
+ assert response_string == "hello world"
+
+ (transaction,) = events
+ (span,) = transaction["spans"]
+ assert span["op"] == "gen_ai.responses"
+ assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "openai"
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MAX_TOKENS] == 100
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TEMPERATURE] == 0.7
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_TOP_P] == 0.9
+
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "response-model-id"
+
+ if send_default_pii and include_prompts:
+ assert span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] == '["hello"]'
+ assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "hello world"
+ else:
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
+ assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
+
+ assert span["data"]["gen_ai.usage.input_tokens"] == 20
+ assert span["data"]["gen_ai.usage.output_tokens"] == 10
+ assert span["data"]["gen_ai.usage.total_tokens"] == 30
+
+
+@pytest.mark.skipif(
+ OPENAI_VERSION <= (1, 1, 0),
+ reason="OpenAI versions <=1.1.0 do not support the tools parameter.",
+)
+@pytest.mark.parametrize(
+ "tools",
+ [[], None, NOT_GIVEN, omit],
+)
+def test_empty_tools_in_chat_completion(sentry_init, capture_events, tools):
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ with start_transaction(name="openai tx"):
+ client.chat.completions.create(
+ model="some-model",
+ messages=[{"role": "system", "content": "hello"}],
+ tools=tools,
+ )
+
+ (event,) = events
+ span = event["spans"][0]
+
+ assert "gen_ai.request.available_tools" not in span["data"]
+
+
+# Test messages with mixed roles including "ai" that should be mapped to "assistant"
+@pytest.mark.parametrize(
+ "test_message,expected_role",
+ [
+ ({"role": "user", "content": "Hello"}, "user"),
+ (
+ {"role": "ai", "content": "Hi there!"},
+ "assistant",
+ ), # Should be mapped to "assistant"
+ (
+ {"role": "assistant", "content": "How can I help?"},
+ "assistant",
+ ), # Should stay "assistant"
+ ],
+)
+def test_openai_message_role_mapping(
+ sentry_init, capture_events, test_message, expected_role
+):
+ """Test that OpenAI integration properly maps message roles like 'ai' to 'assistant'"""
+
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ test_messages = [test_message]
+
+ with start_transaction(name="openai tx"):
+ client.chat.completions.create(model="test-model", messages=test_messages)
+ # Verify that the span was created correctly
+ (event,) = events
+ span = event["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+
+ # Parse the stored messages
+ import json
+
+ stored_messages = json.loads(span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ assert len(stored_messages) == 1
+ assert stored_messages[0]["role"] == expected_role
+
+
+def test_openai_message_truncation(sentry_init, capture_events):
+ """Test that large messages are truncated properly in OpenAI integration."""
+ sentry_init(
+ integrations=[OpenAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)
+
+ large_content = (
+ "This is a very long message that will exceed our size limits. " * 1000
+ )
+ large_messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": large_content},
+ {"role": "assistant", "content": large_content},
+ {"role": "user", "content": large_content},
+ ]
+
+ with start_transaction(name="openai tx"):
+ client.chat.completions.create(
+ model="some-model",
+ messages=large_messages,
+ )
+
+ (event,) = events
+ span = event["spans"][0]
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]
+
+ messages_data = span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert isinstance(messages_data, str)
+
+ parsed_messages = json.loads(messages_data)
+ assert isinstance(parsed_messages, list)
+ assert len(parsed_messages) <= len(large_messages)
+
+ meta_path = event["_meta"]
+ span_meta = meta_path["spans"]["0"]["data"]
+ messages_meta = span_meta[SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert "len" in messages_meta.get("", {})
+
+
+# noinspection PyTypeChecker
+def test_streaming_chat_completion_ttft(
+ sentry_init, capture_events, get_model_response, server_side_event_chunks
+):
+ """
+ Test that streaming chat completions capture time-to-first-token (TTFT).
+ """
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ returned_stream = get_model_response(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="Hello"),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content=" world"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ],
+ include_event_type=False,
+ ),
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = client.chat.completions.create(
+ model="some-model",
+ messages=[{"role": "user", "content": "Say hello"}],
+ stream=True,
+ )
+ # Consume the stream
+ for _ in response_stream:
+ pass
+
+ (tx,) = events
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+
+ # Verify TTFT is captured
+ assert SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN in span["data"]
+ ttft = span["data"][SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN]
+ assert isinstance(ttft, float)
+ assert ttft > 0
+
+
+# noinspection PyTypeChecker
+@pytest.mark.asyncio
+async def test_streaming_chat_completion_ttft_async(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ """
+ Test that async streaming chat completions capture time-to-first-token (TTFT).
+ """
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ returned_stream = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content="Hello"),
+ finish_reason=None,
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ChatCompletionChunk(
+ id="1",
+ choices=[
+ DeltaChoice(
+ index=0,
+ delta=ChoiceDelta(content=" world"),
+ finish_reason="stop",
+ )
+ ],
+ created=100000,
+ model="model-id",
+ object="chat.completion.chunk",
+ ),
+ ],
+ include_event_type=False,
+ ),
+ )
+ )
+
+ with mock.patch.object(
+ client.chat._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = await client.chat.completions.create(
+ model="some-model",
+ messages=[{"role": "user", "content": "Say hello"}],
+ stream=True,
+ )
+ # Consume the stream
+ async for _ in response_stream:
+ pass
+
+ (tx,) = events
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.chat"
+
+ # Verify TTFT is captured
+ assert SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN in span["data"]
+ ttft = span["data"][SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN]
+ assert isinstance(ttft, float)
+ assert ttft > 0
+
+
+# noinspection PyTypeChecker
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+def test_streaming_responses_api_ttft(
+ sentry_init, capture_events, get_model_response, server_side_event_chunks
+):
+ """
+ Test that streaming responses API captures time-to-first-token (TTFT).
+ """
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = OpenAI(api_key="z")
+ returned_stream = get_model_response(
+ server_side_event_chunks(EXAMPLE_RESPONSES_STREAM)
+ )
+
+ with mock.patch.object(
+ client.responses._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = client.responses.create(
+ model="some-model",
+ input="hello",
+ stream=True,
+ )
+ # Consume the stream
+ for _ in response_stream:
+ pass
+
+ (tx,) = events
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.responses"
+
+ # Verify TTFT is captured
+ assert SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN in span["data"]
+ ttft = span["data"][SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN]
+ assert isinstance(ttft, float)
+ assert ttft > 0
+
+
+# noinspection PyTypeChecker
+@pytest.mark.asyncio
+@pytest.mark.skipif(SKIP_RESPONSES_TESTS, reason="Responses API not available")
+async def test_streaming_responses_api_ttft_async(
+ sentry_init,
+ capture_events,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ """
+ Test that async streaming responses API captures time-to-first-token (TTFT).
+ """
+ sentry_init(
+ integrations=[OpenAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = AsyncOpenAI(api_key="z")
+ returned_stream = get_model_response(
+ async_iterator(server_side_event_chunks(EXAMPLE_RESPONSES_STREAM))
+ )
+
+ with mock.patch.object(
+ client.responses._client._client,
+ "send",
+ return_value=returned_stream,
+ ):
+ with start_transaction(name="openai tx"):
+ response_stream = await client.responses.create(
+ model="some-model",
+ input="hello",
+ stream=True,
+ )
+ # Consume the stream
+ async for _ in response_stream:
+ pass
+
+ (tx,) = events
+ span = tx["spans"][0]
+ assert span["op"] == "gen_ai.responses"
+
+ # Verify TTFT is captured
+ assert SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN in span["data"]
+ ttft = span["data"][SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN]
+ assert isinstance(ttft, float)
+ assert ttft > 0
diff --git a/tests/integrations/openai_agents/__init__.py b/tests/integrations/openai_agents/__init__.py
new file mode 100644
index 0000000000..6940e2bbbe
--- /dev/null
+++ b/tests/integrations/openai_agents/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("agents")
diff --git a/tests/integrations/openai_agents/test_openai_agents.py b/tests/integrations/openai_agents/test_openai_agents.py
new file mode 100644
index 0000000000..7310e86df5
--- /dev/null
+++ b/tests/integrations/openai_agents/test_openai_agents.py
@@ -0,0 +1,3479 @@
+import asyncio
+import pytest
+from unittest.mock import MagicMock, patch
+import os
+import json
+import logging
+import httpx
+
+import sentry_sdk
+from sentry_sdk import start_span
+from sentry_sdk.consts import SPANDATA, OP
+from sentry_sdk.integrations.logging import LoggingIntegration
+from sentry_sdk.integrations.openai_agents import OpenAIAgentsIntegration
+from sentry_sdk.integrations.openai_agents.utils import _set_input_data, safe_serialize
+from sentry_sdk.utils import parse_version, package_version
+
+from openai import AsyncOpenAI, InternalServerError
+from agents.models.openai_responses import OpenAIResponsesModel
+
+from unittest import mock
+
+import agents
+from agents import (
+ Agent,
+ ModelResponse,
+ Usage,
+ ModelSettings,
+)
+from agents.items import (
+ McpCall,
+ ResponseOutputMessage,
+ ResponseOutputText,
+ ResponseFunctionToolCall,
+)
+from agents.tool import HostedMCPTool
+from agents.exceptions import MaxTurnsExceeded, ModelBehaviorError
+from agents.version import __version__ as OPENAI_AGENTS_VERSION
+
+OPENAI_VERSION = package_version("openai")
+
+from openai.types.responses import (
+ ResponseCreatedEvent,
+ ResponseTextDeltaEvent,
+ ResponseCompletedEvent,
+ Response,
+ ResponseUsage,
+)
+from openai.types.responses.response_usage import (
+ InputTokensDetails,
+ OutputTokensDetails,
+)
+
+
+test_run_config = agents.RunConfig(tracing_disabled=True)
+
+EXAMPLE_RESPONSE = Response(
+ id="chat-id",
+ output=[
+ ResponseOutputMessage(
+ id="message-id",
+ content=[
+ ResponseOutputText(
+ annotations=[],
+ text="the model response",
+ type="output_text",
+ ),
+ ],
+ role="assistant",
+ status="completed",
+ type="message",
+ ),
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="response-model-id",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=20,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=5,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=8,
+ ),
+ total_tokens=30,
+ ),
+)
+
+
+@pytest.fixture
+def mock_usage():
+ return Usage(
+ requests=1,
+ input_tokens=10,
+ output_tokens=20,
+ total_tokens=30,
+ input_tokens_details=InputTokensDetails(cached_tokens=0),
+ output_tokens_details=OutputTokensDetails(reasoning_tokens=5),
+ )
+
+
+@pytest.fixture
+def test_agent():
+ """Create a real Agent instance for testing."""
+ return Agent(
+ name="test_agent",
+ instructions="You are a helpful test assistant.",
+ model="gpt-4",
+ model_settings=ModelSettings(
+ max_tokens=100,
+ temperature=0.7,
+ top_p=1.0,
+ presence_penalty=0.0,
+ frequency_penalty=0.0,
+ ),
+ )
+
+
+@pytest.fixture
+def test_agent_with_instructions():
+ def inner(instructions):
+ """Create a real Agent instance for testing."""
+ return Agent(
+ name="test_agent",
+ instructions=instructions,
+ model="gpt-4",
+ model_settings=ModelSettings(
+ max_tokens=100,
+ temperature=0.7,
+ top_p=1.0,
+ presence_penalty=0.0,
+ frequency_penalty=0.0,
+ ),
+ )
+
+ return inner
+
+
+@pytest.fixture
+def test_agent_custom_model():
+ """Create a real Agent instance for testing."""
+ return Agent(
+ name="test_agent_custom_model",
+ instructions="You are a helpful test assistant.",
+ # the model could be agents.OpenAIChatCompletionsModel()
+ model="my-custom-model",
+ model_settings=ModelSettings(
+ max_tokens=100,
+ temperature=0.7,
+ top_p=1.0,
+ presence_penalty=0.0,
+ frequency_penalty=0.0,
+ ),
+ )
+
+
+@pytest.mark.asyncio
+async def test_agent_invocation_span_no_pii(
+ sentry_init,
+ capture_events,
+ test_agent,
+ nonstreaming_responses_model_response,
+ get_model_response,
+):
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ response = get_model_response(
+ nonstreaming_responses_model_response, serialize_pydantic=True
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent, "Test input", run_config=test_run_config
+ )
+
+ assert result is not None
+ assert result.final_output == "Hello, how can I help you?"
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span = next(
+ span for span in spans if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ )
+ ai_client_span = next(span for span in spans if span["op"] == OP.GEN_AI_CHAT)
+
+ assert transaction["transaction"] == "test_agent workflow"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.openai_agents"
+
+ assert invoke_agent_span["description"] == "invoke_agent test_agent"
+
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in invoke_agent_span["data"]
+ assert "gen_ai.request.messages" not in invoke_agent_span["data"]
+ assert "gen_ai.response.text" not in invoke_agent_span["data"]
+
+ assert invoke_agent_span["data"]["gen_ai.operation.name"] == "invoke_agent"
+ assert invoke_agent_span["data"]["gen_ai.system"] == "openai"
+ assert invoke_agent_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert invoke_agent_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert invoke_agent_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert invoke_agent_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert invoke_agent_span["data"]["gen_ai.request.top_p"] == 1.0
+
+ assert ai_client_span["description"] == "chat gpt-4"
+ assert ai_client_span["data"]["gen_ai.operation.name"] == "chat"
+ assert ai_client_span["data"]["gen_ai.system"] == "openai"
+ assert ai_client_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert ai_client_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert ai_client_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert ai_client_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert ai_client_span["data"]["gen_ai.request.top_p"] == 1.0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "instructions",
+ (
+ None,
+ "You are a coding assistant that talks like a pirate.",
+ ),
+)
+@pytest.mark.parametrize(
+ "input",
+ [
+ pytest.param("Test input", id="string"),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "role": "user",
+ "content": "Test input",
+ },
+ ],
+ id="blocks_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "type": "message",
+ "role": "user",
+ "content": "Test input",
+ },
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {
+ "role": "user",
+ "content": "Test input",
+ },
+ ],
+ id="parts_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {
+ "type": "message",
+ "role": "user",
+ "content": "Test input",
+ },
+ ],
+ id="parts",
+ ),
+ ],
+)
+async def test_agent_invocation_span(
+ sentry_init,
+ capture_events,
+ test_agent_with_instructions,
+ nonstreaming_responses_model_response,
+ instructions,
+ input,
+ request,
+ get_model_response,
+):
+ """
+ Test that the integration creates spans for agent invocations.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent_with_instructions(instructions).clone(model=model)
+
+ response = get_model_response(
+ nonstreaming_responses_model_response, serialize_pydantic=True
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent,
+ input,
+ run_config=test_run_config,
+ )
+
+ assert result is not None
+ assert result.final_output == "Hello, how can I help you?"
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span, ai_client_span = spans
+
+ assert transaction["transaction"] == "test_agent workflow"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.openai_agents"
+
+ assert invoke_agent_span["description"] == "invoke_agent test_agent"
+
+ # Only first case checks "gen_ai.request.messages" until further input handling work.
+ param_id = request.node.callspec.id
+ if "string" in param_id and instructions is None: # type: ignore
+ assert "gen_ai.system_instructions" not in ai_client_span["data"]
+
+ assert invoke_agent_span["data"]["gen_ai.request.messages"] == safe_serialize(
+ [
+ {"content": [{"text": "Test input", "type": "text"}], "role": "user"},
+ ]
+ )
+
+ elif "string" in param_id:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ ]
+ )
+ elif "blocks_no_type" in param_id and instructions is None: # type: ignore
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ )
+ elif "blocks_no_type" in param_id:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ )
+ elif "blocks" in param_id and instructions is None: # type: ignore
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ )
+ elif "blocks" in param_id:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ )
+ elif "parts_no_type" in param_id and instructions is None:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ )
+ elif "parts_no_type" in param_id:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ )
+ elif instructions is None: # type: ignore
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ )
+ else:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ )
+
+ assert (
+ invoke_agent_span["data"]["gen_ai.response.text"]
+ == "Hello, how can I help you?"
+ )
+
+ assert invoke_agent_span["data"]["gen_ai.operation.name"] == "invoke_agent"
+ assert invoke_agent_span["data"]["gen_ai.system"] == "openai"
+ assert invoke_agent_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert invoke_agent_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert invoke_agent_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert invoke_agent_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert invoke_agent_span["data"]["gen_ai.request.top_p"] == 1.0
+
+ assert ai_client_span["description"] == "chat gpt-4"
+ assert ai_client_span["data"]["gen_ai.operation.name"] == "chat"
+ assert ai_client_span["data"]["gen_ai.system"] == "openai"
+ assert ai_client_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert ai_client_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert ai_client_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert ai_client_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert ai_client_span["data"]["gen_ai.request.top_p"] == 1.0
+
+
+@pytest.mark.asyncio
+async def test_client_span_custom_model(
+ sentry_init,
+ capture_events,
+ test_agent_custom_model,
+ nonstreaming_responses_model_response,
+ get_model_response,
+):
+ """
+ Test that the integration uses the correct model name if a custom model is used.
+ """
+
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="my-custom-model", openai_client=client)
+ agent = test_agent_custom_model.clone(model=model)
+
+ response = get_model_response(
+ nonstreaming_responses_model_response, serialize_pydantic=True
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent, "Test input", run_config=test_run_config
+ )
+
+ assert result is not None
+ assert result.final_output == "Hello, how can I help you?"
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ ai_client_span = next(span for span in spans if span["op"] == OP.GEN_AI_CHAT)
+
+ assert ai_client_span["description"] == "chat my-custom-model"
+ assert ai_client_span["data"]["gen_ai.request.model"] == "my-custom-model"
+
+
+def test_agent_invocation_span_sync_no_pii(
+ sentry_init,
+ capture_events,
+ test_agent,
+ nonstreaming_responses_model_response,
+ get_model_response,
+):
+ """
+ Test that the integration creates spans for agent invocations.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ response = get_model_response(
+ nonstreaming_responses_model_response, serialize_pydantic=True
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+
+ events = capture_events()
+
+ result = agents.Runner.run_sync(agent, "Test input", run_config=test_run_config)
+
+ assert result is not None
+ assert result.final_output == "Hello, how can I help you?"
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span = next(
+ span for span in spans if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ )
+ ai_client_span = next(span for span in spans if span["op"] == OP.GEN_AI_CHAT)
+
+ assert transaction["transaction"] == "test_agent workflow"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.openai_agents"
+
+ assert invoke_agent_span["description"] == "invoke_agent test_agent"
+ assert invoke_agent_span["data"]["gen_ai.operation.name"] == "invoke_agent"
+ assert invoke_agent_span["data"]["gen_ai.system"] == "openai"
+ assert invoke_agent_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert invoke_agent_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert invoke_agent_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert invoke_agent_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert invoke_agent_span["data"]["gen_ai.request.top_p"] == 1.0
+
+ assert ai_client_span["description"] == "chat gpt-4"
+ assert ai_client_span["data"]["gen_ai.operation.name"] == "chat"
+ assert ai_client_span["data"]["gen_ai.system"] == "openai"
+ assert ai_client_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert ai_client_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert ai_client_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert ai_client_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert ai_client_span["data"]["gen_ai.request.top_p"] == 1.0
+
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in invoke_agent_span["data"]
+
+
+@pytest.mark.parametrize(
+ "instructions",
+ (
+ None,
+ "You are a coding assistant that talks like a pirate.",
+ ),
+)
+@pytest.mark.parametrize(
+ "input",
+ [
+ pytest.param("Test input", id="string"),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "role": "user",
+ "content": "Test input",
+ },
+ ],
+ id="blocks_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "type": "message",
+ "role": "user",
+ "content": "Test input",
+ },
+ ],
+ id="blocks",
+ ),
+ pytest.param(
+ [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {
+ "role": "user",
+ "content": "Test input",
+ },
+ ],
+ id="parts_no_type",
+ ),
+ pytest.param(
+ [
+ {
+ "type": "message",
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ {"type": "text", "text": "Be concise and clear."},
+ ],
+ },
+ {
+ "type": "message",
+ "role": "user",
+ "content": "Test input",
+ },
+ ],
+ id="parts",
+ ),
+ ],
+)
+def test_agent_invocation_span_sync(
+ sentry_init,
+ capture_events,
+ test_agent_with_instructions,
+ nonstreaming_responses_model_response,
+ instructions,
+ input,
+ request,
+ get_model_response,
+):
+ """
+ Test that the integration creates spans for agent invocations.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent_with_instructions(instructions).clone(model=model)
+
+ response = get_model_response(
+ nonstreaming_responses_model_response, serialize_pydantic=True
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ result = agents.Runner.run_sync(
+ agent,
+ input,
+ run_config=test_run_config,
+ )
+
+ assert result is not None
+ assert result.final_output == "Hello, how can I help you?"
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span, ai_client_span = spans
+
+ assert transaction["transaction"] == "test_agent workflow"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.openai_agents"
+
+ assert invoke_agent_span["description"] == "invoke_agent test_agent"
+ assert invoke_agent_span["data"]["gen_ai.operation.name"] == "invoke_agent"
+ assert invoke_agent_span["data"]["gen_ai.system"] == "openai"
+ assert invoke_agent_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert invoke_agent_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert invoke_agent_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert invoke_agent_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert invoke_agent_span["data"]["gen_ai.request.top_p"] == 1.0
+
+ assert ai_client_span["description"] == "chat gpt-4"
+ assert ai_client_span["data"]["gen_ai.operation.name"] == "chat"
+ assert ai_client_span["data"]["gen_ai.system"] == "openai"
+ assert ai_client_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert ai_client_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert ai_client_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert ai_client_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert ai_client_span["data"]["gen_ai.request.top_p"] == 1.0
+
+ param_id = request.node.callspec.id
+ if "string" in param_id and instructions is None: # type: ignore
+ assert "gen_ai.system_instructions" not in ai_client_span["data"]
+ elif "string" in param_id:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ ]
+ )
+ elif "blocks_no_type" in param_id and instructions is None: # type: ignore
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ )
+ elif "blocks_no_type" in param_id:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ )
+ elif "blocks" in param_id and instructions is None: # type: ignore
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ )
+ elif "blocks" in param_id:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ ]
+ )
+ elif "parts_no_type" in param_id and instructions is None:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ )
+ elif "parts_no_type" in param_id:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ )
+ elif instructions is None: # type: ignore
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ )
+ else:
+ assert ai_client_span["data"]["gen_ai.system_instructions"] == safe_serialize(
+ [
+ {
+ "type": "text",
+ "content": "You are a coding assistant that talks like a pirate.",
+ },
+ {"type": "text", "content": "You are a helpful assistant."},
+ {"type": "text", "content": "Be concise and clear."},
+ ]
+ )
+
+
+@pytest.mark.asyncio
+async def test_handoff_span(sentry_init, capture_events, get_model_response):
+ """
+ Test that handoff spans are created when agents hand off to other agents.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4-mini", openai_client=client)
+
+ # Create two simple agents with a handoff relationship
+ secondary_agent = agents.Agent(
+ name="secondary_agent",
+ instructions="You are a secondary agent.",
+ model=model,
+ )
+
+ primary_agent = agents.Agent(
+ name="primary_agent",
+ instructions="You are a primary agent that hands off to secondary agent.",
+ model=model,
+ handoffs=[secondary_agent],
+ )
+
+ handoff_response = get_model_response(
+ Response(
+ id="resp_tool_123",
+ output=[
+ ResponseFunctionToolCall(
+ id="call_handoff_123",
+ call_id="call_handoff_123",
+ name="transfer_to_secondary_agent",
+ type="function_call",
+ arguments="{}",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=20,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ final_response = get_model_response(
+ Response(
+ id="resp_final_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="I'm the specialist and I can help with that!",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=20,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ primary_agent.model._client._client,
+ "send",
+ side_effect=[handoff_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ primary_agent,
+ "Please hand off to secondary agent",
+ run_config=test_run_config,
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ handoff_span = next(span for span in spans if span.get("op") == OP.GEN_AI_HANDOFF)
+
+ # Verify handoff span was created
+ assert handoff_span is not None
+ assert (
+ handoff_span["description"] == "handoff from primary_agent to secondary_agent"
+ )
+ assert handoff_span["data"]["gen_ai.operation.name"] == "handoff"
+
+
+@pytest.mark.asyncio
+async def test_max_turns_before_handoff_span(
+ sentry_init, capture_events, get_model_response
+):
+ """
+ Example raising agents.exceptions.AgentsException after the agent invocation span is complete.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4-mini", openai_client=client)
+
+ # Create two simple agents with a handoff relationship
+ secondary_agent = agents.Agent(
+ name="secondary_agent",
+ instructions="You are a secondary agent.",
+ model=model,
+ )
+
+ primary_agent = agents.Agent(
+ name="primary_agent",
+ instructions="You are a primary agent that hands off to secondary agent.",
+ model=model,
+ handoffs=[secondary_agent],
+ )
+
+ handoff_response = get_model_response(
+ Response(
+ id="resp_tool_123",
+ output=[
+ ResponseFunctionToolCall(
+ id="call_handoff_123",
+ call_id="call_handoff_123",
+ name="transfer_to_secondary_agent",
+ type="function_call",
+ arguments="{}",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=20,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ final_response = get_model_response(
+ Response(
+ id="resp_final_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="I'm the specialist and I can help with that!",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=20,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ primary_agent.model._client._client,
+ "send",
+ side_effect=[handoff_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with pytest.raises(MaxTurnsExceeded):
+ await agents.Runner.run(
+ primary_agent,
+ "Please hand off to secondary agent",
+ run_config=test_run_config,
+ max_turns=1,
+ )
+
+ (error, transaction) = events
+ spans = transaction["spans"]
+ handoff_span = next(span for span in spans if span.get("op") == OP.GEN_AI_HANDOFF)
+
+ # Verify handoff span was created
+ assert handoff_span is not None
+ assert (
+ handoff_span["description"] == "handoff from primary_agent to secondary_agent"
+ )
+ assert handoff_span["data"]["gen_ai.operation.name"] == "handoff"
+
+
+@pytest.mark.asyncio
+async def test_tool_execution_span(
+ sentry_init,
+ capture_events,
+ test_agent,
+ get_model_response,
+ responses_tool_call_model_responses,
+):
+ """
+ Test tool execution span creation.
+ """
+
+ @agents.function_tool
+ def simple_test_tool(message: str) -> str:
+ """A simple tool"""
+ return f"Tool executed with: {message}"
+
+ # Create agent with the tool
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent_with_tool = test_agent.clone(tools=[simple_test_tool], model=model)
+
+ responses = responses_tool_call_model_responses(
+ tool_name="simple_test_tool",
+ arguments='{"message": "hello"}',
+ response_model="gpt-4",
+ response_text="Task completed using the tool",
+ response_ids=iter(["resp_tool_123", "resp_final_123"]),
+ usages=iter(
+ [
+ ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=5,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=15,
+ ),
+ ResponseUsage(
+ input_tokens=15,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=25,
+ ),
+ ]
+ ),
+ )
+ tool_response = get_model_response(
+ next(responses),
+ serialize_pydantic=True,
+ )
+ final_response = get_model_response(
+ next(responses),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent_with_tool.model._client._client,
+ "send",
+ side_effect=[tool_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ await agents.Runner.run(
+ agent_with_tool,
+ "Please use the simple test tool",
+ run_config=test_run_config,
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ agent_span = next(span for span in spans if span["op"] == OP.GEN_AI_INVOKE_AGENT)
+ ai_client_span1, ai_client_span2 = (
+ span for span in spans if span["op"] == OP.GEN_AI_CHAT
+ )
+ tool_span = next(span for span in spans if span["op"] == OP.GEN_AI_EXECUTE_TOOL)
+
+ available_tool = {
+ "name": "simple_test_tool",
+ "description": "A simple tool",
+ "params_json_schema": {
+ "properties": {"message": {"title": "Message", "type": "string"}},
+ "required": ["message"],
+ "title": "simple_test_tool_args",
+ "type": "object",
+ "additionalProperties": False,
+ },
+ "on_invoke_tool": mock.ANY,
+ "strict_json_schema": True,
+ "is_enabled": True,
+ }
+
+ if parse_version(OPENAI_AGENTS_VERSION) >= (0, 3, 3):
+ available_tool.update(
+ {"tool_input_guardrails": None, "tool_output_guardrails": None}
+ )
+
+ if parse_version(OPENAI_AGENTS_VERSION) >= (
+ 0,
+ 8,
+ ):
+ available_tool["needs_approval"] = False
+ if parse_version(OPENAI_AGENTS_VERSION) >= (
+ 0,
+ 9,
+ 0,
+ ):
+ available_tool.update(
+ {
+ "timeout_seconds": None,
+ "timeout_behavior": "error_as_result",
+ "timeout_error_function": None,
+ }
+ )
+
+ assert transaction["transaction"] == "test_agent workflow"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.openai_agents"
+
+ assert agent_span["description"] == "invoke_agent test_agent"
+ assert agent_span["origin"] == "auto.ai.openai_agents"
+ assert agent_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert agent_span["data"]["gen_ai.operation.name"] == "invoke_agent"
+
+ agent_span_available_tool = json.loads(
+ agent_span["data"]["gen_ai.request.available_tools"]
+ )[0]
+ assert all(agent_span_available_tool[k] == v for k, v in available_tool.items())
+
+ assert agent_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert agent_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert agent_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert agent_span["data"]["gen_ai.request.top_p"] == 1.0
+ assert agent_span["data"]["gen_ai.system"] == "openai"
+
+ assert ai_client_span1["description"] == "chat gpt-4"
+ assert ai_client_span1["data"]["gen_ai.operation.name"] == "chat"
+ assert ai_client_span1["data"]["gen_ai.system"] == "openai"
+ assert ai_client_span1["data"]["gen_ai.agent.name"] == "test_agent"
+
+ ai_client_span1_available_tool = json.loads(
+ ai_client_span1["data"]["gen_ai.request.available_tools"]
+ )[0]
+ assert all(
+ ai_client_span1_available_tool[k] == v for k, v in available_tool.items()
+ )
+
+ assert ai_client_span1["data"]["gen_ai.request.max_tokens"] == 100
+ assert ai_client_span1["data"]["gen_ai.request.messages"] == safe_serialize(
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Please use the simple test tool"}
+ ],
+ },
+ ]
+ )
+ assert ai_client_span1["data"]["gen_ai.request.model"] == "gpt-4"
+ assert ai_client_span1["data"]["gen_ai.request.temperature"] == 0.7
+ assert ai_client_span1["data"]["gen_ai.request.top_p"] == 1.0
+ assert ai_client_span1["data"]["gen_ai.usage.input_tokens"] == 10
+ assert ai_client_span1["data"]["gen_ai.usage.input_tokens.cached"] == 0
+ assert ai_client_span1["data"]["gen_ai.usage.output_tokens"] == 5
+ assert ai_client_span1["data"]["gen_ai.usage.output_tokens.reasoning"] == 0
+ assert ai_client_span1["data"]["gen_ai.usage.total_tokens"] == 15
+
+ tool_call = {
+ "arguments": '{"message": "hello"}',
+ "call_id": "call_123",
+ "name": "simple_test_tool",
+ "type": "function_call",
+ "id": "call_123",
+ "status": None,
+ }
+
+ if OPENAI_VERSION >= (2, 25, 0):
+ tool_call["namespace"] = None
+
+ assert json.loads(ai_client_span1["data"]["gen_ai.response.tool_calls"]) == [
+ tool_call
+ ]
+
+ assert tool_span["description"] == "execute_tool simple_test_tool"
+ assert tool_span["data"]["gen_ai.agent.name"] == "test_agent"
+ assert tool_span["data"]["gen_ai.operation.name"] == "execute_tool"
+
+ tool_span_available_tool = json.loads(
+ tool_span["data"]["gen_ai.request.available_tools"]
+ )[0]
+ assert all(tool_span_available_tool[k] == v for k, v in available_tool.items())
+
+ assert tool_span["data"]["gen_ai.request.max_tokens"] == 100
+ assert tool_span["data"]["gen_ai.request.model"] == "gpt-4"
+ assert tool_span["data"]["gen_ai.request.temperature"] == 0.7
+ assert tool_span["data"]["gen_ai.request.top_p"] == 1.0
+ assert tool_span["data"]["gen_ai.system"] == "openai"
+ assert tool_span["data"]["gen_ai.tool.description"] == "A simple tool"
+ assert tool_span["data"]["gen_ai.tool.input"] == '{"message": "hello"}'
+ assert tool_span["data"]["gen_ai.tool.name"] == "simple_test_tool"
+ assert tool_span["data"]["gen_ai.tool.output"] == "Tool executed with: hello"
+ assert ai_client_span2["description"] == "chat gpt-4"
+ assert ai_client_span2["data"]["gen_ai.agent.name"] == "test_agent"
+ assert ai_client_span2["data"]["gen_ai.operation.name"] == "chat"
+
+ ai_client_span2_available_tool = json.loads(
+ ai_client_span2["data"]["gen_ai.request.available_tools"]
+ )[0]
+ assert all(
+ ai_client_span2_available_tool[k] == v for k, v in available_tool.items()
+ )
+
+ assert ai_client_span2["data"]["gen_ai.request.max_tokens"] == 100
+ assert ai_client_span2["data"]["gen_ai.request.messages"] == safe_serialize(
+ [
+ {
+ "role": "tool",
+ "content": [
+ {
+ "call_id": "call_123",
+ "output": "Tool executed with: hello",
+ "type": "function_call_output",
+ }
+ ],
+ },
+ ]
+ )
+ assert ai_client_span2["data"]["gen_ai.request.model"] == "gpt-4"
+ assert ai_client_span2["data"]["gen_ai.request.temperature"] == 0.7
+ assert ai_client_span2["data"]["gen_ai.request.top_p"] == 1.0
+ assert (
+ ai_client_span2["data"]["gen_ai.response.text"]
+ == "Task completed using the tool"
+ )
+ assert ai_client_span2["data"]["gen_ai.system"] == "openai"
+ assert ai_client_span2["data"]["gen_ai.usage.input_tokens.cached"] == 0
+ assert ai_client_span2["data"]["gen_ai.usage.input_tokens"] == 15
+ assert ai_client_span2["data"]["gen_ai.usage.output_tokens.reasoning"] == 0
+ assert ai_client_span2["data"]["gen_ai.usage.output_tokens"] == 10
+ assert ai_client_span2["data"]["gen_ai.usage.total_tokens"] == 25
+
+
+@pytest.mark.asyncio
+async def test_hosted_mcp_tool_propagation_header_streamed(
+ sentry_init,
+ test_agent,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ """
+ Test responses API is given trace propagation headers with HostedMCPTool.
+ """
+
+ hosted_tool = HostedMCPTool(
+ tool_config={
+ "type": "mcp",
+ "server_label": "test_server",
+ "server_url": "http://example.com/",
+ "headers": {
+ "baggage": "custom=data",
+ },
+ },
+ )
+
+ client = AsyncOpenAI(api_key="z")
+
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+
+ agent_with_tool = test_agent.clone(
+ tools=[hosted_tool],
+ model=model,
+ )
+
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ release="d08ebdb9309e1b004c6f52202de58a09c2268e42",
+ )
+
+ request_headers = {}
+ # openai-agents calls with_streaming_response() if available starting with
+ # https://github.com/openai/openai-agents-python/commit/159beb56130f7d85192acfd593c9168757984dc0.
+ # When using with_streaming_response() the header set below changes the response type:
+ # https://github.com/openai/openai-python/blob/656e3cab4a18262a49b961d41293367e45ee71b9/src/openai/_response.py#L67.
+ if parse_version(OPENAI_AGENTS_VERSION) >= (0, 10, 3) and hasattr(
+ agent_with_tool.model._client.responses, "with_streaming_response"
+ ):
+ request_headers["X-Stainless-Raw-Response"] = "stream"
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ ResponseCreatedEvent(
+ response=Response(
+ id="chat-id",
+ output=[],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="response-model-id",
+ object="response",
+ ),
+ type="response.created",
+ sequence_number=0,
+ ),
+ ResponseCompletedEvent(
+ response=Response(
+ id="chat-id",
+ output=[
+ ResponseOutputMessage(
+ id="message-id",
+ content=[
+ ResponseOutputText(
+ annotations=[],
+ text="the model response",
+ type="output_text",
+ ),
+ ],
+ role="assistant",
+ status="completed",
+ type="message",
+ ),
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="response-model-id",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=20,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=5,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=8,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ type="response.completed",
+ sequence_number=1,
+ ),
+ ]
+ )
+ ),
+ request_headers=request_headers,
+ )
+
+ # Patching https://github.com/openai/openai-python/blob/656e3cab4a18262a49b961d41293367e45ee71b9/src/openai/_base_client.py#L1604
+ with patch.object(
+ agent_with_tool.model._client._client,
+ "send",
+ return_value=response,
+ ) as create, mock.patch(
+ "sentry_sdk.tracing_utils.Random.randrange", return_value=500000
+ ):
+ with sentry_sdk.start_transaction(
+ name="/interactions/other-dogs/new-dog",
+ op="greeting.sniff",
+ trace_id="01234567890123456789012345678901",
+ ) as transaction:
+ result = agents.Runner.run_streamed(
+ agent_with_tool,
+ "Please use the simple test tool",
+ run_config=test_run_config,
+ )
+
+ async for event in result.stream_events():
+ pass
+
+ ai_client_span = next(
+ span
+ for span in transaction._span_recorder.spans
+ if span.op == OP.GEN_AI_CHAT
+ )
+
+ args, kwargs = create.call_args
+
+ request = args[0]
+ body = json.loads(request.content.decode("utf-8"))
+ hosted_mcp_tool = body["tools"][0]
+
+ assert hosted_mcp_tool["headers"][
+ "sentry-trace"
+ ] == "{trace_id}-{parent_span_id}-{sampled}".format(
+ trace_id=transaction.trace_id,
+ parent_span_id=ai_client_span.span_id,
+ sampled=1,
+ )
+
+ expected_outgoing_baggage = (
+ "custom=data,"
+ "sentry-trace_id=01234567890123456789012345678901,"
+ "sentry-sample_rand=0.500000,"
+ "sentry-environment=production,"
+ "sentry-release=d08ebdb9309e1b004c6f52202de58a09c2268e42,"
+ "sentry-transaction=/interactions/other-dogs/new-dog,"
+ "sentry-sample_rate=1.0,"
+ "sentry-sampled=true"
+ )
+
+ assert hosted_mcp_tool["headers"]["baggage"] == expected_outgoing_baggage
+
+
+@pytest.mark.asyncio
+async def test_hosted_mcp_tool_propagation_headers(
+ sentry_init, test_agent, get_model_response
+):
+ """
+ Test responses API is given trace propagation headers with HostedMCPTool.
+ """
+
+ hosted_tool = HostedMCPTool(
+ tool_config={
+ "type": "mcp",
+ "server_label": "test_server",
+ "server_url": "http://example.com/",
+ "headers": {
+ "baggage": "custom=data",
+ },
+ },
+ )
+
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+
+ agent_with_tool = test_agent.clone(
+ tools=[hosted_tool],
+ model=model,
+ )
+
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ release="d08ebdb9309e1b004c6f52202de58a09c2268e42",
+ )
+
+ response = get_model_response(EXAMPLE_RESPONSE, serialize_pydantic=True)
+
+ with patch.object(
+ agent_with_tool.model._client._client,
+ "send",
+ return_value=response,
+ ) as send, mock.patch(
+ "sentry_sdk.tracing_utils.Random.randrange", return_value=500000
+ ):
+ with sentry_sdk.start_transaction(
+ name="/interactions/other-dogs/new-dog",
+ op="greeting.sniff",
+ trace_id="01234567890123456789012345678901",
+ ) as transaction:
+ await agents.Runner.run(
+ agent_with_tool,
+ "Please use the simple test tool",
+ run_config=test_run_config,
+ )
+
+ ai_client_span = next(
+ span
+ for span in transaction._span_recorder.spans
+ if span.op == OP.GEN_AI_CHAT
+ )
+
+ args, kwargs = send.call_args
+
+ request = args[0]
+ body = json.loads(request.content.decode("utf-8"))
+ hosted_mcp_tool = body["tools"][0]
+
+ assert hosted_mcp_tool["headers"][
+ "sentry-trace"
+ ] == "{trace_id}-{parent_span_id}-{sampled}".format(
+ trace_id=transaction.trace_id,
+ parent_span_id=ai_client_span.span_id,
+ sampled=1,
+ )
+
+ expected_outgoing_baggage = (
+ "custom=data,"
+ "sentry-trace_id=01234567890123456789012345678901,"
+ "sentry-sample_rand=0.500000,"
+ "sentry-environment=production,"
+ "sentry-release=d08ebdb9309e1b004c6f52202de58a09c2268e42,"
+ "sentry-transaction=/interactions/other-dogs/new-dog,"
+ "sentry-sample_rate=1.0,"
+ "sentry-sampled=true"
+ )
+
+ assert hosted_mcp_tool["headers"]["baggage"] == expected_outgoing_baggage
+
+
+@pytest.mark.asyncio
+async def test_model_behavior_error(sentry_init, capture_events, test_agent):
+ """
+ Example raising agents.exceptions.AgentsException before the agent invocation span is complete.
+ The mocked API response indicates that "wrong_tool" was called.
+ """
+
+ @agents.function_tool
+ def simple_test_tool(message: str) -> str:
+ """A simple tool"""
+ return f"Tool executed with: {message}"
+
+ # Create agent with the tool
+ agent_with_tool = test_agent.clone(tools=[simple_test_tool])
+
+ with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
+ with patch(
+ "agents.models.openai_responses.OpenAIResponsesModel.get_response"
+ ) as mock_get_response:
+ # Create a mock response that includes tool calls
+ tool_call = ResponseFunctionToolCall(
+ id="call_123",
+ call_id="call_123",
+ name="wrong_tool",
+ type="function_call",
+ arguments='{"message": "hello"}',
+ )
+
+ tool_response = ModelResponse(
+ output=[tool_call],
+ usage=Usage(
+ requests=1, input_tokens=10, output_tokens=5, total_tokens=15
+ ),
+ response_id="resp_tool_123",
+ )
+
+ mock_get_response.side_effect = [tool_response]
+
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ with pytest.raises(ModelBehaviorError):
+ await agents.Runner.run(
+ agent_with_tool,
+ "Please use the simple test tool",
+ run_config=test_run_config,
+ )
+
+ (error, transaction) = events
+ spans = transaction["spans"]
+ (
+ agent_span,
+ ai_client_span1,
+ ) = spans
+
+ assert transaction["transaction"] == "test_agent workflow"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.openai_agents"
+
+ assert agent_span["description"] == "invoke_agent test_agent"
+ assert agent_span["origin"] == "auto.ai.openai_agents"
+
+ # Error due to unrecognized tool in model response.
+ assert agent_span["status"] == "internal_error"
+ assert agent_span["tags"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_error_handling(sentry_init, capture_events, test_agent):
+ """
+ Test error handling in agent execution.
+ """
+
+ with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
+ with patch(
+ "agents.models.openai_responses.OpenAIResponsesModel.get_response"
+ ) as mock_get_response:
+ mock_get_response.side_effect = Exception("Model Error")
+
+ sentry_init(
+ integrations=[
+ OpenAIAgentsIntegration(),
+ LoggingIntegration(event_level=logging.CRITICAL),
+ ],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with pytest.raises(Exception, match="Model Error"):
+ await agents.Runner.run(
+ test_agent, "Test input", run_config=test_run_config
+ )
+
+ (
+ error_event,
+ transaction,
+ ) = events
+
+ assert error_event["exception"]["values"][0]["type"] == "Exception"
+ assert error_event["exception"]["values"][0]["value"] == "Model Error"
+ assert error_event["exception"]["values"][0]["mechanism"]["type"] == "openai_agents"
+
+ spans = transaction["spans"]
+ (invoke_agent_span, ai_client_span) = spans
+
+ assert transaction["transaction"] == "test_agent workflow"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.openai_agents"
+
+ assert invoke_agent_span["description"] == "invoke_agent test_agent"
+ assert invoke_agent_span["origin"] == "auto.ai.openai_agents"
+
+ assert ai_client_span["description"] == "chat gpt-4"
+ assert ai_client_span["origin"] == "auto.ai.openai_agents"
+ assert ai_client_span["status"] == "internal_error"
+ assert ai_client_span["tags"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_error_captures_input_data(sentry_init, capture_events, test_agent):
+ """
+ Test that input data is captured even when the API call raises an exception.
+ This verifies that _set_input_data is called before the API call.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ model_request = httpx.Request(
+ "POST",
+ "/responses",
+ )
+
+ response = httpx.Response(
+ 500,
+ request=model_request,
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[
+ OpenAIAgentsIntegration(),
+ LoggingIntegration(event_level=logging.CRITICAL),
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ with pytest.raises(InternalServerError, match="Error code: 500"):
+ await agents.Runner.run(agent, "Test input", run_config=test_run_config)
+
+ (
+ error_event,
+ transaction,
+ ) = events
+
+ assert error_event["exception"]["values"][0]["type"] == "InternalServerError"
+ assert error_event["exception"]["values"][0]["value"] == "Error code: 500"
+
+ spans = transaction["spans"]
+ ai_client_span = [s for s in spans if s["op"] == "gen_ai.chat"][0]
+
+ assert ai_client_span["description"] == "chat gpt-4"
+ assert ai_client_span["status"] == "internal_error"
+ assert ai_client_span["tags"]["status"] == "internal_error"
+
+ assert "gen_ai.request.messages" in ai_client_span["data"]
+ request_messages = safe_serialize(
+ [
+ {"role": "user", "content": [{"type": "text", "text": "Test input"}]},
+ ]
+ )
+ assert ai_client_span["data"]["gen_ai.request.messages"] == request_messages
+
+
+@pytest.mark.asyncio
+async def test_span_status_error(sentry_init, capture_events, test_agent):
+ with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
+ with patch(
+ "agents.models.openai_responses.OpenAIResponsesModel.get_response"
+ ) as mock_get_response:
+ mock_get_response.side_effect = ValueError("Model Error")
+
+ sentry_init(
+ integrations=[
+ OpenAIAgentsIntegration(),
+ LoggingIntegration(event_level=logging.CRITICAL),
+ ],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ with pytest.raises(ValueError, match="Model Error"):
+ await agents.Runner.run(
+ test_agent, "Test input", run_config=test_run_config
+ )
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+ assert transaction["spans"][0]["status"] == "internal_error"
+ assert transaction["spans"][0]["tags"]["status"] == "internal_error"
+ assert transaction["contexts"]["trace"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_mcp_tool_execution_spans(
+ sentry_init, capture_events, test_agent, get_model_response
+):
+ """
+ Test that MCP (Model Context Protocol) tool calls create execute_tool spans.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ mcp_response = get_model_response(
+ Response(
+ id="resp_mcp_123",
+ output=[
+ McpCall(
+ id="mcp_call_123",
+ name="test_mcp_tool",
+ arguments='{"query": "search term"}',
+ output="MCP tool executed successfully",
+ error=None,
+ type="mcp_call",
+ server_label="test_server",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=5,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=15,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ final_response = get_model_response(
+ Response(
+ id="resp_final_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Task completed using MCP tool",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=15,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=25,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ side_effect=[mcp_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ await agents.Runner.run(
+ agent,
+ "Please use MCP tool",
+ run_config=test_run_config,
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find the MCP execute_tool span
+ mcp_tool_span = None
+ for span in spans:
+ if span.get("description") == "execute_tool test_mcp_tool":
+ mcp_tool_span = span
+ break
+
+ # Verify the MCP tool span was created
+ assert mcp_tool_span is not None, "MCP execute_tool span was not created"
+ assert mcp_tool_span["description"] == "execute_tool test_mcp_tool"
+ assert mcp_tool_span["data"]["gen_ai.tool.name"] == "test_mcp_tool"
+ assert mcp_tool_span["data"]["gen_ai.tool.input"] == '{"query": "search term"}'
+ assert (
+ mcp_tool_span["data"]["gen_ai.tool.output"] == "MCP tool executed successfully"
+ )
+
+ # Verify no error status since error was None
+ assert mcp_tool_span.get("status") != "internal_error"
+ assert mcp_tool_span.get("tags", {}).get("status") != "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_mcp_tool_execution_with_error(
+ sentry_init, capture_events, test_agent, get_model_response
+):
+ """
+ Test that MCP tool calls with errors are tracked with error status.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ mcp_response = get_model_response(
+ Response(
+ id="resp_mcp_123",
+ output=[
+ McpCall(
+ id="mcp_call_error_123",
+ name="failing_mcp_tool",
+ arguments='{"query": "test"}',
+ output=None,
+ error="MCP tool execution failed",
+ type="mcp_call",
+ server_label="test_server",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=5,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=15,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ final_response = get_model_response(
+ Response(
+ id="resp_final_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Task completed using MCP tool",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=15,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=25,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ side_effect=[mcp_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ await agents.Runner.run(
+ agent,
+ "Please use failing MCP tool",
+ run_config=test_run_config,
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find the MCP execute_tool span with error
+ mcp_tool_span = None
+ for span in spans:
+ if span.get("description") == "execute_tool failing_mcp_tool":
+ mcp_tool_span = span
+ break
+
+ # Verify the MCP tool span was created with error status
+ assert mcp_tool_span is not None, "MCP execute_tool span was not created"
+ assert mcp_tool_span["description"] == "execute_tool failing_mcp_tool"
+ assert mcp_tool_span["data"]["gen_ai.tool.name"] == "failing_mcp_tool"
+ assert mcp_tool_span["data"]["gen_ai.tool.input"] == '{"query": "test"}'
+ assert mcp_tool_span["data"]["gen_ai.tool.output"] is None
+
+ # Verify error status was set
+ assert mcp_tool_span["status"] == "internal_error"
+ assert mcp_tool_span["tags"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_mcp_tool_execution_without_pii(
+ sentry_init, capture_events, test_agent, get_model_response
+):
+ """
+ Test that MCP tool input/output are not included when send_default_pii is False.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ mcp_response = get_model_response(
+ Response(
+ id="resp_mcp_123",
+ output=[
+ McpCall(
+ id="mcp_call_pii_123",
+ name="test_mcp_tool",
+ arguments='{"query": "sensitive data"}',
+ output="Result with sensitive info",
+ error=None,
+ type="mcp_call",
+ server_label="test_server",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=5,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=15,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ final_response = get_model_response(
+ Response(
+ id="resp_final_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Task completed",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=15,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=25,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ side_effect=[mcp_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=False, # PII disabled
+ )
+
+ events = capture_events()
+
+ await agents.Runner.run(
+ agent,
+ "Please use MCP tool",
+ run_config=test_run_config,
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find the MCP execute_tool span
+ mcp_tool_span = None
+ for span in spans:
+ if span.get("description") == "execute_tool test_mcp_tool":
+ mcp_tool_span = span
+ break
+
+ # Verify the MCP tool span was created but without input/output
+ assert mcp_tool_span is not None, "MCP execute_tool span was not created"
+ assert mcp_tool_span["description"] == "execute_tool test_mcp_tool"
+ assert mcp_tool_span["data"]["gen_ai.tool.name"] == "test_mcp_tool"
+
+ # Verify input and output are not included when send_default_pii is False
+ assert "gen_ai.tool.input" not in mcp_tool_span["data"]
+ assert "gen_ai.tool.output" not in mcp_tool_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_multiple_agents_asyncio(
+ sentry_init,
+ capture_events,
+ test_agent,
+ nonstreaming_responses_model_response,
+ get_model_response,
+):
+ """
+ Test that multiple agents can be run at the same time in asyncio tasks
+ without interfering with each other.
+ """
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ response = get_model_response(
+ nonstreaming_responses_model_response, serialize_pydantic=True
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ async def run():
+ await agents.Runner.run(
+ starting_agent=agent,
+ input="Test input",
+ run_config=test_run_config,
+ )
+
+ await asyncio.gather(*[run() for _ in range(3)])
+
+ assert len(events) == 3
+ txn1, txn2, txn3 = events
+
+ assert txn1["type"] == "transaction"
+ assert txn1["transaction"] == "test_agent workflow"
+ assert txn2["type"] == "transaction"
+ assert txn2["transaction"] == "test_agent workflow"
+ assert txn3["type"] == "transaction"
+ assert txn3["transaction"] == "test_agent workflow"
+
+
+# Test input messages with mixed roles including "ai"
+@pytest.mark.parametrize(
+ "test_message,expected_role",
+ [
+ ({"role": "user", "content": "Hello"}, "user"),
+ (
+ {"role": "ai", "content": "Hi there!"},
+ "assistant",
+ ), # Should be mapped to "assistant"
+ (
+ {"role": "assistant", "content": "How can I help?"},
+ "assistant",
+ ), # Should stay "assistant"
+ ],
+)
+def test_openai_agents_message_role_mapping(
+ sentry_init, capture_events, test_message, expected_role
+):
+ """Test that OpenAI Agents integration properly maps message roles like 'ai' to 'assistant'"""
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ get_response_kwargs = {"input": [test_message]}
+
+ from sentry_sdk.integrations.openai_agents.utils import _set_input_data
+ from sentry_sdk import start_span
+
+ with start_span(op="test") as span:
+ _set_input_data(span, get_response_kwargs)
+
+ # Verify that messages were processed and roles were mapped
+ from sentry_sdk.consts import SPANDATA
+
+ stored_messages = json.loads(span._data[SPANDATA.GEN_AI_REQUEST_MESSAGES])
+
+ # Verify roles were properly mapped
+ assert stored_messages[0]["role"] == expected_role
+
+
+@pytest.mark.asyncio
+async def test_tool_execution_error_tracing(
+ sentry_init,
+ capture_events,
+ test_agent,
+ get_model_response,
+ responses_tool_call_model_responses,
+):
+ """
+ Test that tool execution errors are properly tracked via error tracing patch.
+
+ This tests the patch of agents error tracing function to ensure execute_tool
+ spans are set to error status when tool execution fails.
+
+ The function location varies by version:
+ - Newer versions: agents.util._error_tracing.attach_error_to_current_span
+ - Older versions: agents._utils.attach_error_to_current_span
+ """
+
+ @agents.function_tool
+ def failing_tool(message: str) -> str:
+ """A tool that fails"""
+ raise ValueError("Tool execution failed")
+
+ # Create agent with the failing tool
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent_with_tool = test_agent.clone(tools=[failing_tool], model=model)
+
+ responses = responses_tool_call_model_responses(
+ tool_name="failing_tool",
+ arguments='{"message": "test"}',
+ response_model="gpt-4-0613",
+ response_text="An error occurred while running the tool",
+ response_ids=iter(["resp_1", "resp_2"]),
+ usages=iter(
+ [
+ ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=5,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=15,
+ ),
+ ResponseUsage(
+ input_tokens=15,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=25,
+ ),
+ ]
+ ),
+ )
+ tool_response = get_model_response(
+ next(responses),
+ serialize_pydantic=True,
+ )
+ final_response = get_model_response(
+ next(responses),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent_with_tool.model._client._client,
+ "send",
+ side_effect=[tool_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ # Note: The agents library catches tool exceptions internally,
+ # so we don't expect this to raise
+ await agents.Runner.run(
+ agent_with_tool,
+ "Please use the failing tool",
+ run_config=test_run_config,
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find the execute_tool span
+ execute_tool_span = None
+ for span in spans:
+ description = span.get("description", "")
+ if description is not None and description.startswith(
+ "execute_tool failing_tool"
+ ):
+ execute_tool_span = span
+ break
+
+ # Verify the execute_tool span was created
+ assert execute_tool_span is not None, "execute_tool span was not created"
+ assert execute_tool_span["description"] == "execute_tool failing_tool"
+ assert execute_tool_span["data"]["gen_ai.tool.name"] == "failing_tool"
+
+ # Verify error status was set (this is the key test for our patch)
+ # The span should be marked as error because the tool execution failed
+ assert execute_tool_span["status"] == "internal_error"
+ assert execute_tool_span["tags"]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_invoke_agent_span_includes_usage_data(
+ sentry_init,
+ capture_events,
+ test_agent,
+ get_model_response,
+):
+ """
+ Test that invoke_agent spans include aggregated usage data from context_wrapper.
+ This verifies the new functionality added to track token usage in invoke_agent spans.
+ """
+ client = AsyncOpenAI(api_key="z")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ response = get_model_response(
+ Response(
+ id="resp_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_123",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Response with usage",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=20,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent, "Test input", run_config=test_run_config
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span = next(
+ span for span in spans if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ )
+
+ # Verify invoke_agent span has usage data from context_wrapper
+ assert invoke_agent_span["description"] == "invoke_agent test_agent"
+ assert "gen_ai.usage.input_tokens" in invoke_agent_span["data"]
+ assert "gen_ai.usage.output_tokens" in invoke_agent_span["data"]
+ assert "gen_ai.usage.total_tokens" in invoke_agent_span["data"]
+
+ assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 10
+ assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20
+ assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 30
+ assert invoke_agent_span["data"]["gen_ai.usage.input_tokens.cached"] == 0
+ assert invoke_agent_span["data"]["gen_ai.usage.output_tokens.reasoning"] == 5
+
+
+@pytest.mark.asyncio
+async def test_ai_client_span_includes_response_model(
+ sentry_init,
+ capture_events,
+ test_agent,
+ get_model_response,
+):
+ """
+ Test that ai_client spans (gen_ai.chat) include the response model from the actual API response.
+ This verifies we capture the actual model used (which may differ from the requested model).
+ """
+ client = AsyncOpenAI(api_key="z")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ response = get_model_response(
+ Response(
+ id="resp_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_123",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Hello from GPT-4.1",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=20,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent, "Test input", run_config=test_run_config
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ ai_client_span = next(span for span in spans if span["op"] == OP.GEN_AI_CHAT)
+
+ # Verify ai_client span has response model from API response
+ assert ai_client_span["description"] == "chat gpt-4"
+ assert "gen_ai.response.model" in ai_client_span["data"]
+ assert ai_client_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+
+
+@pytest.mark.asyncio
+async def test_ai_client_span_response_model_with_chat_completions(
+ sentry_init,
+ capture_events,
+ get_model_response,
+):
+ """
+ Test that response model is captured when using ChatCompletions API (not Responses API).
+ This ensures our implementation works with different OpenAI model types.
+ """
+ # Create agent that uses ChatCompletions model
+ client = AsyncOpenAI(api_key="z")
+ model = OpenAIResponsesModel(model="gpt-4o-mini", openai_client=client)
+
+ agent = Agent(
+ name="chat_completions_agent",
+ instructions="Test agent using ChatCompletions",
+ model=model,
+ )
+
+ response = get_model_response(
+ Response(
+ id="resp_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_123",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Response from model",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4o-mini-2024-07-18",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=15,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=25,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=40,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent, "Test input", run_config=test_run_config
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ ai_client_span = next(span for span in spans if span["op"] == OP.GEN_AI_CHAT)
+
+ # Verify response model from API response is captured
+ assert "gen_ai.response.model" in ai_client_span["data"]
+ assert ai_client_span["data"]["gen_ai.response.model"] == "gpt-4o-mini-2024-07-18"
+
+
+@pytest.mark.asyncio
+async def test_multiple_llm_calls_aggregate_usage(
+ sentry_init, capture_events, test_agent, get_model_response
+):
+ """
+ Test that invoke_agent spans show aggregated usage across multiple LLM calls
+ (e.g., when tools are used and multiple API calls are made).
+ """
+
+ @agents.function_tool
+ def calculator(a: int, b: int) -> int:
+ """Add two numbers"""
+ return a + b
+
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent_with_tool = test_agent.clone(tools=[calculator], model=model)
+
+ tool_call_response = get_model_response(
+ Response(
+ id="resp_1",
+ output=[
+ ResponseFunctionToolCall(
+ id="call_123",
+ call_id="call_123",
+ name="calculator",
+ type="function_call",
+ arguments='{"a": 5, "b": 3}',
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=5,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=15,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ final_response = get_model_response(
+ Response(
+ id="resp_2",
+ output=[
+ ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="The result is 8",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4-0613",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=20,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=5,
+ ),
+ output_tokens=15,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=3,
+ ),
+ total_tokens=35,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent_with_tool.model._client._client,
+ "send",
+ side_effect=[tool_call_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent_with_tool,
+ "What is 5 + 3?",
+ run_config=test_run_config,
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span = spans[0]
+
+ # Verify invoke_agent span has aggregated usage from both API calls
+ # Total: 10 + 20 = 30 input tokens, 5 + 15 = 20 output tokens, 15 + 35 = 50 total
+ assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 30
+ assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20
+ assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 50
+ # Cached tokens should be aggregated: 0 + 5 = 5
+ assert invoke_agent_span["data"]["gen_ai.usage.input_tokens.cached"] == 5
+ # Reasoning tokens should be aggregated: 0 + 3 = 3
+ assert invoke_agent_span["data"]["gen_ai.usage.output_tokens.reasoning"] == 3
+
+
+@pytest.mark.asyncio
+async def test_invoke_agent_span_includes_response_model(
+ sentry_init,
+ capture_events,
+ test_agent,
+ get_model_response,
+):
+ """
+ Test that invoke_agent spans include the response model from the API response.
+ """
+ client = AsyncOpenAI(api_key="z")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ response = get_model_response(
+ Response(
+ id="resp_123",
+ output=[
+ ResponseOutputMessage(
+ id="msg_123",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Response from model",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=20,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent, "Test input", run_config=test_run_config
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span = next(
+ span for span in spans if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ )
+ ai_client_span = next(span for span in spans if span["op"] == OP.GEN_AI_CHAT)
+
+ # Verify invoke_agent span has response model from API
+ assert invoke_agent_span["description"] == "invoke_agent test_agent"
+ assert "gen_ai.response.model" in invoke_agent_span["data"]
+ assert invoke_agent_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+
+ # Also verify ai_client span has it
+ assert "gen_ai.response.model" in ai_client_span["data"]
+ assert ai_client_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+
+
+@pytest.mark.asyncio
+async def test_invoke_agent_span_uses_last_response_model(
+ sentry_init,
+ capture_events,
+ test_agent,
+ get_model_response,
+):
+ """
+ Test that when an agent makes multiple LLM calls (e.g., with tools),
+ the invoke_agent span reports the last response model used.
+ """
+
+ @agents.function_tool
+ def calculator(a: int, b: int) -> int:
+ """Add two numbers"""
+ return a + b
+
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent_with_tool = test_agent.clone(tools=[calculator], model=model)
+
+ first_response = get_model_response(
+ Response(
+ id="resp_1",
+ output=[
+ ResponseFunctionToolCall(
+ id="call_123",
+ call_id="call_123",
+ name="calculator",
+ type="function_call",
+ arguments='{"a": 5, "b": 3}',
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4-0613",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=5,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=15,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ second_response = get_model_response(
+ Response(
+ id="resp_2",
+ output=[
+ ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="I'm the specialist and I can help with that!",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4.1-2025-04-14",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=20,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=15,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=5,
+ ),
+ total_tokens=35,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent_with_tool.model._client._client,
+ "send",
+ side_effect=[first_response, second_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent_with_tool,
+ "What is 5 + 3?",
+ run_config=test_run_config,
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span = spans[0]
+ first_ai_client_span = spans[1]
+ second_ai_client_span = spans[3] # After tool span
+
+ # Invoke_agent span uses the LAST response model
+ assert "gen_ai.response.model" in invoke_agent_span["data"]
+ assert invoke_agent_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+
+ # Each ai_client span has its own response model from the API
+ assert first_ai_client_span["data"]["gen_ai.response.model"] == "gpt-4-0613"
+ assert (
+ second_ai_client_span["data"]["gen_ai.response.model"] == "gpt-4.1-2025-04-14"
+ )
+
+
+def test_openai_agents_message_truncation(sentry_init, capture_events):
+ """Test that large messages are truncated properly in OpenAI Agents integration."""
+
+ large_content = (
+ "This is a very long message that will exceed our size limits. " * 1000
+ )
+
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ test_messages = [
+ {"role": "user", "content": large_content},
+ {"role": "assistant", "content": large_content},
+ {"role": "user", "content": "small message 4"},
+ {"role": "assistant", "content": "small message 5"},
+ ]
+
+ get_response_kwargs = {"input": test_messages}
+
+ with start_span(op="gen_ai.chat") as span:
+ scope = sentry_sdk.get_current_scope()
+ _set_input_data(span, get_response_kwargs)
+ if hasattr(scope, "_gen_ai_original_message_count"):
+ truncated_count = scope._gen_ai_original_message_count.get(span.span_id)
+ assert truncated_count == 4, (
+ f"Expected 4 original messages, got {truncated_count}"
+ )
+
+ assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span._data
+ messages_data = span._data[SPANDATA.GEN_AI_REQUEST_MESSAGES]
+ assert isinstance(messages_data, str)
+
+ parsed_messages = json.loads(messages_data)
+ assert isinstance(parsed_messages, list)
+ assert len(parsed_messages) == 1
+ assert "small message 5" in str(parsed_messages[0])
+
+
+@pytest.mark.asyncio
+async def test_streaming_span_update_captures_response_data(
+ sentry_init, test_agent, mock_usage
+):
+ """
+ Test that update_ai_client_span correctly captures response text,
+ usage data, and response model from a streaming response.
+ """
+ from sentry_sdk.integrations.openai_agents.spans.ai_client import (
+ update_ai_client_span,
+ )
+
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ # Create a mock streaming response object (similar to what we'd get from ResponseCompletedEvent)
+ mock_streaming_response = MagicMock()
+ mock_streaming_response.model = "gpt-4-streaming"
+ mock_streaming_response.usage = mock_usage
+ mock_streaming_response.output = [
+ ResponseOutputMessage(
+ id="msg_streaming_123",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Hello from streaming!",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ]
+
+ # Test the unified update function (works for both streaming and non-streaming)
+ with start_span(op="gen_ai.chat", description="test chat") as span:
+ update_ai_client_span(span, mock_streaming_response)
+
+ # Verify the span data was set correctly
+ assert span._data["gen_ai.response.text"] == "Hello from streaming!"
+ assert span._data["gen_ai.usage.input_tokens"] == 10
+ assert span._data["gen_ai.usage.output_tokens"] == 20
+ assert span._data["gen_ai.response.model"] == "gpt-4-streaming"
+
+
+@pytest.mark.asyncio
+async def test_streaming_ttft_on_chat_span(
+ sentry_init,
+ test_agent,
+ get_model_response,
+ async_iterator,
+ server_side_event_chunks,
+):
+ """
+ Test that time-to-first-token (TTFT) is recorded on chat spans during streaming.
+
+ TTFT is triggered by events with a `delta` attribute, which includes:
+ - ResponseTextDeltaEvent (text output)
+ - ResponseAudioDeltaEvent (audio output)
+ - ResponseReasoningTextDeltaEvent (reasoning/thinking)
+ - ResponseFunctionCallArgumentsDeltaEvent (function call args)
+ - and other delta events...
+
+ Events WITHOUT delta (like ResponseCompletedEvent, ResponseCreatedEvent, etc.)
+ should NOT trigger TTFT.
+ """
+ client = AsyncOpenAI(api_key="z")
+
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+
+ agent_with_tool = test_agent.clone(
+ model=model,
+ )
+
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ request_headers = {}
+ # openai-agents calls with_streaming_response() if available starting with
+ # https://github.com/openai/openai-agents-python/commit/159beb56130f7d85192acfd593c9168757984dc0.
+ # When using with_streaming_response() the header set below changes the response type:
+ # https://github.com/openai/openai-python/blob/656e3cab4a18262a49b961d41293367e45ee71b9/src/openai/_response.py#L67.
+ if parse_version(OPENAI_AGENTS_VERSION) >= (0, 10, 3) and hasattr(
+ agent_with_tool.model._client.responses, "with_streaming_response"
+ ):
+ request_headers["X-Stainless-Raw-Response"] = "stream"
+
+ response = get_model_response(
+ async_iterator(
+ server_side_event_chunks(
+ [
+ ResponseCreatedEvent(
+ response=Response(
+ id="chat-id",
+ output=[],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="response-model-id",
+ object="response",
+ ),
+ type="response.created",
+ sequence_number=0,
+ ),
+ ResponseTextDeltaEvent(
+ type="response.output_text.delta",
+ item_id="message-id",
+ output_index=0,
+ content_index=0,
+ delta="Hello",
+ logprobs=[],
+ sequence_number=1,
+ ),
+ ResponseTextDeltaEvent(
+ type="response.output_text.delta",
+ item_id="message-id",
+ output_index=0,
+ content_index=0,
+ delta=" world!",
+ logprobs=[],
+ sequence_number=2,
+ ),
+ ResponseCompletedEvent(
+ response=Response(
+ id="chat-id",
+ output=[
+ ResponseOutputMessage(
+ id="message-id",
+ content=[
+ ResponseOutputText(
+ annotations=[],
+ text="Hello world!",
+ type="output_text",
+ ),
+ ],
+ role="assistant",
+ status="completed",
+ type="message",
+ ),
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="response-model-id",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=20,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=5,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=8,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ type="response.completed",
+ sequence_number=3,
+ ),
+ ]
+ )
+ ),
+ request_headers=request_headers,
+ )
+
+ # Patching https://github.com/openai/openai-python/blob/656e3cab4a18262a49b961d41293367e45ee71b9/src/openai/_base_client.py#L1604
+ with patch.object(
+ agent_with_tool.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ with sentry_sdk.start_transaction(
+ name="test_ttft", sampled=True
+ ) as transaction:
+ result = agents.Runner.run_streamed(
+ agent_with_tool,
+ "Please use the simple test tool",
+ run_config=test_run_config,
+ )
+
+ async for event in result.stream_events():
+ pass
+
+ # Verify TTFT is recorded on the chat span (must be inside transaction context)
+ chat_spans = [
+ s for s in transaction._span_recorder.spans if s.op == "gen_ai.chat"
+ ]
+ assert len(chat_spans) >= 1
+ chat_span = chat_spans[0]
+
+ assert SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN in chat_span._data
+ assert chat_span._data.get(SPANDATA.GEN_AI_RESPONSE_STREAMING) is True
+
+
+@pytest.mark.skipif(
+ parse_version(OPENAI_AGENTS_VERSION) < (0, 4, 0),
+ reason="conversation_id support requires openai-agents >= 0.4.0",
+)
+@pytest.mark.asyncio
+async def test_conversation_id_on_all_spans(
+ sentry_init,
+ capture_events,
+ test_agent,
+ nonstreaming_responses_model_response,
+ get_model_response,
+):
+ """
+ Test that gen_ai.conversation.id is set on all AI-related spans when passed to Runner.run().
+ """
+
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ response = get_model_response(
+ nonstreaming_responses_model_response, serialize_pydantic=True
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ result = await agents.Runner.run(
+ agent,
+ "Test input",
+ run_config=test_run_config,
+ conversation_id="conv_test_123",
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span = next(
+ span for span in spans if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ )
+ ai_client_span = next(span for span in spans if span["op"] == OP.GEN_AI_CHAT)
+
+ # Verify workflow span (transaction) has conversation_id
+ assert (
+ transaction["contexts"]["trace"]["data"]["gen_ai.conversation.id"]
+ == "conv_test_123"
+ )
+
+ # Verify invoke_agent span has conversation_id
+ assert invoke_agent_span["data"]["gen_ai.conversation.id"] == "conv_test_123"
+
+ # Verify ai_client span has conversation_id
+ assert ai_client_span["data"]["gen_ai.conversation.id"] == "conv_test_123"
+
+
+@pytest.mark.skipif(
+ parse_version(OPENAI_AGENTS_VERSION) < (0, 4, 0),
+ reason="conversation_id support requires openai-agents >= 0.4.0",
+)
+@pytest.mark.asyncio
+async def test_conversation_id_on_tool_span(
+ sentry_init, capture_events, test_agent, get_model_response
+):
+ """
+ Test that gen_ai.conversation.id is set on tool execution spans when passed to Runner.run().
+ """
+
+ @agents.function_tool
+ def simple_tool(message: str) -> str:
+ """A simple tool"""
+ return f"Result: {message}"
+
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent_with_tool = test_agent.clone(tools=[simple_tool], model=model)
+
+ tool_response = get_model_response(
+ Response(
+ id="call_123",
+ output=[
+ ResponseFunctionToolCall(
+ id="call_123",
+ call_id="call_123",
+ name="simple_tool",
+ type="function_call",
+ arguments='{"message": "hello"}',
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=10,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=0,
+ ),
+ output_tokens=5,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=0,
+ ),
+ total_tokens=15,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ final_response = get_model_response(
+ Response(
+ id="resp_final_789",
+ output=[
+ ResponseOutputMessage(
+ id="msg_final",
+ type="message",
+ status="completed",
+ content=[
+ ResponseOutputText(
+ text="Done",
+ type="output_text",
+ annotations=[],
+ )
+ ],
+ role="assistant",
+ )
+ ],
+ parallel_tool_calls=False,
+ tool_choice="none",
+ tools=[],
+ created_at=10000000,
+ model="gpt-4",
+ object="response",
+ usage=ResponseUsage(
+ input_tokens=20,
+ input_tokens_details=InputTokensDetails(
+ cached_tokens=5,
+ ),
+ output_tokens=10,
+ output_tokens_details=OutputTokensDetails(
+ reasoning_tokens=8,
+ ),
+ total_tokens=30,
+ ),
+ ),
+ serialize_pydantic=True,
+ )
+
+ with patch.object(
+ agent_with_tool.model._client._client,
+ "send",
+ side_effect=[tool_response, final_response],
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ await agents.Runner.run(
+ agent_with_tool,
+ "Use the tool",
+ run_config=test_run_config,
+ conversation_id="conv_tool_test_456",
+ )
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find the tool span
+ tool_span = None
+ for span in spans:
+ if span.get("description", "").startswith("execute_tool"):
+ tool_span = span
+ break
+
+ assert tool_span is not None
+ # Tool span should have the conversation_id passed to Runner.run()
+ assert tool_span["data"]["gen_ai.conversation.id"] == "conv_tool_test_456"
+
+ # Workflow span (transaction) should have the same conversation_id
+ assert (
+ transaction["contexts"]["trace"]["data"]["gen_ai.conversation.id"]
+ == "conv_tool_test_456"
+ )
+
+
+@pytest.mark.skipif(
+ parse_version(OPENAI_AGENTS_VERSION) < (0, 4, 0),
+ reason="conversation_id support requires openai-agents >= 0.4.0",
+)
+@pytest.mark.asyncio
+async def test_no_conversation_id_when_not_provided(
+ sentry_init,
+ capture_events,
+ test_agent,
+ nonstreaming_responses_model_response,
+ get_model_response,
+):
+ """
+ Test that gen_ai.conversation.id is not set when not passed to Runner.run().
+ """
+
+ client = AsyncOpenAI(api_key="test-key")
+ model = OpenAIResponsesModel(model="gpt-4", openai_client=client)
+ agent = test_agent.clone(model=model)
+
+ response = get_model_response(
+ nonstreaming_responses_model_response, serialize_pydantic=True
+ )
+
+ with patch.object(
+ agent.model._client._client,
+ "send",
+ return_value=response,
+ ) as _:
+ sentry_init(
+ integrations=[OpenAIAgentsIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ # Don't pass conversation_id
+ result = await agents.Runner.run(
+ agent, "Test input", run_config=test_run_config
+ )
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+ invoke_agent_span = next(
+ span for span in spans if span["op"] == OP.GEN_AI_INVOKE_AGENT
+ )
+ ai_client_span = next(span for span in spans if span["op"] == OP.GEN_AI_CHAT)
+
+ # Verify conversation_id is NOT set on any spans
+ assert "gen_ai.conversation.id" not in transaction["contexts"]["trace"].get(
+ "data", {}
+ )
+ assert "gen_ai.conversation.id" not in invoke_agent_span.get("data", {})
+ assert "gen_ai.conversation.id" not in ai_client_span.get("data", {})
diff --git a/tests/integrations/openfeature/__init__.py b/tests/integrations/openfeature/__init__.py
new file mode 100644
index 0000000000..a17549ea79
--- /dev/null
+++ b/tests/integrations/openfeature/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("openfeature")
diff --git a/tests/integrations/openfeature/test_openfeature.py b/tests/integrations/openfeature/test_openfeature.py
new file mode 100644
index 0000000000..46acc61ae7
--- /dev/null
+++ b/tests/integrations/openfeature/test_openfeature.py
@@ -0,0 +1,179 @@
+import concurrent.futures as cf
+import sys
+
+import pytest
+
+from openfeature import api
+from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider
+
+import sentry_sdk
+from sentry_sdk import start_span, start_transaction
+from sentry_sdk.integrations.openfeature import OpenFeatureIntegration
+from tests.conftest import ApproxDict
+
+
+def test_openfeature_integration(sentry_init, capture_events, uninstall_integration):
+ uninstall_integration(OpenFeatureIntegration.identifier)
+ sentry_init(integrations=[OpenFeatureIntegration()])
+
+ flags = {
+ "hello": InMemoryFlag("on", {"on": True, "off": False}),
+ "world": InMemoryFlag("off", {"on": True, "off": False}),
+ }
+ api.set_provider(InMemoryProvider(flags))
+
+ client = api.get_client()
+ client.get_boolean_value("hello", default_value=False)
+ client.get_boolean_value("world", default_value=False)
+ client.get_boolean_value("other", default_value=True)
+
+ events = capture_events()
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 1
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": False},
+ {"flag": "other", "result": True},
+ ]
+ }
+
+
+def test_openfeature_integration_threaded(
+ sentry_init, capture_events, uninstall_integration
+):
+ uninstall_integration(OpenFeatureIntegration.identifier)
+ sentry_init(integrations=[OpenFeatureIntegration()])
+ events = capture_events()
+
+ flags = {
+ "hello": InMemoryFlag("on", {"on": True, "off": False}),
+ "world": InMemoryFlag("off", {"on": True, "off": False}),
+ }
+ api.set_provider(InMemoryProvider(flags))
+
+ # Capture an eval before we split isolation scopes.
+ client = api.get_client()
+ client.get_boolean_value("hello", default_value=False)
+
+ def task(flag):
+ # Create a new isolation scope for the thread. This means the flags
+ with sentry_sdk.isolation_scope():
+ client.get_boolean_value(flag, default_value=False)
+ # use a tag to identify to identify events later on
+ sentry_sdk.set_tag("task_id", flag)
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ # Run tasks in separate threads
+ with cf.ThreadPoolExecutor(max_workers=2) as pool:
+ pool.map(task, ["world", "other"])
+
+ # Capture error in original scope
+ sentry_sdk.set_tag("task_id", "0")
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 3
+ events.sort(key=lambda e: e["tags"]["task_id"])
+
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ ]
+ }
+ assert events[1]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "other", "result": False},
+ ]
+ }
+ assert events[2]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": False},
+ ]
+ }
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
+def test_openfeature_integration_asyncio(
+ sentry_init, capture_events, uninstall_integration
+):
+ """Assert concurrently evaluated flags do not pollute one another."""
+
+ asyncio = pytest.importorskip("asyncio")
+
+ uninstall_integration(OpenFeatureIntegration.identifier)
+ sentry_init(integrations=[OpenFeatureIntegration()])
+ events = capture_events()
+
+ async def task(flag):
+ with sentry_sdk.isolation_scope():
+ client.get_boolean_value(flag, default_value=False)
+ # use a tag to identify to identify events later on
+ sentry_sdk.set_tag("task_id", flag)
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ async def runner():
+ return asyncio.gather(task("world"), task("other"))
+
+ flags = {
+ "hello": InMemoryFlag("on", {"on": True, "off": False}),
+ "world": InMemoryFlag("off", {"on": True, "off": False}),
+ }
+ api.set_provider(InMemoryProvider(flags))
+
+ # Capture an eval before we split isolation scopes.
+ client = api.get_client()
+ client.get_boolean_value("hello", default_value=False)
+
+ asyncio.run(runner())
+
+ # Capture error in original scope
+ sentry_sdk.set_tag("task_id", "0")
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 3
+ events.sort(key=lambda e: e["tags"]["task_id"])
+
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ ]
+ }
+ assert events[1]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "other", "result": False},
+ ]
+ }
+ assert events[2]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": False},
+ ]
+ }
+
+
+def test_openfeature_span_integration(
+ sentry_init, capture_events, uninstall_integration
+):
+ uninstall_integration(OpenFeatureIntegration.identifier)
+ sentry_init(traces_sample_rate=1.0, integrations=[OpenFeatureIntegration()])
+
+ api.set_provider(
+ InMemoryProvider({"hello": InMemoryFlag("on", {"on": True, "off": False})})
+ )
+ client = api.get_client()
+
+ events = capture_events()
+
+ with start_transaction(name="hi"):
+ with start_span(op="foo", name="bar"):
+ client.get_boolean_value("hello", default_value=False)
+ client.get_boolean_value("world", default_value=False)
+
+ (event,) = events
+ assert event["spans"][0]["data"] == ApproxDict(
+ {"flag.evaluation.hello": True, "flag.evaluation.world": False}
+ )
diff --git a/tests/integrations/opentelemetry/__init__.py b/tests/integrations/opentelemetry/__init__.py
index 39ecc610d5..75763c2fee 100644
--- a/tests/integrations/opentelemetry/__init__.py
+++ b/tests/integrations/opentelemetry/__init__.py
@@ -1,3 +1,3 @@
import pytest
-django = pytest.importorskip("opentelemetry")
+pytest.importorskip("opentelemetry")
diff --git a/tests/integrations/opentelemetry/test_entry_points.py b/tests/integrations/opentelemetry/test_entry_points.py
new file mode 100644
index 0000000000..cd78209432
--- /dev/null
+++ b/tests/integrations/opentelemetry/test_entry_points.py
@@ -0,0 +1,17 @@
+import importlib
+import os
+from unittest.mock import patch
+
+from opentelemetry import propagate
+from sentry_sdk.integrations.opentelemetry import SentryPropagator
+
+
+def test_propagator_loaded_if_mentioned_in_environment_variable():
+ try:
+ with patch.dict(os.environ, {"OTEL_PROPAGATORS": "sentry"}):
+ importlib.reload(propagate)
+
+ assert len(propagate.propagators) == 1
+ assert isinstance(propagate.propagators[0], SentryPropagator)
+ finally:
+ importlib.reload(propagate)
diff --git a/tests/integrations/opentelemetry/test_experimental.py b/tests/integrations/opentelemetry/test_experimental.py
new file mode 100644
index 0000000000..8e4b703361
--- /dev/null
+++ b/tests/integrations/opentelemetry/test_experimental.py
@@ -0,0 +1,47 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+
+@pytest.mark.forked
+def test_integration_enabled_if_option_is_on(sentry_init, reset_integrations):
+ mocked_setup_once = MagicMock()
+
+ with patch(
+ "sentry_sdk.integrations.opentelemetry.integration.OpenTelemetryIntegration.setup_once",
+ mocked_setup_once,
+ ):
+ sentry_init(
+ _experiments={
+ "otel_powered_performance": True,
+ },
+ )
+ mocked_setup_once.assert_called_once()
+
+
+@pytest.mark.forked
+def test_integration_not_enabled_if_option_is_off(sentry_init, reset_integrations):
+ mocked_setup_once = MagicMock()
+
+ with patch(
+ "sentry_sdk.integrations.opentelemetry.integration.OpenTelemetryIntegration.setup_once",
+ mocked_setup_once,
+ ):
+ sentry_init(
+ _experiments={
+ "otel_powered_performance": False,
+ },
+ )
+ mocked_setup_once.assert_not_called()
+
+
+@pytest.mark.forked
+def test_integration_not_enabled_if_option_is_missing(sentry_init, reset_integrations):
+ mocked_setup_once = MagicMock()
+
+ with patch(
+ "sentry_sdk.integrations.opentelemetry.integration.OpenTelemetryIntegration.setup_once",
+ mocked_setup_once,
+ ):
+ sentry_init()
+ mocked_setup_once.assert_not_called()
diff --git a/tests/integrations/opentelemetry/test_propagator.py b/tests/integrations/opentelemetry/test_propagator.py
index 529aa99c09..d999b0bb2b 100644
--- a/tests/integrations/opentelemetry/test_propagator.py
+++ b/tests/integrations/opentelemetry/test_propagator.py
@@ -1,23 +1,26 @@
-from mock import MagicMock
-import mock
+import pytest
+
+from unittest import mock
+from unittest.mock import MagicMock
from opentelemetry.context import get_current
-from opentelemetry.trace.propagation import get_current_span
from opentelemetry.trace import (
- set_span_in_context,
- TraceFlags,
SpanContext,
+ TraceFlags,
+ set_span_in_context,
)
+from opentelemetry.trace.propagation import get_current_span
+
from sentry_sdk.integrations.opentelemetry.consts import (
SENTRY_BAGGAGE_KEY,
SENTRY_TRACE_KEY,
)
-
from sentry_sdk.integrations.opentelemetry.propagator import SentryPropagator
from sentry_sdk.integrations.opentelemetry.span_processor import SentrySpanProcessor
from sentry_sdk.tracing_utils import Baggage
+@pytest.mark.forked
def test_extract_no_context_no_sentry_trace_header():
"""
No context and NO Sentry trace data in getter.
@@ -33,6 +36,7 @@ def test_extract_no_context_no_sentry_trace_header():
assert modified_context == {}
+@pytest.mark.forked
def test_extract_context_no_sentry_trace_header():
"""
Context but NO Sentry trace data in getter.
@@ -48,6 +52,7 @@ def test_extract_context_no_sentry_trace_header():
assert modified_context == context
+@pytest.mark.forked
def test_extract_empty_context_sentry_trace_header_no_baggage():
"""
Empty context but Sentry trace data but NO Baggage in getter.
@@ -77,6 +82,7 @@ def test_extract_empty_context_sentry_trace_header_no_baggage():
assert span_context.trace_id == int("1234567890abcdef1234567890abcdef", 16)
+@pytest.mark.forked
def test_extract_context_sentry_trace_header_baggage():
"""
Empty context but Sentry trace data and Baggage in getter.
@@ -117,6 +123,7 @@ def test_extract_context_sentry_trace_header_baggage():
assert span_context.trace_id == int("1234567890abcdef1234567890abcdef", 16)
+@pytest.mark.forked
def test_inject_empty_otel_span_map():
"""
Empty otel_span_map.
@@ -135,7 +142,7 @@ def test_inject_empty_otel_span_map():
is_remote=True,
)
span = MagicMock()
- span.context = span_context
+ span.get_span_context.return_value = span_context
with mock.patch(
"sentry_sdk.integrations.opentelemetry.propagator.trace.get_current_span",
@@ -147,6 +154,7 @@ def test_inject_empty_otel_span_map():
setter.set.assert_not_called()
+@pytest.mark.forked
def test_inject_sentry_span_no_baggage():
"""
Inject a sentry span with no baggage.
@@ -166,7 +174,7 @@ def test_inject_sentry_span_no_baggage():
is_remote=True,
)
span = MagicMock()
- span.context = span_context
+ span.get_span_context.return_value = span_context
sentry_span = MagicMock()
sentry_span.to_traceparent = mock.Mock(
@@ -191,6 +199,50 @@ def test_inject_sentry_span_no_baggage():
)
+def test_inject_sentry_span_empty_baggage():
+ """
+ Inject a sentry span with no baggage.
+ """
+ carrier = None
+ context = get_current()
+ setter = MagicMock()
+ setter.set = MagicMock()
+
+ trace_id = "1234567890abcdef1234567890abcdef"
+ span_id = "1234567890abcdef"
+
+ span_context = SpanContext(
+ trace_id=int(trace_id, 16),
+ span_id=int(span_id, 16),
+ trace_flags=TraceFlags(TraceFlags.SAMPLED),
+ is_remote=True,
+ )
+ span = MagicMock()
+ span.get_span_context.return_value = span_context
+
+ sentry_span = MagicMock()
+ sentry_span.to_traceparent = mock.Mock(
+ return_value="1234567890abcdef1234567890abcdef-1234567890abcdef-1"
+ )
+ sentry_span.containing_transaction.get_baggage = mock.Mock(return_value=Baggage({}))
+
+ span_processor = SentrySpanProcessor()
+ span_processor.otel_span_map[span_id] = sentry_span
+
+ with mock.patch(
+ "sentry_sdk.integrations.opentelemetry.propagator.trace.get_current_span",
+ return_value=span,
+ ):
+ full_context = set_span_in_context(span, context)
+ SentryPropagator().inject(carrier, full_context, setter)
+
+ setter.set.assert_called_once_with(
+ carrier,
+ "sentry-trace",
+ "1234567890abcdef1234567890abcdef-1234567890abcdef-1",
+ )
+
+
def test_inject_sentry_span_baggage():
"""
Inject a sentry span with baggage.
@@ -210,7 +262,7 @@ def test_inject_sentry_span_baggage():
is_remote=True,
)
span = MagicMock()
- span.context = span_context
+ span.get_span_context.return_value = span_context
sentry_span = MagicMock()
sentry_span.to_traceparent = mock.Mock(
diff --git a/tests/integrations/opentelemetry/test_span_processor.py b/tests/integrations/opentelemetry/test_span_processor.py
index 0467da7673..e1cd849b94 100644
--- a/tests/integrations/opentelemetry/test_span_processor.py
+++ b/tests/integrations/opentelemetry/test_span_processor.py
@@ -1,42 +1,43 @@
-from datetime import datetime
-from mock import MagicMock
-import mock
import time
+from datetime import datetime, timezone
+from unittest import mock
+from unittest.mock import MagicMock
+
+import pytest
+from opentelemetry.trace import SpanKind, SpanContext, Status, StatusCode
+
+import sentry_sdk
from sentry_sdk.integrations.opentelemetry.span_processor import (
SentrySpanProcessor,
link_trace_context_to_error_event,
)
+from sentry_sdk.utils import Dsn
from sentry_sdk.tracing import Span, Transaction
-
-from opentelemetry.trace import SpanKind, SpanContext
from sentry_sdk.tracing_utils import extract_sentrytrace_data
def test_is_sentry_span():
otel_span = MagicMock()
- hub = MagicMock()
- hub.client = None
-
span_processor = SentrySpanProcessor()
- assert not span_processor._is_sentry_span(hub, otel_span)
+ assert not span_processor._is_sentry_span(otel_span)
client = MagicMock()
client.options = {"instrumenter": "otel"}
- client.dsn = "https://1234567890abcdef@o123456.ingest.sentry.io/123456"
+ client.parsed_dsn = Dsn("https://1234567890abcdef@o123456.ingest.sentry.io/123456")
+ sentry_sdk.get_global_scope().set_client(client)
- hub.client = client
- assert not span_processor._is_sentry_span(hub, otel_span)
+ assert not span_processor._is_sentry_span(otel_span)
otel_span.attributes = {
"http.url": "https://example.com",
}
- assert not span_processor._is_sentry_span(hub, otel_span)
+ assert not span_processor._is_sentry_span(otel_span)
otel_span.attributes = {
"http.url": "https://o123456.ingest.sentry.io/api/123/envelope",
}
- assert span_processor._is_sentry_span(hub, otel_span)
+ assert span_processor._is_sentry_span(otel_span)
def test_get_otel_context():
@@ -56,15 +57,20 @@ def test_get_otel_context():
def test_get_trace_data_with_span_and_trace():
otel_span = MagicMock()
- otel_span.context = MagicMock()
- otel_span.context.trace_id = int("1234567890abcdef1234567890abcdef", 16)
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
otel_span.parent = None
parent_context = {}
span_processor = SentrySpanProcessor()
- sentry_trace_data = span_processor._get_trace_data(otel_span, parent_context)
+ sentry_trace_data = span_processor._get_trace_data(
+ otel_span.get_span_context(), otel_span.parent, parent_context
+ )
assert sentry_trace_data["trace_id"] == "1234567890abcdef1234567890abcdef"
assert sentry_trace_data["span_id"] == "1234567890abcdef"
assert sentry_trace_data["parent_span_id"] is None
@@ -74,16 +80,21 @@ def test_get_trace_data_with_span_and_trace():
def test_get_trace_data_with_span_and_trace_and_parent():
otel_span = MagicMock()
- otel_span.context = MagicMock()
- otel_span.context.trace_id = int("1234567890abcdef1234567890abcdef", 16)
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
otel_span.parent = MagicMock()
otel_span.parent.span_id = int("abcdef1234567890", 16)
parent_context = {}
span_processor = SentrySpanProcessor()
- sentry_trace_data = span_processor._get_trace_data(otel_span, parent_context)
+ sentry_trace_data = span_processor._get_trace_data(
+ otel_span.get_span_context(), otel_span.parent, parent_context
+ )
assert sentry_trace_data["trace_id"] == "1234567890abcdef1234567890abcdef"
assert sentry_trace_data["span_id"] == "1234567890abcdef"
assert sentry_trace_data["parent_span_id"] == "abcdef1234567890"
@@ -93,9 +104,12 @@ def test_get_trace_data_with_span_and_trace_and_parent():
def test_get_trace_data_with_sentry_trace():
otel_span = MagicMock()
- otel_span.context = MagicMock()
- otel_span.context.trace_id = int("1234567890abcdef1234567890abcdef", 16)
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
otel_span.parent = MagicMock()
otel_span.parent.span_id = int("abcdef1234567890", 16)
@@ -111,7 +125,9 @@ def test_get_trace_data_with_sentry_trace():
],
):
span_processor = SentrySpanProcessor()
- sentry_trace_data = span_processor._get_trace_data(otel_span, parent_context)
+ sentry_trace_data = span_processor._get_trace_data(
+ otel_span.get_span_context(), otel_span.parent, parent_context
+ )
assert sentry_trace_data["trace_id"] == "1234567890abcdef1234567890abcdef"
assert sentry_trace_data["span_id"] == "1234567890abcdef"
assert sentry_trace_data["parent_span_id"] == "abcdef1234567890"
@@ -128,7 +144,9 @@ def test_get_trace_data_with_sentry_trace():
],
):
span_processor = SentrySpanProcessor()
- sentry_trace_data = span_processor._get_trace_data(otel_span, parent_context)
+ sentry_trace_data = span_processor._get_trace_data(
+ otel_span.get_span_context(), otel_span.parent, parent_context
+ )
assert sentry_trace_data["trace_id"] == "1234567890abcdef1234567890abcdef"
assert sentry_trace_data["span_id"] == "1234567890abcdef"
assert sentry_trace_data["parent_span_id"] == "abcdef1234567890"
@@ -138,9 +156,12 @@ def test_get_trace_data_with_sentry_trace():
def test_get_trace_data_with_sentry_trace_and_baggage():
otel_span = MagicMock()
- otel_span.context = MagicMock()
- otel_span.context.trace_id = int("1234567890abcdef1234567890abcdef", 16)
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
otel_span.parent = MagicMock()
otel_span.parent.span_id = int("abcdef1234567890", 16)
@@ -162,7 +183,9 @@ def test_get_trace_data_with_sentry_trace_and_baggage():
],
):
span_processor = SentrySpanProcessor()
- sentry_trace_data = span_processor._get_trace_data(otel_span, parent_context)
+ sentry_trace_data = span_processor._get_trace_data(
+ otel_span.get_span_context(), otel_span.parent, parent_context
+ )
assert sentry_trace_data["trace_id"] == "1234567890abcdef1234567890abcdef"
assert sentry_trace_data["span_id"] == "1234567890abcdef"
assert sentry_trace_data["parent_span_id"] == "abcdef1234567890"
@@ -190,17 +213,38 @@ def test_update_span_with_otel_data_http_method():
assert sentry_span.op == "http.client"
assert sentry_span.description == "GET example.com /"
- assert sentry_span._tags["http.status_code"] == "429"
assert sentry_span.status == "resource_exhausted"
assert sentry_span._data["http.method"] == "GET"
- assert sentry_span._data["http.status_code"] == 429
+ assert sentry_span._data["http.response.status_code"] == 429
assert sentry_span._data["http.status_text"] == "xxx"
assert sentry_span._data["http.user_agent"] == "curl/7.64.1"
assert sentry_span._data["net.peer.name"] == "example.com"
assert sentry_span._data["http.target"] == "/"
+@pytest.mark.parametrize(
+ "otel_status, expected_status",
+ [
+ pytest.param(Status(StatusCode.UNSET), None, id="unset"),
+ pytest.param(Status(StatusCode.OK), "ok", id="ok"),
+ pytest.param(Status(StatusCode.ERROR), "internal_error", id="error"),
+ ],
+)
+def test_update_span_with_otel_status(otel_status, expected_status):
+ sentry_span = Span()
+
+ otel_span = MagicMock()
+ otel_span.name = "Test OTel Span"
+ otel_span.kind = SpanKind.INTERNAL
+ otel_span.status = otel_status
+
+ span_processor = SentrySpanProcessor()
+ span_processor._update_span_with_otel_status(sentry_span, otel_span)
+
+ assert sentry_span.get_trace_context().get("status") == expected_status
+
+
def test_update_span_with_otel_data_http_method2():
sentry_span = Span()
@@ -220,11 +264,10 @@ def test_update_span_with_otel_data_http_method2():
assert sentry_span.op == "http.server"
assert sentry_span.description == "GET https://example.com/status/403"
- assert sentry_span._tags["http.status_code"] == "429"
assert sentry_span.status == "resource_exhausted"
assert sentry_span._data["http.method"] == "GET"
- assert sentry_span._data["http.status_code"] == 429
+ assert sentry_span._data["http.response.status_code"] == 429
assert sentry_span._data["http.status_text"] == "xxx"
assert sentry_span._data["http.user_agent"] == "curl/7.64.1"
assert (
@@ -259,38 +302,42 @@ def test_on_start_transaction():
otel_span = MagicMock()
otel_span.name = "Sample OTel Span"
otel_span.start_time = time.time_ns()
- otel_span.context = MagicMock()
- otel_span.context.trace_id = int("1234567890abcdef1234567890abcdef", 16)
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
otel_span.parent = MagicMock()
otel_span.parent.span_id = int("abcdef1234567890", 16)
parent_context = {}
+ fake_start_transaction = MagicMock()
+
fake_client = MagicMock()
fake_client.options = {"instrumenter": "otel"}
fake_client.dsn = "https://1234567890abcdef@o123456.ingest.sentry.io/123456"
-
- current_hub = MagicMock()
- current_hub.client = fake_client
-
- fake_hub = MagicMock()
- fake_hub.current = current_hub
+ sentry_sdk.get_global_scope().set_client(fake_client)
with mock.patch(
- "sentry_sdk.integrations.opentelemetry.span_processor.Hub", fake_hub
+ "sentry_sdk.integrations.opentelemetry.span_processor.start_transaction",
+ fake_start_transaction,
):
span_processor = SentrySpanProcessor()
span_processor.on_start(otel_span, parent_context)
- fake_hub.current.start_transaction.assert_called_once_with(
+ fake_start_transaction.assert_called_once_with(
name="Sample OTel Span",
span_id="1234567890abcdef",
parent_span_id="abcdef1234567890",
trace_id="1234567890abcdef1234567890abcdef",
baggage=None,
- start_timestamp=datetime.fromtimestamp(otel_span.start_time / 1e9),
+ start_timestamp=datetime.fromtimestamp(
+ otel_span.start_time / 1e9, timezone.utc
+ ),
instrumenter="otel",
+ origin="auto.otel",
)
assert len(span_processor.otel_span_map.keys()) == 1
@@ -301,9 +348,12 @@ def test_on_start_child():
otel_span = MagicMock()
otel_span.name = "Sample OTel Span"
otel_span.start_time = time.time_ns()
- otel_span.context = MagicMock()
- otel_span.context.trace_id = int("1234567890abcdef1234567890abcdef", 16)
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
otel_span.parent = MagicMock()
otel_span.parent.span_id = int("abcdef1234567890", 16)
@@ -312,32 +362,27 @@ def test_on_start_child():
fake_client = MagicMock()
fake_client.options = {"instrumenter": "otel"}
fake_client.dsn = "https://1234567890abcdef@o123456.ingest.sentry.io/123456"
+ sentry_sdk.get_global_scope().set_client(fake_client)
- current_hub = MagicMock()
- current_hub.client = fake_client
+ fake_span = MagicMock()
- fake_hub = MagicMock()
- fake_hub.current = current_hub
-
- with mock.patch(
- "sentry_sdk.integrations.opentelemetry.span_processor.Hub", fake_hub
- ):
- fake_span = MagicMock()
-
- span_processor = SentrySpanProcessor()
- span_processor.otel_span_map["abcdef1234567890"] = fake_span
- span_processor.on_start(otel_span, parent_context)
-
- fake_span.start_child.assert_called_once_with(
- span_id="1234567890abcdef",
- description="Sample OTel Span",
- start_timestamp=datetime.fromtimestamp(otel_span.start_time / 1e9),
- instrumenter="otel",
- )
+ span_processor = SentrySpanProcessor()
+ span_processor.otel_span_map["abcdef1234567890"] = fake_span
+ span_processor.on_start(otel_span, parent_context)
+
+ fake_span.start_child.assert_called_once_with(
+ span_id="1234567890abcdef",
+ name="Sample OTel Span",
+ start_timestamp=datetime.fromtimestamp(
+ otel_span.start_time / 1e9, timezone.utc
+ ),
+ instrumenter="otel",
+ origin="auto.otel",
+ )
- assert len(span_processor.otel_span_map.keys()) == 2
- assert "abcdef1234567890" in span_processor.otel_span_map.keys()
- assert "1234567890abcdef" in span_processor.otel_span_map.keys()
+ assert len(span_processor.otel_span_map.keys()) == 2
+ assert "abcdef1234567890" in span_processor.otel_span_map.keys()
+ assert "1234567890abcdef" in span_processor.otel_span_map.keys()
def test_on_end_no_sentry_span():
@@ -347,8 +392,12 @@ def test_on_end_no_sentry_span():
otel_span = MagicMock()
otel_span.name = "Sample OTel Span"
otel_span.end_time = time.time_ns()
- otel_span.context = MagicMock()
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
span_processor = SentrySpanProcessor()
span_processor.otel_span_map = {}
@@ -368,8 +417,17 @@ def test_on_end_sentry_transaction():
otel_span = MagicMock()
otel_span.name = "Sample OTel Span"
otel_span.end_time = time.time_ns()
- otel_span.context = MagicMock()
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ otel_span.status = Status(StatusCode.OK)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
+
+ fake_client = MagicMock()
+ fake_client.options = {"instrumenter": "otel"}
+ sentry_sdk.get_global_scope().set_client(fake_client)
fake_sentry_span = MagicMock(spec=Transaction)
fake_sentry_span.set_context = MagicMock()
@@ -384,6 +442,7 @@ def test_on_end_sentry_transaction():
fake_sentry_span.set_context.assert_called_once()
span_processor._update_span_with_otel_data.assert_not_called()
+ fake_sentry_span.set_status.assert_called_once_with("ok")
fake_sentry_span.finish.assert_called_once()
@@ -394,8 +453,17 @@ def test_on_end_sentry_span():
otel_span = MagicMock()
otel_span.name = "Sample OTel Span"
otel_span.end_time = time.time_ns()
- otel_span.context = MagicMock()
- otel_span.context.span_id = int("1234567890abcdef", 16)
+ otel_span.status = Status(StatusCode.OK)
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
+
+ fake_client = MagicMock()
+ fake_client.options = {"instrumenter": "otel"}
+ sentry_sdk.get_global_scope().set_client(fake_client)
fake_sentry_span = MagicMock(spec=Span)
fake_sentry_span.set_context = MagicMock()
@@ -412,6 +480,7 @@ def test_on_end_sentry_span():
span_processor._update_span_with_otel_data.assert_called_once_with(
fake_sentry_span, otel_span
)
+ fake_sentry_span.set_status.assert_called_once_with("ok")
fake_sentry_span.finish.assert_called_once()
@@ -421,13 +490,7 @@ def test_link_trace_context_to_error_event():
"""
fake_client = MagicMock()
fake_client.options = {"instrumenter": "otel"}
- fake_client
-
- current_hub = MagicMock()
- current_hub.client = fake_client
-
- fake_hub = MagicMock()
- fake_hub.current = current_hub
+ sentry_sdk.get_global_scope().set_client(fake_client)
span_id = "1234567890abcdef"
trace_id = "1234567890abcdef1234567890abcdef"
@@ -466,3 +529,95 @@ def test_link_trace_context_to_error_event():
assert "contexts" in event
assert "trace" in event["contexts"]
assert event["contexts"]["trace"] == fake_trace_context
+
+
+def test_pruning_old_spans_on_start():
+ otel_span = MagicMock()
+ otel_span.name = "Sample OTel Span"
+ otel_span.start_time = time.time_ns()
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
+ otel_span.parent = MagicMock()
+ otel_span.parent.span_id = int("abcdef1234567890", 16)
+
+ parent_context = {}
+ fake_client = MagicMock()
+ fake_client.options = {"instrumenter": "otel", "debug": False}
+ fake_client.dsn = "https://1234567890abcdef@o123456.ingest.sentry.io/123456"
+ sentry_sdk.get_global_scope().set_client(fake_client)
+
+ span_processor = SentrySpanProcessor()
+
+ span_processor.otel_span_map = {
+ "111111111abcdef": MagicMock(), # should stay
+ "2222222222abcdef": MagicMock(), # should go
+ "3333333333abcdef": MagicMock(), # should go
+ }
+ current_time_minutes = int(time.time() / 60)
+ span_processor.open_spans = {
+ current_time_minutes - 3: {"111111111abcdef"}, # should stay
+ current_time_minutes - 11: {
+ "2222222222abcdef",
+ "3333333333abcdef",
+ }, # should go
+ }
+
+ span_processor.on_start(otel_span, parent_context)
+ assert sorted(list(span_processor.otel_span_map.keys())) == [
+ "111111111abcdef",
+ "1234567890abcdef",
+ ]
+ assert sorted(list(span_processor.open_spans.values())) == [
+ {"111111111abcdef"},
+ {"1234567890abcdef"},
+ ]
+
+
+def test_pruning_old_spans_on_end():
+ otel_span = MagicMock()
+ otel_span.name = "Sample OTel Span"
+ otel_span.start_time = time.time_ns()
+ span_context = SpanContext(
+ trace_id=int("1234567890abcdef1234567890abcdef", 16),
+ span_id=int("1234567890abcdef", 16),
+ is_remote=True,
+ )
+ otel_span.get_span_context.return_value = span_context
+ otel_span.parent = MagicMock()
+ otel_span.parent.span_id = int("abcdef1234567890", 16)
+
+ fake_client = MagicMock()
+ fake_client.options = {"instrumenter": "otel"}
+ sentry_sdk.get_global_scope().set_client(fake_client)
+
+ fake_sentry_span = MagicMock(spec=Span)
+ fake_sentry_span.set_context = MagicMock()
+ fake_sentry_span.finish = MagicMock()
+
+ span_processor = SentrySpanProcessor()
+ span_processor._get_otel_context = MagicMock()
+ span_processor._update_span_with_otel_data = MagicMock()
+
+ span_processor.otel_span_map = {
+ "111111111abcdef": MagicMock(), # should stay
+ "2222222222abcdef": MagicMock(), # should go
+ "3333333333abcdef": MagicMock(), # should go
+ "1234567890abcdef": fake_sentry_span, # should go (because it is closed)
+ }
+ current_time_minutes = int(time.time() / 60)
+ span_processor.open_spans = {
+ current_time_minutes: {"1234567890abcdef"}, # should go (because it is closed)
+ current_time_minutes - 3: {"111111111abcdef"}, # should stay
+ current_time_minutes - 11: {
+ "2222222222abcdef",
+ "3333333333abcdef",
+ }, # should go
+ }
+
+ span_processor.on_end(otel_span)
+ assert sorted(list(span_processor.otel_span_map.keys())) == ["111111111abcdef"]
+ assert sorted(list(span_processor.open_spans.values())) == [{"111111111abcdef"}]
diff --git a/tests/integrations/otlp/__init__.py b/tests/integrations/otlp/__init__.py
new file mode 100644
index 0000000000..75763c2fee
--- /dev/null
+++ b/tests/integrations/otlp/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("opentelemetry")
diff --git a/tests/integrations/otlp/test_otlp.py b/tests/integrations/otlp/test_otlp.py
new file mode 100644
index 0000000000..e085a22ac0
--- /dev/null
+++ b/tests/integrations/otlp/test_otlp.py
@@ -0,0 +1,370 @@
+import pytest
+import responses
+
+from opentelemetry import trace
+from opentelemetry.trace import (
+ get_tracer_provider,
+ set_tracer_provider,
+ ProxyTracerProvider,
+ format_span_id,
+ format_trace_id,
+ get_current_span,
+)
+from opentelemetry.context import attach, detach
+from opentelemetry.propagate import get_global_textmap, set_global_textmap
+from opentelemetry.util._once import Once
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import BatchSpanProcessor
+from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
+
+from sentry_sdk.integrations.otlp import OTLPIntegration, SentryOTLPPropagator
+from sentry_sdk.scope import get_external_propagation_context
+
+
+original_propagator = get_global_textmap()
+
+
+@pytest.fixture(autouse=True)
+def mock_otlp_ingest():
+ responses.start()
+ responses.add(
+ responses.POST,
+ url="https://bla.ingest.sentry.io/api/12312012/integration/otlp/v1/traces/",
+ status=200,
+ )
+ responses.add(
+ responses.POST,
+ url="https://my-collector.example.com/v1/traces",
+ status=200,
+ )
+
+ yield
+
+ tracer_provider = get_tracer_provider()
+ if isinstance(tracer_provider, TracerProvider):
+ tracer_provider.force_flush()
+
+ responses.stop()
+ responses.reset()
+
+
+@pytest.fixture(autouse=True)
+def reset_otlp(uninstall_integration):
+ trace._TRACER_PROVIDER_SET_ONCE = Once()
+ trace._TRACER_PROVIDER = None
+
+ set_global_textmap(original_propagator)
+
+ uninstall_integration("otlp")
+
+
+def test_sets_new_tracer_provider_with_otlp_exporter(sentry_init):
+ existing_tracer_provider = get_tracer_provider()
+ assert isinstance(existing_tracer_provider, ProxyTracerProvider)
+
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration()],
+ )
+
+ tracer_provider = get_tracer_provider()
+ assert tracer_provider is not existing_tracer_provider
+ assert isinstance(tracer_provider, TracerProvider)
+
+ (span_processor,) = tracer_provider._active_span_processor._span_processors
+ assert isinstance(span_processor, BatchSpanProcessor)
+
+ exporter = span_processor.span_exporter
+ assert isinstance(exporter, OTLPSpanExporter)
+ assert (
+ exporter._endpoint
+ == "https://bla.ingest.sentry.io/api/12312012/integration/otlp/v1/traces/"
+ )
+ assert "X-Sentry-Auth" in exporter._headers
+ assert (
+ "Sentry sentry_key=mysecret, sentry_version=7, sentry_client=sentry.python/"
+ in exporter._headers["X-Sentry-Auth"]
+ )
+
+
+def test_uses_existing_tracer_provider_with_otlp_exporter(sentry_init):
+ existing_tracer_provider = TracerProvider()
+ set_tracer_provider(existing_tracer_provider)
+
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration()],
+ )
+
+ tracer_provider = get_tracer_provider()
+ assert tracer_provider == existing_tracer_provider
+ assert isinstance(tracer_provider, TracerProvider)
+
+ (span_processor,) = tracer_provider._active_span_processor._span_processors
+ assert isinstance(span_processor, BatchSpanProcessor)
+
+ exporter = span_processor.span_exporter
+ assert isinstance(exporter, OTLPSpanExporter)
+ assert (
+ exporter._endpoint
+ == "https://bla.ingest.sentry.io/api/12312012/integration/otlp/v1/traces/"
+ )
+ assert "X-Sentry-Auth" in exporter._headers
+ assert (
+ "Sentry sentry_key=mysecret, sentry_version=7, sentry_client=sentry.python/"
+ in exporter._headers["X-Sentry-Auth"]
+ )
+
+
+def test_does_not_setup_exporter_when_disabled(sentry_init):
+ existing_tracer_provider = get_tracer_provider()
+ assert isinstance(existing_tracer_provider, ProxyTracerProvider)
+
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration(setup_otlp_traces_exporter=False)],
+ )
+
+ tracer_provider = get_tracer_provider()
+ assert tracer_provider is existing_tracer_provider
+
+
+def test_sets_propagator(sentry_init):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration()],
+ )
+
+ propagator = get_global_textmap()
+ assert isinstance(get_global_textmap(), SentryOTLPPropagator)
+ assert propagator is not original_propagator
+
+
+def test_does_not_set_propagator_if_disabled(sentry_init):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration(setup_propagator=False)],
+ )
+
+ propagator = get_global_textmap()
+ assert not isinstance(propagator, SentryOTLPPropagator)
+ assert propagator is original_propagator
+
+
+def test_otel_propagation_context(sentry_init):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration()],
+ )
+
+ tracer = trace.get_tracer(__name__)
+ with tracer.start_as_current_span("foo") as root_span:
+ with tracer.start_as_current_span("bar") as span:
+ external_propagation_context = get_external_propagation_context()
+
+ assert external_propagation_context is not None
+ (trace_id, span_id) = external_propagation_context
+ assert trace_id == format_trace_id(root_span.get_span_context().trace_id)
+ assert trace_id == format_trace_id(span.get_span_context().trace_id)
+ assert span_id == format_span_id(span.get_span_context().span_id)
+
+
+def test_propagator_inject_head_of_trace(sentry_init):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration()],
+ )
+
+ tracer = trace.get_tracer(__name__)
+ propagator = get_global_textmap()
+ carrier = {}
+
+ with tracer.start_as_current_span("foo") as span:
+ propagator.inject(carrier)
+
+ span_context = span.get_span_context()
+ trace_id = format_trace_id(span_context.trace_id)
+ span_id = format_span_id(span_context.span_id)
+
+ assert "sentry-trace" in carrier
+ assert carrier["sentry-trace"] == f"{trace_id}-{span_id}-1"
+
+ #! we cannot populate baggage in otlp as head SDK yet
+ assert "baggage" not in carrier
+
+
+def test_propagator_inject_continue_trace(sentry_init):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration()],
+ )
+
+ tracer = trace.get_tracer(__name__)
+ propagator = get_global_textmap()
+ carrier = {}
+
+ incoming_headers = {
+ "sentry-trace": "771a43a4192642f0b136d5159a501700-1234567890abcdef-1",
+ "baggage": (
+ "sentry-trace_id=771a43a4192642f0b136d5159a501700,sentry-sampled=true"
+ ),
+ }
+
+ ctx = propagator.extract(incoming_headers)
+ token = attach(ctx)
+
+ parent_span_context = get_current_span().get_span_context()
+ assert (
+ format_trace_id(parent_span_context.trace_id)
+ == "771a43a4192642f0b136d5159a501700"
+ )
+ assert format_span_id(parent_span_context.span_id) == "1234567890abcdef"
+
+ with tracer.start_as_current_span("foo") as span:
+ propagator.inject(carrier)
+
+ span_context = span.get_span_context()
+ trace_id = format_trace_id(span_context.trace_id)
+ span_id = format_span_id(span_context.span_id)
+
+ assert trace_id == "771a43a4192642f0b136d5159a501700"
+
+ assert "sentry-trace" in carrier
+ assert carrier["sentry-trace"] == f"{trace_id}-{span_id}-1"
+
+ assert "baggage" in carrier
+ assert carrier["baggage"] == incoming_headers["baggage"]
+
+ detach(token)
+
+
+def test_collector_url_sets_endpoint(sentry_init):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[
+ OTLPIntegration(collector_url="https://my-collector.example.com/v1/traces")
+ ],
+ )
+
+ tracer_provider = get_tracer_provider()
+ assert isinstance(tracer_provider, TracerProvider)
+
+ (span_processor,) = tracer_provider._active_span_processor._span_processors
+ assert isinstance(span_processor, BatchSpanProcessor)
+
+ exporter = span_processor.span_exporter
+ assert isinstance(exporter, OTLPSpanExporter)
+ assert exporter._endpoint == "https://my-collector.example.com/v1/traces"
+ assert exporter._headers is None or "X-Sentry-Auth" not in exporter._headers
+
+
+def test_collector_url_takes_precedence_over_dsn(sentry_init):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[
+ OTLPIntegration(collector_url="https://my-collector.example.com/v1/traces")
+ ],
+ )
+
+ tracer_provider = get_tracer_provider()
+ assert isinstance(tracer_provider, TracerProvider)
+
+ (span_processor,) = tracer_provider._active_span_processor._span_processors
+ exporter = span_processor.span_exporter
+ assert isinstance(exporter, OTLPSpanExporter)
+ # Should use collector_url, NOT the DSN-derived endpoint
+ assert exporter._endpoint == "https://my-collector.example.com/v1/traces"
+ assert (
+ exporter._endpoint
+ != "https://bla.ingest.sentry.io/api/12312012/integration/otlp/v1/traces/"
+ )
+
+
+def test_collector_url_none_falls_back_to_dsn(sentry_init):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration(collector_url=None)],
+ )
+
+ tracer_provider = get_tracer_provider()
+ assert isinstance(tracer_provider, TracerProvider)
+
+ (span_processor,) = tracer_provider._active_span_processor._span_processors
+ exporter = span_processor.span_exporter
+ assert isinstance(exporter, OTLPSpanExporter)
+ assert (
+ exporter._endpoint
+ == "https://bla.ingest.sentry.io/api/12312012/integration/otlp/v1/traces/"
+ )
+ assert "X-Sentry-Auth" in exporter._headers
+
+
+def test_capture_exceptions_enabled(sentry_init, capture_events):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration(capture_exceptions=True)],
+ )
+
+ events = capture_events()
+
+ tracer = trace.get_tracer(__name__)
+ with tracer.start_as_current_span("test_span") as span:
+ try:
+ raise ValueError("Test exception")
+ except ValueError as e:
+ span.record_exception(e)
+
+ (event,) = events
+ assert event["exception"]["values"][0]["type"] == "ValueError"
+ assert event["exception"]["values"][0]["value"] == "Test exception"
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "otlp"
+ assert event["exception"]["values"][0]["mechanism"]["handled"] is False
+
+ trace_context = event["contexts"]["trace"]
+ assert trace_context["trace_id"] == format_trace_id(
+ span.get_span_context().trace_id
+ )
+ assert trace_context["span_id"] == format_span_id(span.get_span_context().span_id)
+
+
+def test_capture_exceptions_disabled(sentry_init, capture_events):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration(capture_exceptions=False)],
+ )
+
+ events = capture_events()
+
+ tracer = trace.get_tracer(__name__)
+ with tracer.start_as_current_span("test_span") as span:
+ try:
+ raise ValueError("Test exception")
+ except ValueError as e:
+ span.record_exception(e)
+
+ assert len(events) == 0
+
+
+def test_capture_exceptions_preserves_otel_behavior(sentry_init, capture_events):
+ sentry_init(
+ dsn="https://mysecret@bla.ingest.sentry.io/12312012",
+ integrations=[OTLPIntegration(capture_exceptions=True)],
+ )
+
+ events = capture_events()
+
+ tracer = trace.get_tracer(__name__)
+ with tracer.start_as_current_span("test_span") as span:
+ try:
+ raise ValueError("Test exception")
+ except ValueError as e:
+ span.record_exception(e, attributes={"foo": "bar"})
+
+ # Verify the span recorded the exception (OpenTelemetry behavior)
+ # The span should have events with the exception information
+ (otel_event,) = span._events
+ assert otel_event.name == "exception"
+ assert otel_event.attributes["foo"] == "bar"
+
+ # verify sentry also captured it
+ assert len(events) == 1
diff --git a/tests/integrations/pure_eval/__init__.py b/tests/integrations/pure_eval/__init__.py
index 3f645e75f6..47ad99aa8d 100644
--- a/tests/integrations/pure_eval/__init__.py
+++ b/tests/integrations/pure_eval/__init__.py
@@ -1,3 +1,3 @@
import pytest
-pure_eval = pytest.importorskip("pure_eval")
+pytest.importorskip("pure_eval")
diff --git a/tests/integrations/pure_eval/test_pure_eval.py b/tests/integrations/pure_eval/test_pure_eval.py
index 2d1a92026e..497a8768d0 100644
--- a/tests/integrations/pure_eval/test_pure_eval.py
+++ b/tests/integrations/pure_eval/test_pure_eval.py
@@ -1,4 +1,3 @@
-import sys
from types import SimpleNamespace
import pytest
@@ -64,10 +63,7 @@ def foo():
"u",
"y",
]
- if sys.version_info[:2] == (3, 5):
- assert frame_vars.keys() == set(expected_keys)
- else:
- assert list(frame_vars.keys()) == expected_keys
+ assert list(frame_vars.keys()) == expected_keys
assert frame_vars["namespace.d"] == {"1": "2"}
assert frame_vars["namespace.d[1]"] == "2"
else:
diff --git a/tests/integrations/pydantic_ai/__init__.py b/tests/integrations/pydantic_ai/__init__.py
new file mode 100644
index 0000000000..3a2ad11c0c
--- /dev/null
+++ b/tests/integrations/pydantic_ai/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("pydantic_ai")
diff --git a/tests/integrations/pydantic_ai/test_pydantic_ai.py b/tests/integrations/pydantic_ai/test_pydantic_ai.py
new file mode 100644
index 0000000000..50ce155f5b
--- /dev/null
+++ b/tests/integrations/pydantic_ai/test_pydantic_ai.py
@@ -0,0 +1,3063 @@
+import asyncio
+import json
+import pytest
+from unittest.mock import MagicMock
+
+from typing import Annotated
+from pydantic import Field
+
+import sentry_sdk
+from sentry_sdk._types import BLOB_DATA_SUBSTITUTE
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.integrations.pydantic_ai import PydanticAIIntegration
+from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_input_messages
+from sentry_sdk.integrations.pydantic_ai.spans.utils import _set_usage_data
+from pydantic_ai import Agent
+from pydantic_ai.messages import BinaryContent, ImageUrl, UserPromptPart
+from pydantic_ai.usage import RequestUsage
+from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior
+from pydantic_ai.models.function import FunctionModel
+
+
+@pytest.fixture
+def get_test_agent():
+ def inner():
+ """Create a test agent with model settings."""
+ return Agent(
+ "test",
+ name="test_agent",
+ system_prompt="You are a helpful test assistant.",
+ )
+
+ return inner
+
+
+@pytest.fixture
+def get_test_agent_with_settings():
+ def inner():
+ """Create a test agent with explicit model settings."""
+ from pydantic_ai import ModelSettings
+
+ return Agent(
+ "test",
+ name="test_agent_settings",
+ system_prompt="You are a test assistant with settings.",
+ model_settings=ModelSettings(
+ temperature=0.7,
+ max_tokens=100,
+ top_p=0.9,
+ ),
+ )
+
+ return inner
+
+
+@pytest.mark.asyncio
+async def test_agent_run_async(sentry_init, capture_events, get_test_agent):
+ """
+ Test that the integration creates spans for async agent runs.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ result = await test_agent.run("Test input")
+
+ assert result is not None
+ assert result.output is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Verify transaction (the transaction IS the invoke_agent span)
+ assert transaction["transaction"] == "invoke_agent test_agent"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.pydantic_ai"
+
+ # The transaction itself should have invoke_agent data
+ assert transaction["contexts"]["trace"]["op"] == "gen_ai.invoke_agent"
+
+ # Find child span types (invoke_agent is the transaction, not a child span)
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ # Check chat span
+ chat_span = chat_spans[0]
+ assert "chat" in chat_span["description"]
+ assert chat_span["data"]["gen_ai.operation.name"] == "chat"
+ assert chat_span["data"]["gen_ai.response.streaming"] is False
+ assert "gen_ai.request.messages" in chat_span["data"]
+ assert "gen_ai.usage.input_tokens" in chat_span["data"]
+ assert "gen_ai.usage.output_tokens" in chat_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_agent_run_async_model_error(sentry_init, capture_events):
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ def failing_model(messages, info):
+ raise RuntimeError("model exploded")
+
+ agent = Agent(
+ FunctionModel(failing_model),
+ name="test_agent",
+ )
+
+ with pytest.raises(RuntimeError, match="model exploded"):
+ await agent.run("Test input")
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+
+ spans = transaction["spans"]
+ assert len(spans) == 1
+
+ assert spans[0]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_agent_run_async_usage_data(sentry_init, capture_events, get_test_agent):
+ """
+ Test that the invoke_agent span includes token usage and model data.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ result = await test_agent.run("Test input")
+
+ assert result is not None
+ assert result.output is not None
+
+ (transaction,) = events
+
+ # Verify transaction (the transaction IS the invoke_agent span)
+ assert transaction["transaction"] == "invoke_agent test_agent"
+
+ # The invoke_agent span should have token usage data
+ trace_data = transaction["contexts"]["trace"].get("data", {})
+ assert "gen_ai.usage.input_tokens" in trace_data, (
+ "Missing input_tokens on invoke_agent span"
+ )
+ assert "gen_ai.usage.output_tokens" in trace_data, (
+ "Missing output_tokens on invoke_agent span"
+ )
+ assert "gen_ai.usage.total_tokens" in trace_data, (
+ "Missing total_tokens on invoke_agent span"
+ )
+ assert "gen_ai.response.model" in trace_data, (
+ "Missing response.model on invoke_agent span"
+ )
+
+ # Verify the values are reasonable
+ assert trace_data["gen_ai.usage.input_tokens"] > 0
+ assert trace_data["gen_ai.usage.output_tokens"] > 0
+ assert trace_data["gen_ai.usage.total_tokens"] > 0
+ assert trace_data["gen_ai.response.model"] == "test" # Test model name
+
+
+def test_agent_run_sync(sentry_init, capture_events, get_test_agent):
+ """
+ Test that the integration creates spans for sync agent runs.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ result = test_agent.run_sync("Test input")
+
+ assert result is not None
+ assert result.output is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Verify transaction
+ assert transaction["transaction"] == "invoke_agent test_agent"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.pydantic_ai"
+
+ # Find span types
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ # Verify streaming flag is False for sync
+ for chat_span in chat_spans:
+ assert chat_span["data"]["gen_ai.response.streaming"] is False
+
+
+def test_agent_run_sync_model_error(sentry_init, capture_events):
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ def failing_model(messages, info):
+ raise RuntimeError("model exploded")
+
+ agent = Agent(
+ FunctionModel(failing_model),
+ name="test_agent",
+ )
+
+ with pytest.raises(RuntimeError, match="model exploded"):
+ agent.run_sync("Test input")
+
+ (error, transaction) = events
+ assert error["level"] == "error"
+
+ spans = transaction["spans"]
+ assert len(spans) == 1
+
+ assert spans[0]["status"] == "internal_error"
+
+
+@pytest.mark.asyncio
+async def test_agent_run_stream(sentry_init, capture_events, get_test_agent):
+ """
+ Test that the integration creates spans for streaming agent runs.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ async with test_agent.run_stream("Test input") as result:
+ # Consume the stream
+ async for _ in result.stream_output():
+ pass
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Verify transaction
+ assert transaction["transaction"] == "invoke_agent test_agent"
+ assert transaction["contexts"]["trace"]["origin"] == "auto.ai.pydantic_ai"
+
+ # Find chat spans
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ # Verify streaming flag is True for streaming
+ for chat_span in chat_spans:
+ assert chat_span["data"]["gen_ai.response.streaming"] is True
+ assert "gen_ai.request.messages" in chat_span["data"]
+ assert "gen_ai.usage.input_tokens" in chat_span["data"]
+ # Streaming responses should still have output data
+ assert (
+ "gen_ai.response.text" in chat_span["data"]
+ or "gen_ai.response.model" in chat_span["data"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_agent_run_stream_events(sentry_init, capture_events, get_test_agent):
+ """
+ Test that run_stream_events creates spans (it uses run internally, so non-streaming).
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ # Consume all events
+ test_agent = get_test_agent()
+ async for _ in test_agent.run_stream_events("Test input"):
+ pass
+
+ (transaction,) = events
+
+ # Verify transaction
+ assert transaction["transaction"] == "invoke_agent test_agent"
+
+ # Find chat spans
+ spans = transaction["spans"]
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ # run_stream_events uses run() internally, so streaming should be False
+ for chat_span in chat_spans:
+ assert chat_span["data"]["gen_ai.response.streaming"] is False
+
+
+@pytest.mark.asyncio
+async def test_agent_with_tools(sentry_init, capture_events, get_test_agent):
+ """
+ Test that tool execution creates execute_tool spans.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ test_agent = get_test_agent()
+
+ @test_agent.tool_plain
+ def add_numbers(a: int, b: int) -> int:
+ """Add two numbers together."""
+ return a + b
+
+ events = capture_events()
+
+ result = await test_agent.run("What is 5 + 3?")
+
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find child span types (invoke_agent is the transaction, not a child span)
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
+
+ # Should have tool spans
+ assert len(tool_spans) >= 1
+
+ # Check tool span
+ tool_span = tool_spans[0]
+ assert "execute_tool" in tool_span["description"]
+ assert tool_span["data"]["gen_ai.operation.name"] == "execute_tool"
+ assert tool_span["data"]["gen_ai.tool.name"] == "add_numbers"
+ assert "gen_ai.tool.input" in tool_span["data"]
+ assert "gen_ai.tool.output" in tool_span["data"]
+
+ # Check chat spans have available_tools
+ for chat_span in chat_spans:
+ assert "gen_ai.request.available_tools" in chat_span["data"]
+ available_tools_str = chat_span["data"]["gen_ai.request.available_tools"]
+ # Available tools is serialized as a string
+ assert "add_numbers" in available_tools_str
+
+
+@pytest.mark.parametrize(
+ "handled_tool_call_exceptions",
+ [False, True],
+)
+@pytest.mark.asyncio
+async def test_agent_with_tool_model_retry(
+ sentry_init, capture_events, get_test_agent, handled_tool_call_exceptions
+):
+ """
+ Test that a handled exception is captured when a tool raises ModelRetry.
+ """
+ sentry_init(
+ integrations=[
+ PydanticAIIntegration(
+ handled_tool_call_exceptions=handled_tool_call_exceptions
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ retries = 0
+
+ test_agent = get_test_agent()
+
+ @test_agent.tool_plain
+ def add_numbers(a: int, b: int) -> float:
+ """Add two numbers together, but raises an exception on the first attempt."""
+ nonlocal retries
+ if retries == 0:
+ retries += 1
+ raise ModelRetry(message="Try again with the same arguments.")
+ return a + b
+
+ events = capture_events()
+
+ result = await test_agent.run("What is 5 + 3?")
+
+ assert result is not None
+
+ if handled_tool_call_exceptions:
+ (error, transaction) = events
+ else:
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ if handled_tool_call_exceptions:
+ assert error["level"] == "error"
+ assert error["exception"]["values"][0]["mechanism"]["handled"]
+
+ # Find child span types (invoke_agent is the transaction, not a child span)
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
+
+ # Should have tool spans
+ assert len(tool_spans) >= 1
+
+ # Check tool spans
+ model_retry_tool_span = tool_spans[0]
+ assert "execute_tool" in model_retry_tool_span["description"]
+ assert model_retry_tool_span["data"]["gen_ai.operation.name"] == "execute_tool"
+ assert model_retry_tool_span["data"]["gen_ai.tool.name"] == "add_numbers"
+ assert "gen_ai.tool.input" in model_retry_tool_span["data"]
+
+ tool_span = tool_spans[1]
+ assert "execute_tool" in tool_span["description"]
+ assert tool_span["data"]["gen_ai.operation.name"] == "execute_tool"
+ assert tool_span["data"]["gen_ai.tool.name"] == "add_numbers"
+ assert "gen_ai.tool.input" in tool_span["data"]
+ assert "gen_ai.tool.output" in tool_span["data"]
+
+ # Check chat spans have available_tools
+ for chat_span in chat_spans:
+ assert "gen_ai.request.available_tools" in chat_span["data"]
+ available_tools_str = chat_span["data"]["gen_ai.request.available_tools"]
+ # Available tools is serialized as a string
+ assert "add_numbers" in available_tools_str
+
+
+@pytest.mark.parametrize(
+ "handled_tool_call_exceptions",
+ [False, True],
+)
+@pytest.mark.asyncio
+async def test_agent_with_tool_validation_error(
+ sentry_init, capture_events, get_test_agent, handled_tool_call_exceptions
+):
+ """
+ Test that a handled exception is captured when a tool has unsatisfiable constraints.
+ """
+ sentry_init(
+ integrations=[
+ PydanticAIIntegration(
+ handled_tool_call_exceptions=handled_tool_call_exceptions
+ )
+ ],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ test_agent = get_test_agent()
+
+ @test_agent.tool_plain
+ def add_numbers(a: Annotated[int, Field(gt=0, lt=0)], b: int) -> int:
+ """Add two numbers together."""
+ return a + b
+
+ events = capture_events()
+
+ result = None
+ with pytest.raises(UnexpectedModelBehavior):
+ result = await test_agent.run("What is 5 + 3?")
+
+ assert result is None
+
+ if handled_tool_call_exceptions:
+ (error, model_behaviour_error, transaction) = events
+ else:
+ (
+ model_behaviour_error,
+ transaction,
+ ) = events
+ spans = transaction["spans"]
+
+ if handled_tool_call_exceptions:
+ assert error["level"] == "error"
+ assert error["exception"]["values"][0]["mechanism"]["handled"]
+
+ # Find child span types (invoke_agent is the transaction, not a child span)
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
+
+ # Should have tool spans
+ assert len(tool_spans) >= 1
+
+ # Check tool spans
+ model_retry_tool_span = tool_spans[0]
+ assert "execute_tool" in model_retry_tool_span["description"]
+ assert model_retry_tool_span["data"]["gen_ai.operation.name"] == "execute_tool"
+ assert model_retry_tool_span["data"]["gen_ai.tool.name"] == "add_numbers"
+ assert "gen_ai.tool.input" in model_retry_tool_span["data"]
+
+ # Check chat spans have available_tools
+ for chat_span in chat_spans:
+ assert "gen_ai.request.available_tools" in chat_span["data"]
+ available_tools_str = chat_span["data"]["gen_ai.request.available_tools"]
+ # Available tools is serialized as a string
+ assert "add_numbers" in available_tools_str
+
+
+@pytest.mark.asyncio
+async def test_agent_with_tools_streaming(sentry_init, capture_events, get_test_agent):
+ """
+ Test that tool execution works correctly with streaming.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ test_agent = get_test_agent()
+
+ @test_agent.tool_plain
+ def multiply(a: int, b: int) -> int:
+ """Multiply two numbers."""
+ return a * b
+
+ events = capture_events()
+
+ async with test_agent.run_stream("What is 7 times 8?") as result:
+ async for _ in result.stream_output():
+ pass
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find span types
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
+
+ # Should have tool spans
+ assert len(tool_spans) >= 1
+
+ # Verify streaming flag is True
+ for chat_span in chat_spans:
+ assert chat_span["data"]["gen_ai.response.streaming"] is True
+
+ # Check tool span
+ tool_span = tool_spans[0]
+ assert tool_span["data"]["gen_ai.tool.name"] == "multiply"
+ assert "gen_ai.tool.input" in tool_span["data"]
+ assert "gen_ai.tool.output" in tool_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_model_settings(
+ sentry_init, capture_events, get_test_agent_with_settings
+):
+ """
+ Test that model settings are captured in spans.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ test_agent_with_settings = get_test_agent_with_settings()
+ await test_agent_with_settings.run("Test input")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find chat span
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ chat_span = chat_spans[0]
+ # Check that model settings are captured
+ assert chat_span["data"].get("gen_ai.request.temperature") == 0.7
+ assert chat_span["data"].get("gen_ai.request.max_tokens") == 100
+ assert chat_span["data"].get("gen_ai.request.top_p") == 0.9
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_system_prompt_attribute(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ """
+ Test that system prompts are included as the first message.
+ """
+ agent = Agent(
+ "test",
+ name="test_system",
+ system_prompt="You are a helpful assistant specialized in testing.",
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+
+ events = capture_events()
+
+ await agent.run("Hello")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # The transaction IS the invoke_agent span, check for messages in chat spans instead
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ chat_span = chat_spans[0]
+
+ if send_default_pii and include_prompts:
+ system_instructions = chat_span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ assert json.loads(system_instructions) == [
+ {
+ "type": "text",
+ "content": "You are a helpful assistant specialized in testing.",
+ }
+ ]
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in chat_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_error_handling(sentry_init, capture_events):
+ """
+ Test error handling in agent execution.
+ """
+ # Use a simpler test that doesn't cause tool failures
+ # as pydantic-ai has complex error handling for tool errors
+ agent = Agent(
+ "test",
+ name="test_error",
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ # Simple run that should succeed
+ await agent.run("Hello")
+
+ # At minimum, we should have a transaction
+ assert len(events) >= 1
+ transaction = [e for e in events if e.get("type") == "transaction"][0]
+ assert transaction["transaction"] == "invoke_agent test_error"
+ # Transaction should complete successfully (status key may not exist if no error)
+ trace_status = transaction["contexts"]["trace"].get("status")
+ assert trace_status != "error" # Could be None or some other status
+
+
+@pytest.mark.asyncio
+async def test_without_pii(sentry_init, capture_events, get_test_agent):
+ """
+ Test that PII is not captured when send_default_pii is False.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ await test_agent.run("Sensitive input")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find child spans (invoke_agent is the transaction, not a child span)
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+
+ # Verify that messages and response text are not captured
+ for span in chat_spans:
+ assert "gen_ai.request.messages" not in span["data"]
+ assert "gen_ai.response.text" not in span["data"]
+
+
+@pytest.mark.asyncio
+async def test_without_pii_tools(sentry_init, capture_events, get_test_agent):
+ """
+ Test that tool input/output are not captured when send_default_pii is False.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=False,
+ )
+
+ test_agent = get_test_agent()
+
+ @test_agent.tool_plain
+ def sensitive_tool(data: str) -> str:
+ """A tool with sensitive data."""
+ return f"Processed: {data}"
+
+ events = capture_events()
+
+ await test_agent.run("Use sensitive tool with private data")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find tool spans
+ tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
+
+ # If tool was executed, verify input/output are not captured
+ for tool_span in tool_spans:
+ assert "gen_ai.tool.input" not in tool_span["data"]
+ assert "gen_ai.tool.output" not in tool_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_multiple_agents_concurrent(sentry_init, capture_events, get_test_agent):
+ """
+ Test that multiple agents can run concurrently without interfering.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+
+ async def run_agent(input_text):
+ return await test_agent.run(input_text)
+
+ # Run 3 agents concurrently
+ results = await asyncio.gather(*[run_agent(f"Input {i}") for i in range(3)])
+
+ assert len(results) == 3
+ assert len(events) == 3
+
+ # Verify each transaction is separate
+ for i, transaction in enumerate(events):
+ assert transaction["type"] == "transaction"
+ assert transaction["transaction"] == "invoke_agent test_agent"
+ # Each should have its own spans
+ assert len(transaction["spans"]) >= 1
+
+
+@pytest.mark.asyncio
+async def test_message_history(sentry_init, capture_events):
+ """
+ Test that full conversation history is captured in chat spans.
+ """
+ agent = Agent(
+ "test",
+ name="test_history",
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ # First message
+ await agent.run("Hello, I'm Alice")
+
+ # Second message with history
+ from pydantic_ai import messages
+
+ history = [
+ messages.ModelRequest(
+ parts=[messages.UserPromptPart(content="Hello, I'm Alice")]
+ ),
+ messages.ModelResponse(
+ parts=[messages.TextPart(content="Hello Alice! How can I help you?")],
+ model_name="test",
+ ),
+ ]
+
+ await agent.run("What is my name?", message_history=history)
+
+ # We should have 2 transactions
+ assert len(events) >= 2
+
+ # Check the second transaction has the full history
+ second_transaction = events[1]
+ spans = second_transaction["spans"]
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+
+ if chat_spans:
+ chat_span = chat_spans[0]
+ if "gen_ai.request.messages" in chat_span["data"]:
+ messages_data = chat_span["data"]["gen_ai.request.messages"]
+ # Should have multiple messages including history
+ assert len(messages_data) > 1
+
+
+@pytest.mark.asyncio
+async def test_gen_ai_system(sentry_init, capture_events, get_test_agent):
+ """
+ Test that gen_ai.system is set from the model.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ await test_agent.run("Test input")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find chat span
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ chat_span = chat_spans[0]
+ # gen_ai.system should be set from the model (TestModel -> 'test')
+ assert "gen_ai.system" in chat_span["data"]
+ assert chat_span["data"]["gen_ai.system"] == "test"
+
+
+@pytest.mark.asyncio
+async def test_include_prompts_false(sentry_init, capture_events, get_test_agent):
+ """
+ Test that prompts are not captured when include_prompts=False.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=True, # Even with PII enabled, prompts should not be captured
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ await test_agent.run("Sensitive prompt")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find child spans (invoke_agent is the transaction, not a child span)
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+
+ # Verify that messages and response text are not captured
+ for span in chat_spans:
+ assert "gen_ai.request.messages" not in span["data"]
+ assert "gen_ai.response.text" not in span["data"]
+
+
+@pytest.mark.asyncio
+async def test_include_prompts_true(sentry_init, capture_events, get_test_agent):
+ """
+ Test that prompts are captured when include_prompts=True (default).
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ await test_agent.run("Test prompt")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find child spans (invoke_agent is the transaction, not a child span)
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+
+ # Verify that messages are captured in chat spans
+ assert len(chat_spans) >= 1
+ for chat_span in chat_spans:
+ assert "gen_ai.request.messages" in chat_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_include_prompts_false_with_tools(
+ sentry_init, capture_events, get_test_agent
+):
+ """
+ Test that tool input/output are not captured when include_prompts=False.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ test_agent = get_test_agent()
+
+ @test_agent.tool_plain
+ def test_tool(value: int) -> int:
+ """A test tool."""
+ return value * 2
+
+ events = capture_events()
+
+ await test_agent.run("Use the test tool with value 5")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find tool spans
+ tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
+
+ # If tool was executed, verify input/output are not captured
+ for tool_span in tool_spans:
+ assert "gen_ai.tool.input" not in tool_span["data"]
+ assert "gen_ai.tool.output" not in tool_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_include_prompts_requires_pii(
+ sentry_init, capture_events, get_test_agent
+):
+ """
+ Test that include_prompts requires send_default_pii=True.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=False, # PII disabled
+ )
+
+ events = capture_events()
+
+ test_agent = get_test_agent()
+ await test_agent.run("Test prompt")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # Find child spans (invoke_agent is the transaction, not a child span)
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+
+ # Even with include_prompts=True, if PII is disabled, messages should not be captured
+ for span in chat_spans:
+ assert "gen_ai.request.messages" not in span["data"]
+ assert "gen_ai.response.text" not in span["data"]
+
+
+@pytest.mark.asyncio
+async def test_mcp_tool_execution_spans(sentry_init, capture_events):
+ """
+ Test that MCP (Model Context Protocol) tool calls create execute_tool spans.
+
+ Tests MCP tools accessed through CombinedToolset, which is how they're typically
+ used in practice (when an agent combines regular functions with MCP servers).
+ """
+ pytest.importorskip("mcp")
+
+ from unittest.mock import MagicMock
+ from pydantic_ai.mcp import MCPServerStdio
+ from pydantic_ai import Agent
+ from pydantic_ai.toolsets.combined import CombinedToolset
+ import sentry_sdk
+
+ # Create mock MCP server
+ mock_server = MCPServerStdio(
+ command="python",
+ args=["-m", "test_server"],
+ )
+
+ # Mock the server's internal methods
+ mock_server._client = MagicMock()
+ mock_server._is_initialized = True
+ mock_server._server_info = MagicMock()
+
+ # Mock tool call response
+ async def mock_send_request(request, response_type):
+ from mcp.types import CallToolResult, TextContent
+
+ return CallToolResult(
+ content=[TextContent(type="text", text="MCP tool executed successfully")],
+ isError=False,
+ )
+
+ mock_server._client.send_request = mock_send_request
+
+ # Mock context manager methods
+ async def mock_aenter():
+ return mock_server
+
+ async def mock_aexit(*args):
+ pass
+
+ mock_server.__aenter__ = mock_aenter
+ mock_server.__aexit__ = mock_aexit
+
+ # Mock _map_tool_result_part
+ async def mock_map_tool_result_part(part):
+ return part.text if hasattr(part, "text") else str(part)
+
+ mock_server._map_tool_result_part = mock_map_tool_result_part
+
+ # Create a CombinedToolset with the MCP server
+ # This simulates how MCP servers are typically used in practice
+ from pydantic_ai.toolsets.function import FunctionToolset
+
+ function_toolset = FunctionToolset()
+ combined = CombinedToolset([function_toolset, mock_server])
+
+ # Create agent
+ agent = Agent(
+ "test",
+ name="test_mcp_agent",
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ # Simulate MCP tool execution within a transaction through CombinedToolset
+ with sentry_sdk.start_transaction(
+ op="ai.run", name="invoke_agent test_mcp_agent"
+ ) as transaction:
+ # Set up the agent context
+ scope = sentry_sdk.get_current_scope()
+ scope._contexts["pydantic_ai_agent"] = {
+ "_agent": agent,
+ }
+
+ # Create a mock tool that simulates an MCP tool from CombinedToolset
+ from pydantic_ai._run_context import RunContext
+ from pydantic_ai.result import RunUsage
+ from pydantic_ai.models.test import TestModel
+ from pydantic_ai.toolsets.combined import _CombinedToolsetTool
+
+ ctx = RunContext(
+ deps=None,
+ model=TestModel(),
+ usage=RunUsage(),
+ retry=0,
+ tool_name="test_mcp_tool",
+ )
+
+ tool_name = "test_mcp_tool"
+
+ # Create a tool that points to the MCP server
+ # This simulates how CombinedToolset wraps tools from different sources
+ tool = _CombinedToolsetTool(
+ toolset=combined,
+ tool_def=MagicMock(name=tool_name),
+ max_retries=0,
+ args_validator=MagicMock(),
+ source_toolset=mock_server,
+ source_tool=MagicMock(),
+ )
+
+ try:
+ await combined.call_tool(tool_name, {"query": "test"}, ctx, tool)
+ except Exception:
+ # MCP tool might raise if not fully mocked, that's okay
+ pass
+
+ events_list = events
+ if len(events_list) == 0:
+ pytest.skip("No events captured, MCP test setup incomplete")
+
+ (transaction,) = events_list
+ transaction["spans"]
+
+ # Note: This test manually calls combined.call_tool which doesn't go through
+ # ToolManager._call_tool (which is what the integration patches).
+ # In real-world usage, MCP tools are called through agent.run() which uses ToolManager.
+ # This synthetic test setup doesn't trigger the integration's tool patches.
+ # We skip this test as it doesn't represent actual usage patterns.
+ pytest.skip(
+ "MCP test needs to be rewritten to use agent.run() instead of manually calling toolset methods"
+ )
+
+
+@pytest.mark.asyncio
+async def test_context_cleanup_after_run(sentry_init, get_test_agent):
+ """
+ Test that the pydantic_ai_agent context is properly cleaned up after agent execution.
+ """
+ import sentry_sdk
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Verify context is not set before run
+ scope = sentry_sdk.get_current_scope()
+ assert "pydantic_ai_agent" not in scope._contexts
+
+ # Run the agent
+ test_agent = get_test_agent()
+ await test_agent.run("Test input")
+
+ # Verify context is cleaned up after run
+ assert "pydantic_ai_agent" not in scope._contexts
+
+
+def test_context_cleanup_after_run_sync(sentry_init, get_test_agent):
+ """
+ Test that the pydantic_ai_agent context is properly cleaned up after sync agent execution.
+ """
+ import sentry_sdk
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Verify context is not set before run
+ scope = sentry_sdk.get_current_scope()
+ assert "pydantic_ai_agent" not in scope._contexts
+
+ # Run the agent synchronously
+ test_agent = get_test_agent()
+ test_agent.run_sync("Test input")
+
+ # Verify context is cleaned up after run
+ assert "pydantic_ai_agent" not in scope._contexts
+
+
+@pytest.mark.asyncio
+async def test_context_cleanup_after_streaming(sentry_init, get_test_agent):
+ """
+ Test that the pydantic_ai_agent context is properly cleaned up after streaming execution.
+ """
+ import sentry_sdk
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Verify context is not set before run
+ scope = sentry_sdk.get_current_scope()
+ assert "pydantic_ai_agent" not in scope._contexts
+
+ test_agent = get_test_agent()
+ # Run the agent with streaming
+ async with test_agent.run_stream("Test input") as result:
+ async for _ in result.stream_output():
+ pass
+
+ # Verify context is cleaned up after streaming completes
+ assert "pydantic_ai_agent" not in scope._contexts
+
+
+@pytest.mark.asyncio
+async def test_context_cleanup_on_error(sentry_init, get_test_agent):
+ """
+ Test that the pydantic_ai_agent context is cleaned up even when an error occurs.
+ """
+ import sentry_sdk
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ test_agent = get_test_agent()
+
+ # Create an agent with a tool that raises an error
+ @test_agent.tool_plain
+ def failing_tool() -> str:
+ """A tool that always fails."""
+ raise ValueError("Tool error")
+
+ # Verify context is not set before run
+ scope = sentry_sdk.get_current_scope()
+ assert "pydantic_ai_agent" not in scope._contexts
+
+ # Run the agent - this may or may not raise depending on pydantic-ai's error handling
+ try:
+ await test_agent.run("Use the failing tool")
+ except Exception:
+ pass
+
+ # Verify context is cleaned up even if there was an error
+ assert "pydantic_ai_agent" not in scope._contexts
+
+
+@pytest.mark.asyncio
+async def test_context_isolation_concurrent_agents(sentry_init, get_test_agent):
+ """
+ Test that concurrent agent executions maintain isolated contexts.
+ """
+ import sentry_sdk
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Create a second agent
+ agent2 = Agent(
+ "test",
+ name="test_agent_2",
+ system_prompt="Second test agent.",
+ )
+
+ async def run_and_check_context(agent, agent_name):
+ """Run an agent and verify its context during and after execution."""
+ # Before execution, context should not exist in the outer scope
+ outer_scope = sentry_sdk.get_current_scope()
+
+ # Run the agent
+ await agent.run(f"Input for {agent_name}")
+
+ # After execution, verify context is cleaned up
+ # Note: Due to isolation_scope, we can't easily check the inner scope here,
+ # but we can verify the outer scope remains clean
+ assert "pydantic_ai_agent" not in outer_scope._contexts
+
+ return agent_name
+
+ test_agent = get_test_agent()
+ # Run both agents concurrently
+ results = await asyncio.gather(
+ run_and_check_context(test_agent, "agent1"),
+ run_and_check_context(agent2, "agent2"),
+ )
+
+ assert results == ["agent1", "agent2"]
+
+ # Final check: outer scope should be clean
+ final_scope = sentry_sdk.get_current_scope()
+ assert "pydantic_ai_agent" not in final_scope._contexts
+
+
+# ==================== Additional Coverage Tests ====================
+
+
+@pytest.mark.asyncio
+async def test_invoke_agent_with_list_user_prompt(sentry_init, capture_events):
+ """
+ Test that invoke_agent span handles list user prompts correctly.
+ """
+ agent = Agent(
+ "test",
+ name="test_list_prompt",
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ # Use a list as user prompt
+ await agent.run(["First part", "Second part"])
+
+ (transaction,) = events
+
+ # Check that the invoke_agent transaction has messages data
+ # The invoke_agent is the transaction itself
+ if "gen_ai.request.messages" in transaction["contexts"]["trace"]["data"]:
+ messages_str = transaction["contexts"]["trace"]["data"][
+ "gen_ai.request.messages"
+ ]
+ assert "First part" in messages_str
+ assert "Second part" in messages_str
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "send_default_pii, include_prompts",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+async def test_invoke_agent_with_instructions(
+ sentry_init, capture_events, send_default_pii, include_prompts
+):
+ """
+ Test that invoke_agent span handles instructions correctly.
+ """
+ from pydantic_ai import Agent
+
+ # Create agent with instructions (can be string or list)
+ agent = Agent(
+ "test",
+ name="test_instructions",
+ )
+
+ # Add instructions via _instructions attribute (internal API)
+ agent._instructions = ["Instruction 1", "Instruction 2"]
+ agent._system_prompts = ["System prompt"]
+
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=include_prompts)],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+
+ events = capture_events()
+
+ await agent.run("Test input")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ # The transaction IS the invoke_agent span, check for messages in chat spans instead
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ chat_span = chat_spans[0]
+
+ if send_default_pii and include_prompts:
+ system_instructions = chat_span["data"][SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS]
+ assert json.loads(system_instructions) == [
+ {"type": "text", "content": "System prompt"},
+ {"type": "text", "content": "Instruction 1\nInstruction 2"},
+ ]
+ else:
+ assert SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS not in chat_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_model_name_extraction_with_callable(sentry_init, capture_events):
+ """
+ Test model name extraction when model has a callable name() method.
+ """
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _get_model_name
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Test the utility function directly
+ mock_model = MagicMock()
+ # Remove model_name attribute so it checks name() next
+ del mock_model.model_name
+ mock_model.name = lambda: "custom-model-name"
+
+ # Get model name - should call the callable name()
+ result = _get_model_name(mock_model)
+
+ # Should return the result from callable
+ assert result == "custom-model-name"
+
+
+@pytest.mark.asyncio
+async def test_model_name_extraction_fallback_to_str(sentry_init, capture_events):
+ """
+ Test model name extraction falls back to str() when no name attribute exists.
+ """
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _get_model_name
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Test the utility function directly
+ mock_model = MagicMock()
+ # Remove name and model_name attributes
+ del mock_model.name
+ del mock_model.model_name
+
+ # Get model name - should fall back to str()
+ result = _get_model_name(mock_model)
+
+ # Should return string representation
+ assert result is not None
+ assert isinstance(result, str)
+
+
+@pytest.mark.asyncio
+async def test_model_settings_object_style(sentry_init, capture_events):
+ """
+ Test that object-style model settings (non-dict) are handled correctly.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_model_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create mock settings object (not a dict)
+ mock_settings = MagicMock()
+ mock_settings.temperature = 0.8
+ mock_settings.max_tokens = 200
+ mock_settings.top_p = 0.95
+ mock_settings.frequency_penalty = 0.5
+ mock_settings.presence_penalty = 0.3
+
+ # Set model data with object-style settings
+ _set_model_data(span, None, mock_settings)
+
+ span.finish()
+
+ # Should not crash and should set the settings
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_usage_data_partial(sentry_init, capture_events):
+ """
+ Test that usage data is correctly handled when only some fields are present.
+ """
+ agent = Agent(
+ "test",
+ name="test_usage",
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ await agent.run("Test input")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ # Check that usage data fields exist (they may or may not be set depending on TestModel)
+ chat_span = chat_spans[0]
+ # At minimum, the span should have been created
+ assert chat_span is not None
+
+
+@pytest.mark.asyncio
+async def test_agent_data_from_scope(sentry_init, capture_events):
+ """
+ Test that agent data can be retrieved from Sentry scope when not passed directly.
+ """
+
+ agent = Agent(
+ "test",
+ name="test_scope_agent",
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ # The integration automatically sets agent in scope during execution
+ await agent.run("Test input")
+
+ (transaction,) = events
+
+ # Verify agent name is captured
+ assert transaction["transaction"] == "invoke_agent test_scope_agent"
+
+
+@pytest.mark.asyncio
+async def test_available_tools_without_description(
+ sentry_init, capture_events, get_test_agent
+):
+ """
+ Test that available tools are captured even when description is missing.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ test_agent = get_test_agent()
+
+ @test_agent.tool_plain
+ def tool_without_desc(x: int) -> int:
+ # No docstring = no description
+ return x * 2
+
+ events = capture_events()
+
+ await test_agent.run("Use the tool with 5")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+ if chat_spans:
+ chat_span = chat_spans[0]
+ if "gen_ai.request.available_tools" in chat_span["data"]:
+ tools_str = chat_span["data"]["gen_ai.request.available_tools"]
+ assert "tool_without_desc" in tools_str
+
+
+@pytest.mark.asyncio
+async def test_output_with_tool_calls(sentry_init, capture_events, get_test_agent):
+ """
+ Test that tool calls in model response are captured correctly.
+ """
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ test_agent = get_test_agent()
+
+ @test_agent.tool_plain
+ def calc_tool(value: int) -> int:
+ """Calculate something."""
+ return value + 10
+
+ events = capture_events()
+
+ await test_agent.run("Use calc_tool with 5")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+
+ # At least one chat span should exist
+ assert len(chat_spans) >= 1
+
+ # Check if tool calls are captured in response
+ for chat_span in chat_spans:
+ # Tool calls may or may not be in response depending on TestModel behavior
+ # Just verify the span was created and has basic data
+ assert "gen_ai.operation.name" in chat_span["data"]
+
+
+@pytest.mark.asyncio
+async def test_message_formatting_with_different_parts(sentry_init, capture_events):
+ """
+ Test that different message part types are handled correctly in ai_client span.
+ """
+ from pydantic_ai import Agent, messages
+
+ agent = Agent(
+ "test",
+ name="test_message_parts",
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ # Create message history with different part types
+ history = [
+ messages.ModelRequest(parts=[messages.UserPromptPart(content="Hello")]),
+ messages.ModelResponse(
+ parts=[
+ messages.TextPart(content="Hi there!"),
+ ],
+ model_name="test",
+ ),
+ ]
+
+ await agent.run("What did I say?", message_history=history)
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+
+ # Should have chat spans
+ assert len(chat_spans) >= 1
+
+ # Check that messages are captured
+ chat_span = chat_spans[0]
+ if "gen_ai.request.messages" in chat_span["data"]:
+ messages_data = chat_span["data"]["gen_ai.request.messages"]
+ # Should contain message history
+ assert messages_data is not None
+
+
+@pytest.mark.asyncio
+async def test_update_invoke_agent_span_with_none_output(sentry_init, capture_events):
+ """
+ Test that update_invoke_agent_span handles None output gracefully.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.invoke_agent import (
+ update_invoke_agent_span,
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Update with None output - should not raise
+ update_invoke_agent_span(span, None)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_update_ai_client_span_with_none_response(sentry_init, capture_events):
+ """
+ Test that update_ai_client_span handles None response gracefully.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import (
+ update_ai_client_span,
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Update with None response - should not raise
+ update_ai_client_span(span, None)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_agent_without_name(sentry_init, capture_events):
+ """
+ Test that agent without a name is handled correctly.
+ """
+ # Create agent without explicit name
+ agent = Agent("test")
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ events = capture_events()
+
+ await agent.run("Test input")
+
+ (transaction,) = events
+
+ # Should still create transaction, just with default name
+ assert transaction["type"] == "transaction"
+ # Transaction name should be "invoke_agent agent" or similar default
+ assert "invoke_agent" in transaction["transaction"]
+
+
+@pytest.mark.asyncio
+async def test_model_response_without_parts(sentry_init, capture_events):
+ """
+ Test handling of model response without parts attribute.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_output_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create mock response without parts
+ mock_response = MagicMock()
+ mock_response.model_name = "test-model"
+ del mock_response.parts # Remove parts attribute
+
+ # Should not raise, just skip formatting
+ _set_output_data(span, mock_response)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_input_messages_error_handling(sentry_init, capture_events):
+ """
+ Test that _set_input_messages handles errors gracefully.
+ """
+ import sentry_sdk
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Pass invalid messages that would cause an error
+ invalid_messages = [object()] # Plain object without expected attributes
+
+ # Should not raise, error is caught internally
+ _set_input_messages(span, invalid_messages)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_available_tools_error_handling(sentry_init, capture_events):
+ """
+ Test that _set_available_tools handles errors gracefully.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_available_tools
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create mock agent with invalid toolset
+ mock_agent = MagicMock()
+ mock_agent._function_toolset.tools.items.side_effect = Exception("Error")
+
+ # Should not raise, error is caught internally
+ _set_available_tools(span, mock_agent)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_usage_data_with_none_usage(sentry_init, capture_events):
+ """
+ Test that _set_usage_data handles None usage gracefully.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_usage_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Pass None usage - should not raise
+ _set_usage_data(span, None)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_usage_data_with_partial_fields(sentry_init, capture_events):
+ """
+ Test that _set_usage_data handles usage with only some fields.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_usage_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create usage object with only some fields
+ mock_usage = MagicMock()
+ mock_usage.input_tokens = 100
+ mock_usage.output_tokens = None # Missing
+ mock_usage.total_tokens = 100
+
+ # Should only set the non-None fields
+ _set_usage_data(span, mock_usage)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_message_parts_with_tool_return(sentry_init, capture_events):
+ """
+ Test that ToolReturnPart messages are handled correctly.
+ """
+ from pydantic_ai import Agent
+
+ agent = Agent(
+ "test",
+ name="test_tool_return",
+ )
+
+ @agent.tool_plain
+ def test_tool(x: int) -> int:
+ """Test tool."""
+ return x * 2
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ # Run with history containing tool return
+ await agent.run("Use test_tool with 5")
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
+
+ # Should have chat spans
+ assert len(chat_spans) >= 1
+
+
+@pytest.mark.asyncio
+async def test_message_parts_with_list_content(sentry_init, capture_events):
+ """
+ Test that message parts with list content are handled correctly.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create message with list content
+ mock_msg = MagicMock()
+ mock_part = MagicMock()
+ mock_part.content = ["item1", "item2", {"complex": "item"}]
+ mock_msg.parts = [mock_part]
+ mock_msg.instructions = None
+
+ messages = [mock_msg]
+
+ # Should handle list content
+ _set_input_messages(span, messages)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_output_data_with_text_and_tool_calls(sentry_init, capture_events):
+ """
+ Test that _set_output_data handles both text and tool calls in response.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_output_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create mock response with both TextPart and ToolCallPart
+ from pydantic_ai import messages
+
+ text_part = messages.TextPart(content="Here's the result")
+ tool_call_part = MagicMock()
+ tool_call_part.tool_name = "test_tool"
+ tool_call_part.args = {"x": 5}
+
+ mock_response = MagicMock()
+ mock_response.model_name = "test-model"
+ mock_response.parts = [text_part, tool_call_part]
+
+ # Should handle both text and tool calls
+ _set_output_data(span, mock_response)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_output_data_error_handling(sentry_init, capture_events):
+ """
+ Test that _set_output_data handles errors in formatting gracefully.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_output_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create mock response that will cause error
+ mock_response = MagicMock()
+ mock_response.model_name = "test-model"
+ mock_response.parts = [MagicMock(side_effect=Exception("Error"))]
+
+ # Should catch error and not crash
+ _set_output_data(span, mock_response)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_message_with_system_prompt_part(sentry_init, capture_events):
+ """
+ Test that SystemPromptPart is handled with correct role.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from pydantic_ai import messages
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create message with SystemPromptPart
+ system_part = messages.SystemPromptPart(content="You are a helpful assistant")
+
+ mock_msg = MagicMock()
+ mock_msg.parts = [system_part]
+ mock_msg.instructions = None
+
+ msgs = [mock_msg]
+
+ # Should handle system prompt
+ _set_input_messages(span, msgs)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_message_with_instructions(sentry_init, capture_events):
+ """
+ Test that messages with instructions field are handled correctly.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create message with instructions
+ mock_msg = MagicMock()
+ mock_msg.instructions = "System instructions here"
+ mock_part = MagicMock()
+ mock_part.content = "User message"
+ mock_msg.parts = [mock_part]
+
+ msgs = [mock_msg]
+
+ # Should extract system prompt from instructions
+ _set_input_messages(span, msgs)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_input_messages_without_prompts(sentry_init, capture_events):
+ """
+ Test that _set_input_messages respects _should_send_prompts().
+ """
+ import sentry_sdk
+
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Even with messages, should not set them
+ messages = ["test"]
+ _set_input_messages(span, messages)
+
+ span.finish()
+
+ # Should not crash and should not set messages
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_output_data_without_prompts(sentry_init, capture_events):
+ """
+ Test that _set_output_data respects _should_send_prompts().
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import _set_output_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Even with response, should not set output data
+ mock_response = MagicMock()
+ mock_response.model_name = "test"
+ _set_output_data(span, mock_response)
+
+ span.finish()
+
+ # Should not crash and should not set output
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_get_model_name_with_exception_in_callable(sentry_init, capture_events):
+ """
+ Test that _get_model_name handles exceptions in name() callable.
+ """
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _get_model_name
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Create model with callable name that raises exception
+ mock_model = MagicMock()
+ mock_model.name = MagicMock(side_effect=Exception("Error"))
+
+ # Should fall back to str()
+ result = _get_model_name(mock_model)
+
+ # Should return something (str fallback)
+ assert result is not None
+
+
+@pytest.mark.asyncio
+async def test_get_model_name_with_string_model(sentry_init, capture_events):
+ """
+ Test that _get_model_name handles string models.
+ """
+ from sentry_sdk.integrations.pydantic_ai.utils import _get_model_name
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Pass a string as model
+ result = _get_model_name("gpt-4")
+
+ # Should return the string
+ assert result == "gpt-4"
+
+
+@pytest.mark.asyncio
+async def test_get_model_name_with_none(sentry_init, capture_events):
+ """
+ Test that _get_model_name handles None model.
+ """
+ from sentry_sdk.integrations.pydantic_ai.utils import _get_model_name
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Pass None
+ result = _get_model_name(None)
+
+ # Should return None
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_set_model_data_with_system(sentry_init, capture_events):
+ """
+ Test that _set_model_data captures system from model.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_model_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create model with system
+ mock_model = MagicMock()
+ mock_model.system = "openai"
+ mock_model.model_name = "gpt-4"
+
+ # Set model data
+ _set_model_data(span, mock_model, None)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_model_data_from_agent_scope(sentry_init, capture_events):
+ """
+ Test that _set_model_data retrieves model from agent in scope when not passed.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_model_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Set agent in scope
+ scope = sentry_sdk.get_current_scope()
+ mock_agent = MagicMock()
+ mock_agent.model = MagicMock()
+ mock_agent.model.model_name = "test-model"
+ mock_agent.model_settings = {"temperature": 0.5}
+ scope._contexts["pydantic_ai_agent"] = {"_agent": mock_agent}
+
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Pass None for model, should get from scope
+ _set_model_data(span, None, None)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_model_data_with_none_settings_values(sentry_init, capture_events):
+ """
+ Test that _set_model_data skips None values in settings.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_model_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create settings with None values
+ settings = {
+ "temperature": 0.7,
+ "max_tokens": None, # Should be skipped
+ "top_p": None, # Should be skipped
+ }
+
+ # Set model data
+ _set_model_data(span, None, settings)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_should_send_prompts_without_pii(sentry_init, capture_events):
+ """
+ Test that _should_send_prompts returns False when PII disabled.
+ """
+ from sentry_sdk.integrations.pydantic_ai.utils import _should_send_prompts
+
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=True)],
+ traces_sample_rate=1.0,
+ send_default_pii=False, # PII disabled
+ )
+
+ # Should return False
+ result = _should_send_prompts()
+ assert result is False
+
+
+@pytest.mark.asyncio
+async def test_set_agent_data_without_agent(sentry_init, capture_events):
+ """
+ Test that _set_agent_data handles None agent gracefully.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_agent_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Pass None agent, with no agent in scope
+ _set_agent_data(span, None)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_agent_data_from_scope(sentry_init, capture_events):
+ """
+ Test that _set_agent_data retrieves agent from scope when not passed.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_agent_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Set agent in scope
+ scope = sentry_sdk.get_current_scope()
+ mock_agent = MagicMock()
+ mock_agent.name = "test_agent_from_scope"
+ scope._contexts["pydantic_ai_agent"] = {"_agent": mock_agent}
+
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Pass None for agent, should get from scope
+ _set_agent_data(span, None)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_agent_data_without_name(sentry_init, capture_events):
+ """
+ Test that _set_agent_data handles agent without name attribute.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_agent_data
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create agent without name
+ mock_agent = MagicMock()
+ mock_agent.name = None # No name
+
+ # Should not set agent name
+ _set_agent_data(span, mock_agent)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_available_tools_without_toolset(sentry_init, capture_events):
+ """
+ Test that _set_available_tools handles agent without toolset.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_available_tools
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create agent without _function_toolset
+ mock_agent = MagicMock()
+ del mock_agent._function_toolset
+
+ # Should handle gracefully
+ _set_available_tools(span, mock_agent)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_set_available_tools_with_schema(sentry_init, capture_events):
+ """
+ Test that _set_available_tools extracts tool schema correctly.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.utils import _set_available_tools
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ span = sentry_sdk.start_span(op="test_span")
+
+ # Create agent with toolset containing schema
+ mock_agent = MagicMock()
+ mock_tool = MagicMock()
+ mock_schema = MagicMock()
+ mock_schema.description = "Test tool description"
+ mock_schema.json_schema = {"type": "object", "properties": {}}
+ mock_tool.function_schema = mock_schema
+
+ mock_agent._function_toolset.tools = {"test_tool": mock_tool}
+
+ # Should extract schema
+ _set_available_tools(span, mock_agent)
+
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_execute_tool_span_creation(sentry_init, capture_events):
+ """
+ Test direct creation of execute_tool span.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.execute_tool import (
+ execute_tool_span,
+ update_execute_tool_span,
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Create execute_tool span
+ with execute_tool_span("test_tool", {"arg": "value"}, None, "function") as span:
+ # Update with result
+ update_execute_tool_span(span, {"result": "success"})
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_execute_tool_span_with_mcp_type(sentry_init, capture_events):
+ """
+ Test execute_tool span with MCP tool type.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.execute_tool import (
+ execute_tool_span,
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Create execute_tool span with mcp type
+ with execute_tool_span("mcp_tool", {"arg": "value"}, None, "mcp") as span:
+ # Verify type is set
+ assert span is not None
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_execute_tool_span_without_prompts(sentry_init, capture_events):
+ """
+ Test that execute_tool span respects _should_send_prompts().
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.execute_tool import (
+ execute_tool_span,
+ update_execute_tool_span,
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration(include_prompts=False)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Create execute_tool span
+ with execute_tool_span("test_tool", {"arg": "value"}, None, "function") as span:
+ # Update with result - should not set input/output
+ update_execute_tool_span(span, {"result": "success"})
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_execute_tool_span_with_none_args(sentry_init, capture_events):
+ """
+ Test execute_tool span with None args.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.execute_tool import execute_tool_span
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Create execute_tool span with None args
+ with execute_tool_span("test_tool", None, None, "function") as span:
+ assert span is not None
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_update_execute_tool_span_with_none_span(sentry_init, capture_events):
+ """
+ Test that update_execute_tool_span handles None span gracefully.
+ """
+ from sentry_sdk.integrations.pydantic_ai.spans.execute_tool import (
+ update_execute_tool_span,
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Update with None span - should not raise
+ update_execute_tool_span(None, {"result": "success"})
+
+ # Should not crash
+ assert True
+
+
+@pytest.mark.asyncio
+async def test_update_execute_tool_span_with_none_result(sentry_init, capture_events):
+ """
+ Test that update_execute_tool_span handles None result gracefully.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.execute_tool import (
+ execute_tool_span,
+ update_execute_tool_span,
+ )
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Create execute_tool span
+ with execute_tool_span("test_tool", {"arg": "value"}, None, "function") as span:
+ # Update with None result
+ update_execute_tool_span(span, None)
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_tool_execution_without_span_context(sentry_init, capture_events):
+ """
+ Test that tool execution patch handles case when no span context exists.
+ This tests the code path where current_span is None in _patch_tool_execution.
+ """
+ # Import the patching function
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ # Create a simple agent with no tools (won't have function_toolset)
+ agent = Agent("test", name="test_no_span")
+
+ # Call without span context (no transaction active)
+ # The patches should handle this gracefully
+ try:
+ # This will fail because we're not in a transaction, but it should not crash
+ await agent.run("test")
+ except Exception:
+ # Expected to fail, that's okay
+ pass
+
+ # Should not crash
+ assert True
+
+
+@pytest.mark.asyncio
+async def test_invoke_agent_span_with_callable_instruction(sentry_init, capture_events):
+ """
+ Test that invoke_agent_span skips callable instructions correctly.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.spans.invoke_agent import invoke_agent_span
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Create mock agent with callable instruction
+ mock_agent = MagicMock()
+ mock_agent.name = "test_agent"
+ mock_agent._system_prompts = []
+
+ # Add both string and callable instructions
+ mock_callable = lambda: "Dynamic instruction"
+ mock_agent._instructions = ["Static instruction", mock_callable]
+
+ # Create span
+ span = invoke_agent_span("Test prompt", mock_agent, None, None)
+ span.finish()
+
+ # Should not crash (callable should be skipped)
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_invoke_agent_span_with_string_instructions(sentry_init, capture_events):
+ """
+ Test that invoke_agent_span handles string instructions (not list).
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.spans.invoke_agent import invoke_agent_span
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Create mock agent with string instruction
+ mock_agent = MagicMock()
+ mock_agent.name = "test_agent"
+ mock_agent._system_prompts = []
+ mock_agent._instructions = "Single instruction string"
+
+ # Create span
+ span = invoke_agent_span("Test prompt", mock_agent, None, None)
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_ai_client_span_with_streaming_flag(sentry_init, capture_events):
+ """
+ Test that ai_client_span reads streaming flag from scope.
+ """
+ import sentry_sdk
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import ai_client_span
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Set streaming flag in scope
+ scope = sentry_sdk.get_current_scope()
+ scope._contexts["pydantic_ai_agent"] = {"_streaming": True}
+
+ # Create ai_client span
+ span = ai_client_span([], None, None, None)
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+@pytest.mark.asyncio
+async def test_ai_client_span_gets_agent_from_scope(sentry_init, capture_events):
+ """
+ Test that ai_client_span gets agent from scope when not passed.
+ """
+ import sentry_sdk
+ from unittest.mock import MagicMock
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import ai_client_span
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ with sentry_sdk.start_transaction(op="test", name="test") as transaction:
+ # Set agent in scope
+ scope = sentry_sdk.get_current_scope()
+ mock_agent = MagicMock()
+ mock_agent.name = "test_agent"
+ mock_agent._function_toolset = MagicMock()
+ mock_agent._function_toolset.tools = {}
+ scope._contexts["pydantic_ai_agent"] = {"_agent": mock_agent}
+
+ # Create ai_client span without passing agent
+ span = ai_client_span([], None, None, None)
+ span.finish()
+
+ # Should not crash
+ assert transaction is not None
+
+
+def _get_messages_from_span(span_data):
+ """Helper to extract and parse messages from span data."""
+ messages_data = span_data["gen_ai.request.messages"]
+ return (
+ json.loads(messages_data) if isinstance(messages_data, str) else messages_data
+ )
+
+
+def _find_binary_content(messages_data, expected_modality, expected_mime_type):
+ """Helper to find and verify binary content in messages."""
+ for msg in messages_data:
+ if "content" not in msg:
+ continue
+ for content_item in msg["content"]:
+ if content_item.get("type") == "blob":
+ assert content_item["modality"] == expected_modality
+ assert content_item["mime_type"] == expected_mime_type
+ assert content_item["content"] == BLOB_DATA_SUBSTITUTE
+ return True
+ return False
+
+
+@pytest.mark.asyncio
+async def test_binary_content_encoding_image(sentry_init, capture_events):
+ """Test that BinaryContent with image data is properly encoded in messages."""
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(op="test", name="test"):
+ span = sentry_sdk.start_span(op="test_span")
+ binary_content = BinaryContent(
+ data=b"fake_image_data_12345", media_type="image/png"
+ )
+ user_part = UserPromptPart(content=["Look at this image:", binary_content])
+ mock_msg = MagicMock()
+ mock_msg.parts = [user_part]
+ mock_msg.instructions = None
+
+ _set_input_messages(span, [mock_msg])
+ span.finish()
+
+ (event,) = events
+ span_data = event["spans"][0]["data"]
+ messages_data = _get_messages_from_span(span_data)
+ assert _find_binary_content(messages_data, "image", "image/png")
+
+
+@pytest.mark.asyncio
+async def test_binary_content_encoding_mixed_content(sentry_init, capture_events):
+ """Test that BinaryContent mixed with text content is properly handled."""
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(op="test", name="test"):
+ span = sentry_sdk.start_span(op="test_span")
+ binary_content = BinaryContent(
+ data=b"fake_image_bytes", media_type="image/jpeg"
+ )
+ user_part = UserPromptPart(
+ content=["Here is an image:", binary_content, "What do you see?"]
+ )
+ mock_msg = MagicMock()
+ mock_msg.parts = [user_part]
+ mock_msg.instructions = None
+
+ _set_input_messages(span, [mock_msg])
+ span.finish()
+
+ (event,) = events
+ span_data = event["spans"][0]["data"]
+ messages_data = _get_messages_from_span(span_data)
+
+ # Verify both text and binary content are present
+ found_text = any(
+ content_item.get("type") == "text"
+ for msg in messages_data
+ if "content" in msg
+ for content_item in msg["content"]
+ )
+ assert found_text, "Text content should be found"
+ assert _find_binary_content(messages_data, "image", "image/jpeg")
+
+
+@pytest.mark.asyncio
+async def test_binary_content_in_agent_run(sentry_init, capture_events):
+ """Test that BinaryContent in actual agent run is properly captured in spans."""
+ agent = Agent("test", name="test_binary_agent")
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+ binary_content = BinaryContent(
+ data=b"fake_image_data_for_testing", media_type="image/png"
+ )
+ await agent.run(["Analyze this image:", binary_content])
+
+ (transaction,) = events
+ chat_spans = [s for s in transaction["spans"] if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+
+ chat_span = chat_spans[0]
+ if "gen_ai.request.messages" in chat_span["data"]:
+ messages_str = str(chat_span["data"]["gen_ai.request.messages"])
+ assert any(keyword in messages_str for keyword in ["blob", "image", "base64"])
+
+
+@pytest.mark.asyncio
+async def test_set_usage_data_with_cache_tokens(sentry_init, capture_events):
+ """Test that cache_read_tokens and cache_write_tokens are tracked."""
+ sentry_init(integrations=[PydanticAIIntegration()], traces_sample_rate=1.0)
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(op="test", name="test"):
+ span = sentry_sdk.start_span(op="test_span")
+ usage = RequestUsage(
+ input_tokens=100,
+ output_tokens=50,
+ cache_read_tokens=80,
+ cache_write_tokens=20,
+ )
+ _set_usage_data(span, usage)
+ span.finish()
+
+ (event,) = events
+ (span_data,) = event["spans"]
+ assert span_data["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
+ assert span_data["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20
+
+
+@pytest.mark.parametrize(
+ "url,image_url_kwargs,expected_content",
+ [
+ pytest.param(
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ {},
+ BLOB_DATA_SUBSTITUTE,
+ id="base64_data_url",
+ ),
+ pytest.param(
+ "https://example.com/image.png",
+ {},
+ "https://example.com/image.png",
+ id="http_url_no_redaction",
+ ),
+ pytest.param(
+ "https://example.com/api?data=iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ {"media_type": "image/png"},
+ "https://example.com/api?data=iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ id="http_url_with_base64_query_param",
+ ),
+ pytest.param(
+ "data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciLz4=",
+ {},
+ BLOB_DATA_SUBSTITUTE,
+ id="complex_mime_type",
+ ),
+ pytest.param(
+ "data:image/png;name=file.png;base64,iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ {},
+ BLOB_DATA_SUBSTITUTE,
+ id="optional_parameters",
+ ),
+ pytest.param(
+ "data:text/plain;charset=utf-8;name=hello.txt;base64,SGVsbG8sIFdvcmxkIQ==",
+ {},
+ BLOB_DATA_SUBSTITUTE,
+ id="multiple_optional_parameters",
+ ),
+ ],
+)
+def test_image_url_base64_content_in_span(
+ sentry_init, capture_events, url, image_url_kwargs, expected_content
+):
+ from sentry_sdk.integrations.pydantic_ai.spans.ai_client import ai_client_span
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ with sentry_sdk.start_transaction(op="test", name="test"):
+ image_url = ImageUrl(url=url, **image_url_kwargs)
+ user_part = UserPromptPart(content=["Look at this image:", image_url])
+ mock_msg = MagicMock()
+ mock_msg.parts = [user_part]
+ mock_msg.instructions = None
+
+ span = ai_client_span([mock_msg], None, None, None)
+ span.finish()
+
+ (event,) = events
+ chat_spans = [s for s in event["spans"] if s["op"] == "gen_ai.chat"]
+ assert len(chat_spans) >= 1
+ messages_data = _get_messages_from_span(chat_spans[0]["data"])
+
+ found_image = False
+ for msg in messages_data:
+ if "content" not in msg:
+ continue
+ for content_item in msg["content"]:
+ if content_item.get("type") == "image":
+ found_image = True
+ assert content_item["content"] == expected_content
+
+ assert found_image, "Image content item should be found in messages data"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "url, image_url_kwargs, expected_content",
+ [
+ pytest.param(
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ {},
+ BLOB_DATA_SUBSTITUTE,
+ id="base64_data_url_redacted",
+ ),
+ pytest.param(
+ "https://example.com/image.png",
+ {},
+ "https://example.com/image.png",
+ id="http_url_no_redaction",
+ ),
+ pytest.param(
+ "https://example.com/api?data=iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ {},
+ "https://example.com/api?data=iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ id="http_url_with_base64_query_param",
+ ),
+ pytest.param(
+ "https://example.com/api?data=iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ {"media_type": "image/png"},
+ "https://example.com/api?data=iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs",
+ id="http_url_with_base64_query_param_and_media_type",
+ ),
+ ],
+)
+async def test_invoke_agent_image_url(
+ sentry_init, capture_events, url, image_url_kwargs, expected_content
+):
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ agent = Agent("test", name="test_image_url_agent")
+
+ events = capture_events()
+ image_url = ImageUrl(url=url, **image_url_kwargs)
+ await agent.run([image_url, "Describe this image"])
+
+ (transaction,) = events
+
+ found_image = False
+
+ chat_spans = [s for s in transaction["spans"] if s["op"] == "gen_ai.chat"]
+ for chat_span in chat_spans:
+ messages_data = _get_messages_from_span(chat_span["data"])
+ for msg in messages_data:
+ if "content" not in msg:
+ continue
+ for content_item in msg["content"]:
+ if content_item.get("type") == "image":
+ assert content_item["content"] == expected_content
+ found_image = True
+
+ assert found_image, "Image content item should be found in messages data"
+
+
+@pytest.mark.asyncio
+async def test_tool_description_in_execute_tool_span(sentry_init, capture_events):
+ """
+ Test that tool description from the tool's docstring is included in execute_tool spans.
+ """
+ agent = Agent(
+ "test",
+ name="test_agent",
+ system_prompt="You are a helpful test assistant.",
+ )
+
+ @agent.tool_plain
+ def multiply_numbers(a: int, b: int) -> int:
+ """Multiply two numbers and return the product."""
+ return a * b
+
+ sentry_init(
+ integrations=[PydanticAIIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+
+ events = capture_events()
+
+ result = await agent.run("What is 5 times 3?")
+ assert result is not None
+
+ (transaction,) = events
+ spans = transaction["spans"]
+
+ tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
+ assert len(tool_spans) >= 1
+
+ tool_span = tool_spans[0]
+ assert tool_span["data"]["gen_ai.tool.name"] == "multiply_numbers"
+ assert SPANDATA.GEN_AI_TOOL_DESCRIPTION in tool_span["data"]
+ assert "Multiply two numbers" in tool_span["data"][SPANDATA.GEN_AI_TOOL_DESCRIPTION]
diff --git a/tests/integrations/pymongo/test_pymongo.py b/tests/integrations/pymongo/test_pymongo.py
index 16438ac971..b57061b0a0 100644
--- a/tests/integrations/pymongo/test_pymongo.py
+++ b/tests/integrations/pymongo/test_pymongo.py
@@ -1,4 +1,5 @@
from sentry_sdk import capture_message, start_transaction
+from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations.pymongo import PyMongoIntegration, _strip_pii
from mockupdb import MockupDB, OpQuery
@@ -9,7 +10,7 @@
@pytest.fixture(scope="session")
def mongo_server():
server = MockupDB(verbose=True)
- server.autoresponds("ismaster", maxWireVersion=6)
+ server.autoresponds("ismaster", maxWireVersion=8)
server.run()
server.autoresponds(
{"find": "test_collection"}, cursor={"id": 123, "firstBatch": []}
@@ -51,24 +52,41 @@ def test_transactions(sentry_init, capture_events, mongo_server, with_pii):
common_tags = {
"db.name": "test_db",
"db.system": "mongodb",
+ "db.driver.name": "pymongo",
"net.peer.name": mongo_server.host,
"net.peer.port": str(mongo_server.port),
}
for span in find, insert_success, insert_fail:
+ assert span["data"][SPANDATA.DB_SYSTEM] == "mongodb"
+ assert span["data"][SPANDATA.DB_DRIVER_NAME] == "pymongo"
+ assert span["data"][SPANDATA.DB_NAME] == "test_db"
+ assert span["data"][SPANDATA.SERVER_ADDRESS] == "localhost"
+ assert span["data"][SPANDATA.SERVER_PORT] == mongo_server.port
for field, value in common_tags.items():
assert span["tags"][field] == value
+ assert span["data"][field] == value
- assert find["op"] == "db.query"
- assert insert_success["op"] == "db.query"
- assert insert_fail["op"] == "db.query"
+ assert find["op"] == "db"
+ assert insert_success["op"] == "db"
+ assert insert_fail["op"] == "db"
+ assert find["data"]["db.operation"] == "find"
assert find["tags"]["db.operation"] == "find"
+ assert insert_success["data"]["db.operation"] == "insert"
assert insert_success["tags"]["db.operation"] == "insert"
+ assert insert_fail["data"]["db.operation"] == "insert"
assert insert_fail["tags"]["db.operation"] == "insert"
- assert find["description"].startswith("find {")
- assert insert_success["description"].startswith("insert {")
- assert insert_fail["description"].startswith("insert {")
+ assert find["description"].startswith('{"find')
+ assert insert_success["description"].startswith('{"insert')
+ assert insert_fail["description"].startswith('{"insert')
+
+ assert find["data"][SPANDATA.DB_MONGODB_COLLECTION] == "test_collection"
+ assert find["tags"][SPANDATA.DB_MONGODB_COLLECTION] == "test_collection"
+ assert insert_success["data"][SPANDATA.DB_MONGODB_COLLECTION] == "test_collection"
+ assert insert_success["tags"][SPANDATA.DB_MONGODB_COLLECTION] == "test_collection"
+ assert insert_fail["data"][SPANDATA.DB_MONGODB_COLLECTION] == "erroneous"
+ assert insert_fail["tags"][SPANDATA.DB_MONGODB_COLLECTION] == "erroneous"
if with_pii:
assert "1" in find["description"]
assert "2" in insert_success["description"]
@@ -83,8 +101,11 @@ def test_transactions(sentry_init, capture_events, mongo_server, with_pii):
and "4" not in insert_fail["description"]
)
+ assert find["status"] == "ok"
assert find["tags"]["status"] == "ok"
+ assert insert_success["status"] == "ok"
assert insert_success["tags"]["status"] == "ok"
+ assert insert_fail["status"] == "internal_error"
assert insert_fail["tags"]["status"] == "internal_error"
@@ -108,18 +129,20 @@ def test_breadcrumbs(sentry_init, capture_events, mongo_server, with_pii):
(crumb,) = event["breadcrumbs"]["values"]
assert crumb["category"] == "query"
- assert crumb["message"].startswith("find {")
+ assert crumb["message"].startswith('{"find')
if with_pii:
assert "1" in crumb["message"]
else:
assert "1" not in crumb["message"]
- assert crumb["type"] == "db.query"
+ assert crumb["type"] == "db"
assert crumb["data"] == {
"db.name": "test_db",
"db.system": "mongodb",
+ "db.driver.name": "pymongo",
"db.operation": "find",
"net.peer.name": mongo_server.host,
"net.peer.port": str(mongo_server.port),
+ "db.mongodb.collection": "test_collection",
}
@@ -417,3 +440,23 @@ def test_breadcrumbs(sentry_init, capture_events, mongo_server, with_pii):
)
def test_strip_pii(testcase):
assert _strip_pii(testcase["command"]) == testcase["command_stripped"]
+
+
+def test_span_origin(sentry_init, capture_events, mongo_server):
+ sentry_init(
+ integrations=[PyMongoIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = MongoClient(mongo_server.uri)
+
+ with start_transaction():
+ list(
+ connection["test_db"]["test_collection"].find({"foobar": 1})
+ ) # force query execution
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.db.pymongo"
diff --git a/tests/integrations/pyramid/__init__.py b/tests/integrations/pyramid/__init__.py
index b63de1d1d3..a77a4d54ca 100644
--- a/tests/integrations/pyramid/__init__.py
+++ b/tests/integrations/pyramid/__init__.py
@@ -1,3 +1,3 @@
import pytest
-pyramid = pytest.importorskip("pyramid")
+pytest.importorskip("pyramid")
diff --git a/tests/integrations/pyramid/test_pyramid.py b/tests/integrations/pyramid/test_pyramid.py
index 0f8755ac6b..95efe8172b 100644
--- a/tests/integrations/pyramid/test_pyramid.py
+++ b/tests/integrations/pyramid/test_pyramid.py
@@ -1,24 +1,33 @@
import json
import logging
-import pkg_resources
-import pytest
-
from io import BytesIO
import pyramid.testing
-
+import pytest
from pyramid.authorization import ACLAuthorizationPolicy
from pyramid.response import Response
+from packaging.version import Version
+from werkzeug.test import Client
from sentry_sdk import capture_message, add_breadcrumb
+from sentry_sdk.consts import DEFAULT_MAX_VALUE_LENGTH
from sentry_sdk.integrations.pyramid import PyramidIntegration
+from sentry_sdk.serializer import MAX_DATABAG_BREADTH
+from tests.conftest import unpack_werkzeug_response
-from werkzeug.test import Client
+try:
+ from importlib.metadata import version
-PYRAMID_VERSION = tuple(
- map(int, pkg_resources.get_distribution("pyramid").version.split("."))
-)
+ PYRAMID_VERSION = Version(version("pyramid")).release
+
+except ImportError:
+ # < py3.8
+ import pkg_resources
+
+ PYRAMID_VERSION = tuple(
+ map(int, pkg_resources.get_distribution("pyramid").version.split("."))
+ )
def hi(request):
@@ -89,7 +98,10 @@ def errors(request):
(event,) = events
(breadcrumb,) = event["breadcrumbs"]["values"]
assert breadcrumb["message"] == "hi2"
- assert event["exception"]["values"][0]["mechanism"]["type"] == "pyramid"
+ # Checking only the last value in the exceptions list,
+ # because Pyramid >= 1.9 returns a chained exception and before just a single exception
+ assert event["exception"]["values"][-1]["mechanism"]["type"] == "pyramid"
+ assert event["exception"]["values"][-1]["type"] == "ZeroDivisionError"
def test_has_context(route, get_client, sentry_init, capture_events):
@@ -146,9 +158,9 @@ def test_transaction_style(
def test_large_json_request(sentry_init, capture_events, route, get_client):
- sentry_init(integrations=[PyramidIntegration()])
+ sentry_init(integrations=[PyramidIntegration()], max_request_body_size="always")
- data = {"foo": {"bar": "a" * 2000}}
+ data = {"foo": {"bar": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10)}}
@route("/")
def index(request):
@@ -165,9 +177,14 @@ def index(request):
(event,) = events
assert event["_meta"]["request"]["data"]["foo"]["bar"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]["bar"]) == 1024
+ assert len(event["request"]["data"]["foo"]["bar"]) == DEFAULT_MAX_VALUE_LENGTH
@pytest.mark.parametrize("data", [{}, []], ids=["empty-dict", "empty-list"])
@@ -192,10 +209,38 @@ def index(request):
assert event["request"]["data"] == data
+def test_json_not_truncated_if_max_request_body_size_is_always(
+ sentry_init, capture_events, route, get_client
+):
+ sentry_init(integrations=[PyramidIntegration()], max_request_body_size="always")
+
+ data = {
+ "key{}".format(i): "value{}".format(i) for i in range(MAX_DATABAG_BREADTH + 10)
+ }
+
+ @route("/")
+ def index(request):
+ assert request.json == data
+ assert request.text == json.dumps(data)
+ capture_message("hi")
+ return Response("ok")
+
+ events = capture_events()
+
+ client = get_client()
+ client.post("/", content_type="application/json", data=json.dumps(data))
+
+ (event,) = events
+ assert event["request"]["data"] == data
+
+
def test_files_and_form(sentry_init, capture_events, route, get_client):
- sentry_init(integrations=[PyramidIntegration()], request_bodies="always")
+ sentry_init(integrations=[PyramidIntegration()], max_request_body_size="always")
- data = {"foo": "a" * 2000, "file": (BytesIO(b"hello"), "hello.txt")}
+ data = {
+ "foo": "a" * (DEFAULT_MAX_VALUE_LENGTH + 10),
+ "file": (BytesIO(b"hello"), "hello.txt"),
+ }
@route("/")
def index(request):
@@ -209,9 +254,14 @@ def index(request):
(event,) = events
assert event["_meta"]["request"]["data"]["foo"] == {
- "": {"len": 2000, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
}
- assert len(event["request"]["data"]["foo"]) == 1024
+ assert len(event["request"]["data"]["foo"]) == DEFAULT_MAX_VALUE_LENGTH
assert event["_meta"]["request"]["data"]["file"] == {"": {"rem": [["!raw", "x"]]}}
assert not event["request"]["data"]["file"]
@@ -281,8 +331,8 @@ def errorhandler(exc, request):
pyramid_config.add_view(errorhandler, context=Exception)
client = get_client()
- app_iter, status, headers = client.get("/")
- assert b"".join(app_iter) == b"bad request"
+ app_iter, status, headers = unpack_werkzeug_response(client.get("/"))
+ assert app_iter == b"bad request"
assert status.lower() == "500 internal server error"
(error,) = errors
@@ -331,9 +381,9 @@ def test_error_in_authenticated_userid(
)
logger = logging.getLogger("test_pyramid")
- class AuthenticationPolicy(object):
+ class AuthenticationPolicy:
def authenticated_userid(self, request):
- logger.error("failed to identify user")
+ logger.warning("failed to identify user")
pyramid_config.set_authorization_policy(ACLAuthorizationPolicy())
pyramid_config.set_authentication_policy(AuthenticationPolicy())
@@ -345,6 +395,16 @@ def authenticated_userid(self, request):
assert len(events) == 1
+ # In `authenticated_userid` there used to be a call to `logging.error`. This would print this error in the
+ # event processor of the Pyramid integration and the logging integration would capture this and send it to Sentry.
+ # This is not possible anymore, because capturing that error in the logging integration would again run all the
+ # event processors (from the global, isolation and current scope) and thus would again run the same pyramid
+ # event processor that raised the error in the first place, leading on an infinite loop.
+ # This test here is now deactivated and always passes, but it is kept here to document the problem.
+ # This change in behavior is also mentioned in the migration documentation for Python SDK 2.0
+
+ # assert "message" not in events[0].keys()
+
def tween_factory(handler, registry):
def tween(request):
@@ -376,3 +436,18 @@ def index(request):
client.get("/")
assert not errors
+
+
+def test_span_origin(sentry_init, capture_events, get_client):
+ sentry_init(
+ integrations=[PyramidIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ client = get_client()
+ client.get("/message")
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.pyramid"
diff --git a/tests/integrations/pyreqwest/__init__.py b/tests/integrations/pyreqwest/__init__.py
new file mode 100644
index 0000000000..dfdd787852
--- /dev/null
+++ b/tests/integrations/pyreqwest/__init__.py
@@ -0,0 +1,9 @@
+import os
+import sys
+import pytest
+
+pytest.importorskip("pyreqwest")
+
+# Load `pyreqwest_helpers` into the module search path to test request source path names relative to module. See
+# `test_request_source_with_module_in_search_path`
+sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
diff --git a/tests/integrations/pyreqwest/pyreqwest_helpers/__init__.py b/tests/integrations/pyreqwest/pyreqwest_helpers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/pyreqwest/pyreqwest_helpers/helpers.py b/tests/integrations/pyreqwest/pyreqwest_helpers/helpers.py
new file mode 100644
index 0000000000..3abc554522
--- /dev/null
+++ b/tests/integrations/pyreqwest/pyreqwest_helpers/helpers.py
@@ -0,0 +1,2 @@
+def get_request_with_client(client, url):
+ client.get(url).build().send()
diff --git a/tests/integrations/pyreqwest/test_pyreqwest.py b/tests/integrations/pyreqwest/test_pyreqwest.py
new file mode 100644
index 0000000000..ad20e2b08a
--- /dev/null
+++ b/tests/integrations/pyreqwest/test_pyreqwest.py
@@ -0,0 +1,484 @@
+import datetime
+import os
+from contextlib import contextmanager
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from threading import Thread
+from unittest import mock
+import pytest
+
+from pyreqwest.client import ClientBuilder, SyncClientBuilder
+from pyreqwest.simple.request import pyreqwest_get as async_pyreqwest_get
+from pyreqwest.simple.sync_request import pyreqwest_get as sync_pyreqwest_get
+
+import sentry_sdk
+from sentry_sdk import start_transaction
+from sentry_sdk.consts import MATCH_ALL, SPANDATA
+from sentry_sdk.integrations.pyreqwest import PyreqwestIntegration
+from tests.conftest import get_free_port
+
+
+class PyreqwestMockHandler(BaseHTTPRequestHandler):
+ captured_requests = []
+
+ def do_GET(self) -> None:
+ self.captured_requests.append(
+ {
+ "path": self.path,
+ "headers": {k.lower(): v for k, v in self.headers.items()},
+ }
+ )
+
+ code = 200
+ if "/status/" in self.path:
+ try:
+ code = int(self.path.split("/")[-1])
+ except (ValueError, IndexError):
+ code = 200
+
+ self.send_response(code)
+ self.end_headers()
+ self.wfile.write(b"OK")
+
+ def log_message(self, format: str, *args: object) -> None:
+ pass
+
+
+@pytest.fixture(scope="module")
+def server_port():
+ port = get_free_port()
+ server = HTTPServer(("localhost", port), PyreqwestMockHandler)
+ thread = Thread(target=server.serve_forever)
+ thread.daemon = True
+ thread.start()
+ yield port
+ server.shutdown()
+
+
+@pytest.fixture(autouse=True)
+def clear_captured_requests():
+ PyreqwestMockHandler.captured_requests.clear()
+
+
+def test_sync_client_spans(sentry_init, capture_events, server_port):
+ sentry_init(integrations=[PyreqwestIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello?q=test#frag"
+ with start_transaction(name="test_transaction"):
+ client = SyncClientBuilder().build()
+ response = client.get(url).build().send()
+ assert response.status == 200
+
+ (event,) = events
+ assert len(event["spans"]) == 1
+ span = event["spans"][0]
+ assert span["op"] == "http.client"
+ assert span["description"] == f"GET http://localhost:{server_port}/hello"
+ assert span["data"]["url"] == f"http://localhost:{server_port}/hello"
+ assert span["data"][SPANDATA.HTTP_METHOD] == "GET"
+ assert span["data"][SPANDATA.HTTP_STATUS_CODE] == 200
+ assert span["data"][SPANDATA.HTTP_QUERY] == "q=test"
+ assert span["data"][SPANDATA.HTTP_FRAGMENT] == "frag"
+ assert span["origin"] == "auto.http.pyreqwest"
+
+
+@pytest.mark.asyncio
+async def test_async_client_spans(sentry_init, capture_events, server_port):
+ sentry_init(integrations=[PyreqwestIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello"
+ async with ClientBuilder().build() as client:
+ with start_transaction(name="test_transaction"):
+ response = await client.get(url).build().send()
+ assert response.status == 200
+
+ (event,) = events
+ assert len(event["spans"]) == 1
+ span = event["spans"][0]
+ assert span["op"] == "http.client"
+ assert span["description"] == f"GET {url}"
+ assert span["data"]["url"] == url
+ assert span["data"][SPANDATA.HTTP_METHOD] == "GET"
+ assert span["data"][SPANDATA.HTTP_STATUS_CODE] == 200
+ assert span["origin"] == "auto.http.pyreqwest"
+
+
+def test_sync_simple_request_spans(sentry_init, capture_events, server_port):
+ sentry_init(integrations=[PyreqwestIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello-simple"
+ with start_transaction(name="test_transaction"):
+ response = sync_pyreqwest_get(url).send()
+ assert response.status == 200
+
+ (event,) = events
+ assert len(event["spans"]) == 1
+ span = event["spans"][0]
+ assert span["op"] == "http.client"
+ assert span["description"] == f"GET {url}"
+ assert span["data"]["url"] == url
+ assert span["data"][SPANDATA.HTTP_METHOD] == "GET"
+ assert span["data"][SPANDATA.HTTP_STATUS_CODE] == 200
+ assert span["origin"] == "auto.http.pyreqwest"
+
+
+@pytest.mark.asyncio
+async def test_async_simple_request_spans(sentry_init, capture_events, server_port):
+ sentry_init(integrations=[PyreqwestIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello-simple-async"
+ with start_transaction(name="test_transaction"):
+ response = await async_pyreqwest_get(url).send()
+ assert response.status == 200
+
+ (event,) = events
+ assert len(event["spans"]) == 1
+ span = event["spans"][0]
+ assert span["op"] == "http.client"
+ assert span["description"] == f"GET {url}"
+ assert span["data"]["url"] == url
+ assert span["data"][SPANDATA.HTTP_METHOD] == "GET"
+ assert span["data"][SPANDATA.HTTP_STATUS_CODE] == 200
+ assert span["origin"] == "auto.http.pyreqwest"
+
+
+def test_span_origin(sentry_init, capture_events, server_port):
+ sentry_init(integrations=[PyreqwestIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/origin"
+ with start_transaction(name="test_transaction"):
+ client = SyncClientBuilder().build()
+ client.get(url).build().send()
+
+ (event,) = events
+ assert event["spans"][0]["origin"] == "auto.http.pyreqwest"
+
+
+def test_outgoing_trace_headers(sentry_init, server_port):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ traces_sample_rate=1.0,
+ trace_propagation_targets=["localhost"],
+ )
+
+ url = f"http://localhost:{server_port}/trace"
+ with start_transaction(
+ name="test_transaction", trace_id="01234567890123456789012345678901"
+ ):
+ client = SyncClientBuilder().build()
+ response = client.get(url).build().send()
+ assert response.status == 200
+
+ assert len(PyreqwestMockHandler.captured_requests) == 1
+ headers = PyreqwestMockHandler.captured_requests[0]["headers"]
+
+ assert "sentry-trace" in headers
+ assert headers["sentry-trace"].startswith("01234567890123456789012345678901")
+ assert "baggage" in headers
+ assert "sentry-trace_id=01234567890123456789012345678901" in headers["baggage"]
+
+
+def test_outgoing_trace_headers_append_to_baggage(sentry_init, server_port):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ traces_sample_rate=1.0,
+ trace_propagation_targets=["localhost"],
+ release="d08ebdb9309e1b004c6f52202de58a09c2268e42",
+ )
+
+ url = f"http://localhost:{server_port}/baggage"
+
+ with mock.patch("sentry_sdk.tracing_utils.Random.randrange", return_value=500000):
+ with start_transaction(
+ name="/interactions/other-dogs/new-dog",
+ op="greeting.sniff",
+ trace_id="01234567890123456789012345678901",
+ ):
+ client = SyncClientBuilder().build()
+ client.get(url).header("baggage", "custom=data").build().send()
+
+ assert len(PyreqwestMockHandler.captured_requests) == 1
+ headers = PyreqwestMockHandler.captured_requests[0]["headers"]
+
+ assert "baggage" in headers
+ baggage = headers["baggage"]
+ assert "custom=data" in baggage
+ assert "sentry-trace_id=01234567890123456789012345678901" in baggage
+ assert "sentry-sample_rand=0.500000" in baggage
+ assert "sentry-environment=production" in baggage
+ assert "sentry-release=d08ebdb9309e1b004c6f52202de58a09c2268e42" in baggage
+ assert "sentry-transaction=/interactions/other-dogs/new-dog" in baggage
+ assert "sentry-sample_rate=1.0" in baggage
+ assert "sentry-sampled=true" in baggage
+
+
+@pytest.mark.parametrize(
+ "trace_propagation_targets,trace_propagated",
+ [
+ [None, False],
+ [[], False],
+ [[MATCH_ALL], True],
+ [["localhost"], True],
+ [[r"https?:\/\/[\w\-]+(\.[\w\-]+)+\.net"], False],
+ ],
+)
+def test_trace_propagation_targets(
+ sentry_init, server_port, trace_propagation_targets, trace_propagated
+):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ trace_propagation_targets=trace_propagation_targets,
+ traces_sample_rate=1.0,
+ )
+
+ url = f"http://localhost:{server_port}/propagation"
+
+ with start_transaction():
+ client = SyncClientBuilder().build()
+ client.get(url).build().send()
+
+ assert len(PyreqwestMockHandler.captured_requests) == 1
+ headers = PyreqwestMockHandler.captured_requests[0]["headers"]
+
+ if trace_propagated:
+ assert "sentry-trace" in headers
+ else:
+ assert "sentry-trace" not in headers
+
+
+@pytest.mark.tests_internal_exceptions
+def test_omit_url_data_if_parsing_fails(sentry_init, capture_events, server_port):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/parse-fail"
+
+ with start_transaction(name="test_transaction"):
+ with mock.patch(
+ "sentry_sdk.integrations.pyreqwest.parse_url",
+ side_effect=ValueError,
+ ):
+ client = SyncClientBuilder().build()
+ client.get(url).build().send()
+
+ (event,) = events
+ span = event["spans"][0]
+
+ assert span["description"] == "GET [Filtered]"
+ assert span["data"][SPANDATA.HTTP_METHOD] == "GET"
+ assert span["data"][SPANDATA.HTTP_STATUS_CODE] == 200
+ assert "url" not in span["data"]
+ assert SPANDATA.HTTP_QUERY not in span["data"]
+ assert SPANDATA.HTTP_FRAGMENT not in span["data"]
+
+
+def test_request_source_disabled(sentry_init, capture_events, server_port):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=False,
+ http_request_source_threshold_ms=0,
+ )
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello"
+
+ with start_transaction(name="test_transaction"):
+ client = SyncClientBuilder().build()
+ client.get(url).build().send()
+
+ (event,) = events
+ span = event["spans"][0]
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+@pytest.mark.parametrize("enable_http_request_source", [None, True])
+def test_request_source_enabled(
+ sentry_init, capture_events, server_port, enable_http_request_source
+):
+ sentry_options = {
+ "integrations": [PyreqwestIntegration()],
+ "traces_sample_rate": 1.0,
+ "http_request_source_threshold_ms": 0,
+ }
+ if enable_http_request_source is not None:
+ sentry_options["enable_http_request_source"] = enable_http_request_source
+
+ sentry_init(**sentry_options)
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello"
+
+ with start_transaction(name="test_transaction"):
+ client = SyncClientBuilder().build()
+ client.get(url).build().send()
+
+ (event,) = events
+ span = event["spans"][0]
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+
+def test_request_source(sentry_init, capture_events, server_port):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=0,
+ )
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello"
+
+ with start_transaction(name="test_transaction"):
+ client = SyncClientBuilder().build()
+ client.get(url).build().send()
+
+ (event,) = events
+ span = event["spans"][0]
+ data = span.get("data", {})
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE)
+ == "tests.integrations.pyreqwest.test_pyreqwest"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/pyreqwest/test_pyreqwest.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "test_request_source"
+
+
+def test_request_source_with_module_in_search_path(
+ sentry_init, capture_events, server_port
+):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=0,
+ )
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello"
+
+ with start_transaction(name="test_transaction"):
+ from pyreqwest_helpers.helpers import get_request_with_client
+
+ client = SyncClientBuilder().build()
+ get_request_with_client(client, url)
+
+ (event,) = events
+ span = event["spans"][0]
+ data = span.get("data", {})
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "pyreqwest_helpers.helpers"
+ assert data.get(SPANDATA.CODE_FILEPATH) == "pyreqwest_helpers/helpers.py"
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "get_request_with_client"
+
+
+def test_no_request_source_if_duration_too_short(
+ sentry_init, capture_events, server_port
+):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=100,
+ )
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello"
+
+ with start_transaction(name="test_transaction"):
+
+ @contextmanager
+ def fake_start_span(*args, **kwargs):
+ with sentry_sdk.start_span(*args, **kwargs) as span:
+ pass
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=99999)
+ yield span
+
+ with mock.patch(
+ "sentry_sdk.integrations.pyreqwest.start_span",
+ fake_start_span,
+ ):
+ client = SyncClientBuilder().build()
+ client.get(url).build().send()
+
+ (event,) = events
+ span = event["spans"][-1]
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+def test_request_source_if_duration_over_threshold(
+ sentry_init, capture_events, server_port
+):
+ sentry_init(
+ integrations=[PyreqwestIntegration()],
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=100,
+ )
+ events = capture_events()
+
+ url = f"http://localhost:{server_port}/hello"
+
+ with start_transaction(name="test_transaction"):
+
+ @contextmanager
+ def fake_start_span(*args, **kwargs):
+ with sentry_sdk.start_span(*args, **kwargs) as span:
+ pass
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=100001)
+ yield span
+
+ with mock.patch(
+ "sentry_sdk.integrations.pyreqwest.start_span",
+ fake_start_span,
+ ):
+ client = SyncClientBuilder().build()
+ client.get(url).build().send()
+
+ (event,) = events
+ span = event["spans"][-1]
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
diff --git a/tests/integrations/quart/__init__.py b/tests/integrations/quart/__init__.py
index ea02dfb3a6..2bf976c50d 100644
--- a/tests/integrations/quart/__init__.py
+++ b/tests/integrations/quart/__init__.py
@@ -1,3 +1,3 @@
import pytest
-quart = pytest.importorskip("quart")
+pytest.importorskip("quart")
diff --git a/tests/integrations/quart/test_quart.py b/tests/integrations/quart/test_quart.py
index bda2c1013e..7c027455c0 100644
--- a/tests/integrations/quart/test_quart.py
+++ b/tests/integrations/quart/test_quart.py
@@ -1,35 +1,39 @@
+import importlib
import json
+import sys
import threading
+from unittest import mock
import pytest
-import pytest_asyncio
-
-quart = pytest.importorskip("quart")
-
-from quart import Quart, Response, abort, stream_with_context
-from quart.views import View
-
-from quart_auth import AuthManager, AuthUser, login_user
+import sentry_sdk
from sentry_sdk import (
set_tag,
- configure_scope,
capture_message,
capture_exception,
- last_event_id,
)
from sentry_sdk.integrations.logging import LoggingIntegration
import sentry_sdk.integrations.quart as quart_sentry
-auth_manager = AuthManager()
+def quart_app_factory():
+ # These imports are inlined because the `test_quart_flask_patch` testcase
+ # tests behavior that is triggered by importing a package before any Quart
+ # imports happen, so we can't have these on the module level
+ from quart import Quart
+
+ try:
+ from quart_auth import QuartAuth
+ auth_manager = QuartAuth()
+ except ImportError:
+ from quart_auth import AuthManager
+
+ auth_manager = AuthManager()
-@pytest_asyncio.fixture
-async def app():
app = Quart(__name__)
- app.debug = True
- app.config["TESTING"] = True
+ app.debug = False
+ app.config["TESTING"] = False
app.secret_key = "haha"
auth_manager.init_app(app)
@@ -70,8 +74,49 @@ def integration_enabled_params(request):
@pytest.mark.asyncio
-async def test_has_context(sentry_init, app, capture_events):
+@pytest.mark.forked
+@pytest.mark.skipif(
+ not importlib.util.find_spec("quart_flask_patch"),
+ reason="requires quart_flask_patch",
+)
+@pytest.mark.skipif(
+ sys.version_info >= (3, 14),
+ reason="quart_flask_patch not working on 3.14 (yet?)",
+)
+async def test_quart_flask_patch(sentry_init, capture_events, reset_integrations):
+ # This testcase is forked because `import quart_flask_patch` needs to run
+ # before anything else Quart-related is imported (since it monkeypatches
+ # some things) and we don't want this to affect other testcases.
+ #
+ # It's also important this testcase be run before any other testcase
+ # that uses `quart_app_factory`.
+ import quart_flask_patch # noqa: F401
+
+ app = quart_app_factory()
+ sentry_init(
+ integrations=[quart_sentry.QuartIntegration()],
+ )
+
+ @app.route("/")
+ async def index():
+ 1 / 0
+
+ events = capture_events()
+
+ client = app.test_client()
+ try:
+ await client.get("/")
+ except ZeroDivisionError:
+ pass
+
+ (event,) = events
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "quart"
+
+
+@pytest.mark.asyncio
+async def test_has_context(sentry_init, capture_events):
sentry_init(integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
events = capture_events()
client = app.test_client()
@@ -96,7 +141,6 @@ async def test_has_context(sentry_init, app, capture_events):
)
async def test_transaction_style(
sentry_init,
- app,
capture_events,
url,
transaction_style,
@@ -108,6 +152,7 @@ async def test_transaction_style(
quart_sentry.QuartIntegration(transaction_style=transaction_style)
]
)
+ app = quart_app_factory()
events = capture_events()
client = app.test_client()
@@ -119,21 +164,14 @@ async def test_transaction_style(
@pytest.mark.asyncio
-@pytest.mark.parametrize("debug", (True, False))
-@pytest.mark.parametrize("testing", (True, False))
async def test_errors(
sentry_init,
capture_exceptions,
capture_events,
- app,
- debug,
- testing,
integration_enabled_params,
):
- sentry_init(debug=True, **integration_enabled_params)
-
- app.debug = debug
- app.testing = testing
+ sentry_init(**integration_enabled_params)
+ app = quart_app_factory()
@app.route("/")
async def index():
@@ -157,9 +195,10 @@ async def index():
@pytest.mark.asyncio
async def test_quart_auth_not_installed(
- sentry_init, app, capture_events, monkeypatch, integration_enabled_params
+ sentry_init, capture_events, monkeypatch, integration_enabled_params
):
sentry_init(**integration_enabled_params)
+ app = quart_app_factory()
monkeypatch.setattr(quart_sentry, "quart_auth", None)
@@ -174,9 +213,10 @@ async def test_quart_auth_not_installed(
@pytest.mark.asyncio
async def test_quart_auth_not_configured(
- sentry_init, app, capture_events, monkeypatch, integration_enabled_params
+ sentry_init, capture_events, monkeypatch, integration_enabled_params
):
sentry_init(**integration_enabled_params)
+ app = quart_app_factory()
assert quart_sentry.quart_auth
@@ -190,9 +230,10 @@ async def test_quart_auth_not_configured(
@pytest.mark.asyncio
async def test_quart_auth_partially_configured(
- sentry_init, app, capture_events, monkeypatch, integration_enabled_params
+ sentry_init, capture_events, monkeypatch, integration_enabled_params
):
sentry_init(**integration_enabled_params)
+ app = quart_app_factory()
events = capture_events()
@@ -209,13 +250,15 @@ async def test_quart_auth_partially_configured(
async def test_quart_auth_configured(
send_default_pii,
sentry_init,
- app,
user_id,
capture_events,
monkeypatch,
integration_enabled_params,
):
+ from quart_auth import AuthUser, login_user
+
sentry_init(send_default_pii=send_default_pii, **integration_enabled_params)
+ app = quart_app_factory()
@app.route("/login")
async def login():
@@ -246,10 +289,9 @@ async def login():
[quart_sentry.QuartIntegration(), LoggingIntegration(event_level="ERROR")],
],
)
-async def test_errors_not_reported_twice(
- sentry_init, integrations, capture_events, app
-):
+async def test_errors_not_reported_twice(sentry_init, integrations, capture_events):
sentry_init(integrations=integrations)
+ app = quart_app_factory()
@app.route("/")
async def index():
@@ -269,7 +311,7 @@ async def index():
@pytest.mark.asyncio
-async def test_logging(sentry_init, capture_events, app):
+async def test_logging(sentry_init, capture_events):
# ensure that Quart's logger magic doesn't break ours
sentry_init(
integrations=[
@@ -277,6 +319,7 @@ async def test_logging(sentry_init, capture_events, app):
LoggingIntegration(event_level="ERROR"),
]
)
+ app = quart_app_factory()
@app.route("/")
async def index():
@@ -293,13 +336,17 @@ async def index():
@pytest.mark.asyncio
-async def test_no_errors_without_request(app, sentry_init):
+async def test_no_errors_without_request(sentry_init):
sentry_init(integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
+
async with app.app_context():
capture_exception(ValueError())
-def test_cli_commands_raise(app):
+def test_cli_commands_raise():
+ app = quart_app_factory()
+
if not hasattr(app, "cli"):
pytest.skip("Too old quart version")
@@ -316,11 +363,9 @@ def foo():
@pytest.mark.asyncio
-async def test_500(sentry_init, capture_events, app):
+async def test_500(sentry_init):
sentry_init(integrations=[quart_sentry.QuartIntegration()])
-
- app.debug = False
- app.testing = False
+ app = quart_app_factory()
@app.route("/")
async def index():
@@ -328,25 +373,18 @@ async def index():
@app.errorhandler(500)
async def error_handler(err):
- return "Sentry error: %s" % last_event_id()
-
- events = capture_events()
+ return "Sentry error."
client = app.test_client()
response = await client.get("/")
- (event,) = events
- assert (await response.get_data(as_text=True)) == "Sentry error: %s" % event[
- "event_id"
- ]
+ assert (await response.get_data(as_text=True)) == "Sentry error."
@pytest.mark.asyncio
-async def test_error_in_errorhandler(sentry_init, capture_events, app):
+async def test_error_in_errorhandler(sentry_init, capture_events):
sentry_init(integrations=[quart_sentry.QuartIntegration()])
-
- app.debug = False
- app.testing = False
+ app = quart_app_factory()
@app.route("/")
async def index():
@@ -373,8 +411,11 @@ async def error_handler(err):
@pytest.mark.asyncio
-async def test_bad_request_not_captured(sentry_init, capture_events, app):
+async def test_bad_request_not_captured(sentry_init, capture_events):
+ from quart import abort
+
sentry_init(integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
events = capture_events()
@app.route("/")
@@ -389,22 +430,22 @@ async def index():
@pytest.mark.asyncio
-async def test_does_not_leak_scope(sentry_init, capture_events, app):
+async def test_does_not_leak_scope(sentry_init, capture_events):
+ from quart import Response, stream_with_context
+
sentry_init(integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
events = capture_events()
- with configure_scope() as scope:
- scope.set_tag("request_data", False)
+ sentry_sdk.get_isolation_scope().set_tag("request_data", False)
@app.route("/")
async def index():
- with configure_scope() as scope:
- scope.set_tag("request_data", True)
+ sentry_sdk.get_isolation_scope().set_tag("request_data", True)
async def generate():
for row in range(1000):
- with configure_scope() as scope:
- assert scope._tags["request_data"]
+ assert sentry_sdk.get_isolation_scope()._tags["request_data"]
yield str(row) + "\n"
@@ -416,14 +457,13 @@ async def generate():
str(row) + "\n" for row in range(1000)
)
assert not events
-
- with configure_scope() as scope:
- assert not scope._tags["request_data"]
+ assert not sentry_sdk.get_isolation_scope()._tags["request_data"]
@pytest.mark.asyncio
-async def test_scoped_test_client(sentry_init, app):
+async def test_scoped_test_client(sentry_init):
sentry_init(integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
@app.route("/")
async def index():
@@ -437,12 +477,13 @@ async def index():
@pytest.mark.asyncio
@pytest.mark.parametrize("exc_cls", [ZeroDivisionError, Exception])
async def test_errorhandler_for_exception_swallows_exception(
- sentry_init, app, capture_events, exc_cls
+ sentry_init, capture_events, exc_cls
):
# In contrast to error handlers for a status code, error
# handlers for exceptions can swallow the exception (this is
# just how the Quart signal works)
sentry_init(integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
events = capture_events()
@app.route("/")
@@ -461,8 +502,9 @@ async def zerodivision(e):
@pytest.mark.asyncio
-async def test_tracing_success(sentry_init, capture_events, app):
+async def test_tracing_success(sentry_init, capture_events):
sentry_init(traces_sample_rate=1.0, integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
@app.before_request
async def _():
@@ -494,8 +536,9 @@ async def hi_tx():
@pytest.mark.asyncio
-async def test_tracing_error(sentry_init, capture_events, app):
+async def test_tracing_error(sentry_init, capture_events):
sentry_init(traces_sample_rate=1.0, integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
events = capture_events()
@@ -518,8 +561,11 @@ async def error():
@pytest.mark.asyncio
-async def test_class_based_views(sentry_init, app, capture_events):
+async def test_class_based_views(sentry_init, capture_events):
+ from quart.views import View
+
sentry_init(integrations=[quart_sentry.QuartIntegration()])
+ app = quart_app_factory()
events = capture_events()
@app.route("/")
@@ -543,27 +589,61 @@ async def dispatch_request(self):
@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"])
-async def test_active_thread_id(sentry_init, capture_envelopes, endpoint, app):
- sentry_init(
- traces_sample_rate=1.0,
- _experiments={"profiles_sample_rate": 1.0},
- )
+@pytest.mark.asyncio
+async def test_active_thread_id(
+ sentry_init, capture_envelopes, teardown_profiling, endpoint
+):
+ with mock.patch(
+ "sentry_sdk.profiler.transaction_profiler.PROFILE_MINIMUM_SAMPLES", 0
+ ):
+ sentry_init(
+ traces_sample_rate=1.0,
+ profiles_sample_rate=1.0,
+ )
+ app = quart_app_factory()
- envelopes = capture_envelopes()
+ envelopes = capture_envelopes()
- async with app.test_client() as client:
- response = await client.get(endpoint)
- assert response.status_code == 200
+ async with app.test_client() as client:
+ response = await client.get(endpoint)
+ assert response.status_code == 200
+
+ data = json.loads(await response.get_data(as_text=True))
- data = json.loads(response.content)
+ envelopes = [envelope for envelope in envelopes]
+ assert len(envelopes) == 1
- envelopes = [envelope for envelope in envelopes]
- assert len(envelopes) == 1
+ profiles = [item for item in envelopes[0].items if item.type == "profile"]
+ assert len(profiles) == 1, envelopes[0].items
- profiles = [item for item in envelopes[0].items if item.type == "profile"]
- assert len(profiles) == 1
+ for item in profiles:
+ transactions = item.payload.json["transactions"]
+ assert len(transactions) == 1
+ assert str(data["active"]) == transactions[0]["active_thread_id"]
- for profile in profiles:
- transactions = profile.payload.json["transactions"]
+ transactions = [
+ item for item in envelopes[0].items if item.type == "transaction"
+ ]
assert len(transactions) == 1
- assert str(data["active"]) == transactions[0]["active_thread_id"]
+
+ for item in transactions:
+ transaction = item.payload.json
+ trace_context = transaction["contexts"]["trace"]
+ assert str(data["active"]) == trace_context["data"]["thread.id"]
+
+
+@pytest.mark.asyncio
+async def test_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[quart_sentry.QuartIntegration()],
+ traces_sample_rate=1.0,
+ )
+ app = quart_app_factory()
+ events = capture_events()
+
+ client = app.test_client()
+ await client.get("/message")
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.quart"
diff --git a/tests/integrations/ray/__init__.py b/tests/integrations/ray/__init__.py
new file mode 100644
index 0000000000..92f6d93906
--- /dev/null
+++ b/tests/integrations/ray/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("ray")
diff --git a/tests/integrations/ray/test_ray.py b/tests/integrations/ray/test_ray.py
new file mode 100644
index 0000000000..be7ebc9d05
--- /dev/null
+++ b/tests/integrations/ray/test_ray.py
@@ -0,0 +1,292 @@
+import json
+import os
+import pytest
+import shutil
+import uuid
+
+import ray
+
+import sentry_sdk
+from sentry_sdk.envelope import Envelope
+from sentry_sdk.integrations.ray import RayIntegration
+from tests.conftest import TestTransport
+
+
+@pytest.fixture(autouse=True)
+def shutdown_ray(tmpdir):
+ yield
+ ray.shutdown()
+
+
+class RayTestTransport(TestTransport):
+ def __init__(self):
+ self.envelopes = []
+ super().__init__()
+
+ def capture_envelope(self, envelope: Envelope) -> None:
+ self.envelopes.append(envelope)
+
+
+class RayLoggingTransport(TestTransport):
+ def capture_envelope(self, envelope: Envelope) -> None:
+ print(envelope.serialize().decode("utf-8", "replace"))
+
+
+def setup_sentry_with_logging_transport():
+ setup_sentry(transport=RayLoggingTransport())
+
+
+def setup_sentry(transport=None):
+ sentry_sdk.init(
+ integrations=[RayIntegration()],
+ transport=RayTestTransport() if transport is None else transport,
+ traces_sample_rate=1.0,
+ )
+
+
+def read_error_from_log(job_id, ray_temp_dir):
+ # Find the actual session directory that Ray created
+ session_dirs = [d for d in os.listdir(ray_temp_dir) if d.startswith("session_")]
+ if not session_dirs:
+ raise FileNotFoundError(f"No session directory found in {ray_temp_dir}")
+
+ session_dir = os.path.join(ray_temp_dir, session_dirs[0])
+ log_dir = os.path.join(session_dir, "logs")
+
+ if not os.path.exists(log_dir):
+ raise FileNotFoundError(f"No logs directory found at {log_dir}")
+
+ log_file = [
+ f
+ for f in os.listdir(log_dir)
+ if "worker" in f and job_id in f and f.endswith(".out")
+ ][0]
+
+ with open(os.path.join(log_dir, log_file), "r") as file:
+ lines = file.readlines()
+
+ try:
+ # parse error object from log line
+ error = json.loads(lines[4][:-1])
+ except IndexError:
+ error = None
+
+ return error
+
+
+def example_task():
+ with sentry_sdk.start_span(op="task", name="example task step"):
+ ...
+
+ return sentry_sdk.get_client().transport.envelopes
+
+
+# RayIntegration must leave variadic keyword arguments at the end
+def example_task_with_kwargs(**kwargs):
+ with sentry_sdk.start_span(op="task", name="example task step"):
+ ...
+
+ return sentry_sdk.get_client().transport.envelopes
+
+
+@pytest.mark.parametrize(
+ "task_options", [{}, {"num_cpus": 0, "memory": 1024 * 1024 * 10}]
+)
+@pytest.mark.parametrize(
+ "task",
+ [example_task, example_task_with_kwargs],
+)
+def test_tracing_in_ray_tasks(task_options, task):
+ setup_sentry()
+
+ ray.init(
+ runtime_env={
+ "worker_process_setup_hook": setup_sentry,
+ "working_dir": "./",
+ }
+ )
+
+ # Setup ray task, calling decorator directly instead of @,
+ # to accommodate for test parametrization
+ if task_options:
+ example_task = ray.remote(**task_options)(task)
+ else:
+ example_task = ray.remote(task)
+
+ # Function name shouldn't be overwritten by Sentry wrapper
+ assert (
+ example_task._function_name
+ == f"tests.integrations.ray.test_ray.{task.__name__}"
+ )
+
+ with sentry_sdk.start_transaction(op="task", name="ray test transaction"):
+ worker_envelopes = ray.get(example_task.remote())
+
+ client_envelope = sentry_sdk.get_client().transport.envelopes[0]
+ client_transaction = client_envelope.get_transaction_event()
+ assert client_transaction["transaction"] == "ray test transaction"
+ assert client_transaction["transaction_info"] == {"source": "custom"}
+
+ worker_envelope = worker_envelopes[0]
+ worker_transaction = worker_envelope.get_transaction_event()
+ assert (
+ worker_transaction["transaction"]
+ == f"tests.integrations.ray.test_ray.{task.__name__}"
+ )
+ assert worker_transaction["transaction_info"] == {"source": "task"}
+
+ (span,) = client_transaction["spans"]
+ assert span["op"] == "queue.submit.ray"
+ assert span["origin"] == "auto.queue.ray"
+ assert span["description"] == f"tests.integrations.ray.test_ray.{task.__name__}"
+ assert span["parent_span_id"] == client_transaction["contexts"]["trace"]["span_id"]
+ assert span["trace_id"] == client_transaction["contexts"]["trace"]["trace_id"]
+
+ (span,) = worker_transaction["spans"]
+ assert span["op"] == "task"
+ assert span["origin"] == "manual"
+ assert span["description"] == "example task step"
+ assert span["parent_span_id"] == worker_transaction["contexts"]["trace"]["span_id"]
+ assert span["trace_id"] == worker_transaction["contexts"]["trace"]["trace_id"]
+
+ assert (
+ client_transaction["contexts"]["trace"]["trace_id"]
+ == worker_transaction["contexts"]["trace"]["trace_id"]
+ )
+
+
+def test_errors_in_ray_tasks():
+ setup_sentry_with_logging_transport()
+
+ ray_temp_dir = os.path.join("/tmp", f"ray_test_{uuid.uuid4().hex[:8]}")
+ os.makedirs(ray_temp_dir, exist_ok=True)
+
+ try:
+ ray.init(
+ runtime_env={
+ "worker_process_setup_hook": setup_sentry_with_logging_transport,
+ "working_dir": "./",
+ },
+ _temp_dir=ray_temp_dir,
+ )
+
+ # Setup ray task
+ @ray.remote
+ def example_task():
+ 1 / 0
+
+ with sentry_sdk.start_transaction(op="task", name="ray test transaction"):
+ with pytest.raises(ZeroDivisionError):
+ future = example_task.remote()
+ ray.get(future)
+
+ job_id = future.job_id().hex()
+ error = read_error_from_log(job_id, ray_temp_dir)
+
+ assert error["level"] == "error"
+ assert (
+ error["transaction"]
+ == "tests.integrations.ray.test_ray.test_errors_in_ray_tasks..example_task"
+ )
+ assert error["exception"]["values"][0]["mechanism"]["type"] == "ray"
+ assert not error["exception"]["values"][0]["mechanism"]["handled"]
+
+ finally:
+ if os.path.exists(ray_temp_dir):
+ shutil.rmtree(ray_temp_dir, ignore_errors=True)
+
+
+# Arbitrary keyword argument to test all decorator paths
+@pytest.mark.parametrize("remote_kwargs", [{}, {"namespace": "actors"}])
+def test_tracing_in_ray_actors(remote_kwargs):
+ setup_sentry()
+
+ ray.init(
+ runtime_env={
+ "worker_process_setup_hook": setup_sentry,
+ "working_dir": "./",
+ }
+ )
+
+ # Setup ray actor
+ if remote_kwargs:
+
+ @ray.remote(**remote_kwargs)
+ class Counter:
+ def __init__(self):
+ self.n = 0
+
+ def increment(self):
+ with sentry_sdk.start_span(op="task", name="example actor execution"):
+ self.n += 1
+
+ return sentry_sdk.get_client().transport.envelopes
+ else:
+
+ @ray.remote
+ class Counter:
+ def __init__(self):
+ self.n = 0
+
+ def increment(self):
+ with sentry_sdk.start_span(op="task", name="example actor execution"):
+ self.n += 1
+
+ return sentry_sdk.get_client().transport.envelopes
+
+ with sentry_sdk.start_transaction(op="task", name="ray test transaction"):
+ counter = Counter.remote()
+ worker_envelopes = ray.get(counter.increment.remote())
+
+ client_envelope = sentry_sdk.get_client().transport.envelopes[0]
+ client_transaction = client_envelope.get_transaction_event()
+
+ # Spans for submitting the actor task are not created (actors are not supported yet)
+ assert client_transaction["spans"] == []
+
+ # Transaction are not yet created when executing ray actors (actors are not supported yet)
+ assert worker_envelopes == []
+
+
+def test_errors_in_ray_actors():
+ setup_sentry_with_logging_transport()
+
+ ray_temp_dir = os.path.join("/tmp", f"ray_test_{uuid.uuid4().hex[:8]}")
+ os.makedirs(ray_temp_dir, exist_ok=True)
+
+ try:
+ ray.init(
+ runtime_env={
+ "worker_process_setup_hook": setup_sentry_with_logging_transport,
+ "working_dir": "./",
+ },
+ _temp_dir=ray_temp_dir,
+ )
+
+ # Setup ray actor
+ @ray.remote
+ class Counter:
+ def __init__(self):
+ self.n = 0
+
+ def increment(self):
+ with sentry_sdk.start_span(op="task", name="example actor execution"):
+ 1 / 0
+
+ return sentry_sdk.get_client().transport.envelopes
+
+ with sentry_sdk.start_transaction(op="task", name="ray test transaction"):
+ with pytest.raises(ZeroDivisionError):
+ counter = Counter.remote()
+ future = counter.increment.remote()
+ ray.get(future)
+
+ job_id = future.job_id().hex()
+ error = read_error_from_log(job_id, ray_temp_dir)
+
+ # We do not capture errors in ray actors yet
+ assert error is None
+
+ finally:
+ if os.path.exists(ray_temp_dir):
+ shutil.rmtree(ray_temp_dir, ignore_errors=True)
diff --git a/tests/integrations/redis/asyncio/__init__.py b/tests/integrations/redis/asyncio/__init__.py
new file mode 100644
index 0000000000..bd93246a9a
--- /dev/null
+++ b/tests/integrations/redis/asyncio/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("fakeredis.aioredis")
diff --git a/tests/integrations/redis/asyncio/test_redis_asyncio.py b/tests/integrations/redis/asyncio/test_redis_asyncio.py
new file mode 100644
index 0000000000..17130b337b
--- /dev/null
+++ b/tests/integrations/redis/asyncio/test_redis_asyncio.py
@@ -0,0 +1,112 @@
+import pytest
+
+from sentry_sdk import capture_message, start_transaction
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.integrations.redis import RedisIntegration
+from tests.conftest import ApproxDict
+
+from fakeredis.aioredis import FakeRedis
+
+
+@pytest.mark.asyncio
+async def test_async_basic(sentry_init, capture_events):
+ sentry_init(integrations=[RedisIntegration()])
+ events = capture_events()
+
+ connection = FakeRedis()
+
+ await connection.get("foobar")
+ capture_message("hi")
+
+ (event,) = events
+ (crumb,) = event["breadcrumbs"]["values"]
+
+ assert crumb == {
+ "category": "redis",
+ "message": "GET 'foobar'",
+ "data": {
+ "db.operation": "GET",
+ "redis.key": "foobar",
+ "redis.command": "GET",
+ "redis.is_cluster": False,
+ },
+ "timestamp": crumb["timestamp"],
+ "type": "redis",
+ }
+
+
+@pytest.mark.parametrize(
+ "is_transaction, send_default_pii, expected_first_ten",
+ [
+ (False, False, ["GET 'foo'", "SET 'bar' [Filtered]", "SET 'baz' [Filtered]"]),
+ (True, True, ["GET 'foo'", "SET 'bar' 1", "SET 'baz' 2"]),
+ ],
+)
+@pytest.mark.asyncio
+async def test_async_redis_pipeline(
+ sentry_init, capture_events, is_transaction, send_default_pii, expected_first_ten
+):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ connection = FakeRedis()
+ with start_transaction():
+ pipeline = connection.pipeline(transaction=is_transaction)
+ pipeline.get("foo")
+ pipeline.set("bar", 1)
+ pipeline.set("baz", 2)
+ await pipeline.execute()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["op"] == "db.redis"
+ assert span["description"] == "redis.pipeline.execute"
+ assert span["data"] == ApproxDict(
+ {
+ "redis.commands": {
+ "count": 3,
+ "first_ten": expected_first_ten,
+ },
+ SPANDATA.DB_SYSTEM: "redis",
+ SPANDATA.DB_NAME: "0",
+ SPANDATA.SERVER_ADDRESS: connection.connection_pool.connection_kwargs.get(
+ "host"
+ ),
+ SPANDATA.SERVER_PORT: 6379,
+ }
+ )
+ assert span["tags"] == {
+ "redis.transaction": is_transaction,
+ "redis.is_cluster": False,
+ }
+
+
+@pytest.mark.asyncio
+async def test_async_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeRedis()
+ with start_transaction(name="custom_transaction"):
+ # default case
+ await connection.set("somekey", "somevalue")
+
+ # pipeline
+ pipeline = connection.pipeline(transaction=False)
+ pipeline.get("somekey")
+ pipeline.set("anotherkey", 1)
+ await pipeline.execute()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ for span in event["spans"]:
+ assert span["origin"] == "auto.db.redis"
diff --git a/tests/integrations/redis/cluster/__init__.py b/tests/integrations/redis/cluster/__init__.py
new file mode 100644
index 0000000000..008b24295f
--- /dev/null
+++ b/tests/integrations/redis/cluster/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("redis.cluster")
diff --git a/tests/integrations/redis/cluster/test_redis_cluster.py b/tests/integrations/redis/cluster/test_redis_cluster.py
new file mode 100644
index 0000000000..83d1b45cc9
--- /dev/null
+++ b/tests/integrations/redis/cluster/test_redis_cluster.py
@@ -0,0 +1,172 @@
+import pytest
+from sentry_sdk import capture_message
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.api import start_transaction
+from sentry_sdk.integrations.redis import RedisIntegration
+from tests.conftest import ApproxDict
+
+import redis
+
+
+@pytest.fixture(autouse=True)
+def monkeypatch_rediscluster_class(reset_integrations):
+ pipeline_cls = redis.cluster.ClusterPipeline
+ redis.cluster.NodesManager.initialize = lambda *_, **__: None
+ redis.RedisCluster.command = lambda *_: []
+ redis.RedisCluster.pipeline = lambda *_, **__: pipeline_cls(None, None)
+ redis.RedisCluster.get_default_node = lambda *_, **__: redis.cluster.ClusterNode(
+ "localhost", 6379
+ )
+ pipeline_cls.execute = lambda *_, **__: None
+ redis.RedisCluster.execute_command = lambda *_, **__: []
+
+
+def test_rediscluster_breadcrumb(sentry_init, capture_events):
+ sentry_init(integrations=[RedisIntegration()])
+ events = capture_events()
+
+ rc = redis.RedisCluster(host="localhost", port=6379)
+ rc.get("foobar")
+ capture_message("hi")
+
+ (event,) = events
+ crumbs = event["breadcrumbs"]["values"]
+
+ # on initializing a RedisCluster, a COMMAND call is made - this is not important for the test
+ # but must be accounted for
+ assert len(crumbs) in (1, 2)
+ assert len(crumbs) == 1 or crumbs[0]["message"] == "COMMAND"
+
+ crumb = crumbs[-1]
+
+ assert crumb == {
+ "category": "redis",
+ "message": "GET 'foobar'",
+ "data": {
+ "db.operation": "GET",
+ "redis.key": "foobar",
+ "redis.command": "GET",
+ "redis.is_cluster": True,
+ },
+ "timestamp": crumb["timestamp"],
+ "type": "redis",
+ }
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, description",
+ [
+ (False, "SET 'bar' [Filtered]"),
+ (True, "SET 'bar' 1"),
+ ],
+)
+def test_rediscluster_basic(sentry_init, capture_events, send_default_pii, description):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ with start_transaction():
+ rc = redis.RedisCluster(host="localhost", port=6379)
+ rc.set("bar", 1)
+
+ (event,) = events
+ spans = event["spans"]
+
+ # on initializing a RedisCluster, a COMMAND call is made - this is not important for the test
+ # but must be accounted for
+ assert len(spans) in (1, 2)
+ assert len(spans) == 1 or spans[0]["description"] == "COMMAND"
+
+ span = spans[-1]
+ assert span["op"] == "db.redis"
+ assert span["description"] == description
+ assert span["data"] == ApproxDict(
+ {
+ SPANDATA.DB_SYSTEM: "redis",
+ # ClusterNode converts localhost to 127.0.0.1
+ SPANDATA.SERVER_ADDRESS: "127.0.0.1",
+ SPANDATA.SERVER_PORT: 6379,
+ }
+ )
+ assert span["tags"] == {
+ "db.operation": "SET",
+ "redis.command": "SET",
+ "redis.is_cluster": True,
+ "redis.key": "bar",
+ }
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, expected_first_ten",
+ [
+ (False, ["GET 'foo'", "SET 'bar' [Filtered]", "SET 'baz' [Filtered]"]),
+ (True, ["GET 'foo'", "SET 'bar' 1", "SET 'baz' 2"]),
+ ],
+)
+def test_rediscluster_pipeline(
+ sentry_init, capture_events, send_default_pii, expected_first_ten
+):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ rc = redis.RedisCluster(host="localhost", port=6379)
+ with start_transaction():
+ pipeline = rc.pipeline()
+ pipeline.get("foo")
+ pipeline.set("bar", 1)
+ pipeline.set("baz", 2)
+ pipeline.execute()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["op"] == "db.redis"
+ assert span["description"] == "redis.pipeline.execute"
+ assert span["data"] == ApproxDict(
+ {
+ "redis.commands": {
+ "count": 3,
+ "first_ten": expected_first_ten,
+ },
+ SPANDATA.DB_SYSTEM: "redis",
+ # ClusterNode converts localhost to 127.0.0.1
+ SPANDATA.SERVER_ADDRESS: "127.0.0.1",
+ SPANDATA.SERVER_PORT: 6379,
+ }
+ )
+ assert span["tags"] == {
+ "redis.transaction": False, # For Cluster, this is always False
+ "redis.is_cluster": True,
+ }
+
+
+def test_rediscluster_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ rc = redis.RedisCluster(host="localhost", port=6379)
+ with start_transaction(name="custom_transaction"):
+ # default case
+ rc.set("somekey", "somevalue")
+
+ # pipeline
+ pipeline = rc.pipeline(transaction=False)
+ pipeline.get("somekey")
+ pipeline.set("anotherkey", 1)
+ pipeline.execute()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ for span in event["spans"]:
+ assert span["origin"] == "auto.db.redis"
diff --git a/tests/integrations/redis/cluster_asyncio/__init__.py b/tests/integrations/redis/cluster_asyncio/__init__.py
new file mode 100644
index 0000000000..663979a4e2
--- /dev/null
+++ b/tests/integrations/redis/cluster_asyncio/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("redis.asyncio.cluster")
diff --git a/tests/integrations/redis/cluster_asyncio/test_redis_cluster_asyncio.py b/tests/integrations/redis/cluster_asyncio/test_redis_cluster_asyncio.py
new file mode 100644
index 0000000000..993a2962ca
--- /dev/null
+++ b/tests/integrations/redis/cluster_asyncio/test_redis_cluster_asyncio.py
@@ -0,0 +1,176 @@
+import pytest
+
+from sentry_sdk import capture_message, start_transaction
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.integrations.redis import RedisIntegration
+from tests.conftest import ApproxDict
+
+from redis.asyncio import cluster
+
+
+async def fake_initialize(*_, **__):
+ return None
+
+
+async def fake_execute_command(*_, **__):
+ return []
+
+
+async def fake_execute(*_, **__):
+ return None
+
+
+@pytest.fixture(autouse=True)
+def monkeypatch_rediscluster_asyncio_class(reset_integrations):
+ pipeline_cls = cluster.ClusterPipeline
+ cluster.NodesManager.initialize = fake_initialize
+ cluster.RedisCluster.get_default_node = lambda *_, **__: cluster.ClusterNode(
+ "localhost", 6379
+ )
+ cluster.RedisCluster.pipeline = lambda self, *_, **__: pipeline_cls(self)
+ pipeline_cls.execute = fake_execute
+ cluster.RedisCluster.execute_command = fake_execute_command
+
+
+@pytest.mark.asyncio
+async def test_async_breadcrumb(sentry_init, capture_events):
+ sentry_init(integrations=[RedisIntegration()])
+ events = capture_events()
+
+ connection = cluster.RedisCluster(host="localhost", port=6379)
+
+ await connection.get("foobar")
+ capture_message("hi")
+
+ (event,) = events
+ (crumb,) = event["breadcrumbs"]["values"]
+
+ assert crumb == {
+ "category": "redis",
+ "message": "GET 'foobar'",
+ "data": ApproxDict(
+ {
+ "db.operation": "GET",
+ "redis.key": "foobar",
+ "redis.command": "GET",
+ "redis.is_cluster": True,
+ }
+ ),
+ "timestamp": crumb["timestamp"],
+ "type": "redis",
+ }
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, description",
+ [
+ (False, "SET 'bar' [Filtered]"),
+ (True, "SET 'bar' 1"),
+ ],
+)
+@pytest.mark.asyncio
+async def test_async_basic(sentry_init, capture_events, send_default_pii, description):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ connection = cluster.RedisCluster(host="localhost", port=6379)
+ with start_transaction():
+ await connection.set("bar", 1)
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["op"] == "db.redis"
+ assert span["description"] == description
+ assert span["data"] == ApproxDict(
+ {
+ SPANDATA.DB_SYSTEM: "redis",
+ # ClusterNode converts localhost to 127.0.0.1
+ SPANDATA.SERVER_ADDRESS: "127.0.0.1",
+ SPANDATA.SERVER_PORT: 6379,
+ }
+ )
+ assert span["tags"] == {
+ "redis.is_cluster": True,
+ "db.operation": "SET",
+ "redis.command": "SET",
+ "redis.key": "bar",
+ }
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, expected_first_ten",
+ [
+ (False, ["GET 'foo'", "SET 'bar' [Filtered]", "SET 'baz' [Filtered]"]),
+ (True, ["GET 'foo'", "SET 'bar' 1", "SET 'baz' 2"]),
+ ],
+)
+@pytest.mark.asyncio
+async def test_async_redis_pipeline(
+ sentry_init, capture_events, send_default_pii, expected_first_ten
+):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ connection = cluster.RedisCluster(host="localhost", port=6379)
+ with start_transaction():
+ pipeline = connection.pipeline()
+ pipeline.get("foo")
+ pipeline.set("bar", 1)
+ pipeline.set("baz", 2)
+ await pipeline.execute()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["op"] == "db.redis"
+ assert span["description"] == "redis.pipeline.execute"
+ assert span["data"] == ApproxDict(
+ {
+ "redis.commands": {
+ "count": 3,
+ "first_ten": expected_first_ten,
+ },
+ SPANDATA.DB_SYSTEM: "redis",
+ # ClusterNode converts localhost to 127.0.0.1
+ SPANDATA.SERVER_ADDRESS: "127.0.0.1",
+ SPANDATA.SERVER_PORT: 6379,
+ }
+ )
+ assert span["tags"] == {
+ "redis.transaction": False,
+ "redis.is_cluster": True,
+ }
+
+
+@pytest.mark.asyncio
+async def test_async_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = cluster.RedisCluster(host="localhost", port=6379)
+ with start_transaction(name="custom_transaction"):
+ # default case
+ await connection.set("somekey", "somevalue")
+
+ # pipeline
+ pipeline = connection.pipeline(transaction=False)
+ pipeline.get("somekey")
+ pipeline.set("anotherkey", 1)
+ await pipeline.execute()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ for span in event["spans"]:
+ assert span["origin"] == "auto.db.redis"
diff --git a/tests/integrations/redis/test_redis.py b/tests/integrations/redis/test_redis.py
index 9a6d066e03..84c5699d14 100644
--- a/tests/integrations/redis/test_redis.py
+++ b/tests/integrations/redis/test_redis.py
@@ -1,8 +1,19 @@
+from unittest import mock
+
+import pytest
+from fakeredis import FakeStrictRedis
+
from sentry_sdk import capture_message, start_transaction
+from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations.redis import RedisIntegration
-from fakeredis import FakeStrictRedis
-import pytest
+
+MOCK_CONNECTION_POOL = mock.MagicMock()
+MOCK_CONNECTION_POOL.connection_kwargs = {
+ "host": "localhost",
+ "port": 63791,
+ "db": 1,
+}
def test_basic(sentry_init, capture_events):
@@ -24,20 +35,32 @@ def test_basic(sentry_init, capture_events):
"redis.key": "foobar",
"redis.command": "GET",
"redis.is_cluster": False,
+ "db.operation": "GET",
},
"timestamp": crumb["timestamp"],
"type": "redis",
}
-@pytest.mark.parametrize("is_transaction", [False, True])
-def test_redis_pipeline(sentry_init, capture_events, is_transaction):
- sentry_init(integrations=[RedisIntegration()], traces_sample_rate=1.0)
+@pytest.mark.parametrize(
+ "is_transaction, send_default_pii, expected_first_ten",
+ [
+ (False, False, ["GET 'foo'", "SET 'bar' [Filtered]", "SET 'baz' [Filtered]"]),
+ (True, True, ["GET 'foo'", "SET 'bar' 1", "SET 'baz' 2"]),
+ ],
+)
+def test_redis_pipeline(
+ sentry_init, capture_events, is_transaction, send_default_pii, expected_first_ten
+):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
events = capture_events()
connection = FakeStrictRedis()
with start_transaction():
-
pipeline = connection.pipeline(transaction=is_transaction)
pipeline.get("foo")
pipeline.set("bar", 1)
@@ -48,13 +71,251 @@ def test_redis_pipeline(sentry_init, capture_events, is_transaction):
(span,) = event["spans"]
assert span["op"] == "db.redis"
assert span["description"] == "redis.pipeline.execute"
- assert span["data"] == {
- "redis.commands": {
- "count": 3,
- "first_ten": ["GET 'foo'", "SET 'bar' 1", "SET 'baz' 2"],
- }
+ assert span["data"][SPANDATA.DB_SYSTEM] == "redis"
+ assert span["data"]["redis.commands"] == {
+ "count": 3,
+ "first_ten": expected_first_ten,
}
assert span["tags"] == {
"redis.transaction": is_transaction,
"redis.is_cluster": False,
}
+
+
+def test_sensitive_data(sentry_init, capture_events):
+ # fakeredis does not support the AUTH command, so we need to mock it
+ with mock.patch(
+ "sentry_sdk.integrations.redis.utils._COMMANDS_INCLUDING_SENSITIVE_DATA",
+ ["get"],
+ ):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with start_transaction():
+ connection.get(
+ "this is super secret"
+ ) # because fakeredis does not support AUTH we use GET instead
+
+ (event,) = events
+ spans = event["spans"]
+ assert spans[0]["op"] == "db.redis"
+ assert spans[0]["description"] == "GET [Filtered]"
+
+
+def test_pii_data_redacted(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with start_transaction():
+ connection.set("somekey1", "my secret string1")
+ connection.set("somekey2", "my secret string2")
+ connection.get("somekey2")
+ connection.delete("somekey1", "somekey2")
+
+ (event,) = events
+ spans = event["spans"]
+ assert spans[0]["op"] == "db.redis"
+ assert spans[0]["description"] == "SET 'somekey1' [Filtered]"
+ assert spans[1]["description"] == "SET 'somekey2' [Filtered]"
+ assert spans[2]["description"] == "GET 'somekey2'"
+ assert spans[3]["description"] == "DEL 'somekey1' [Filtered]"
+
+
+def test_pii_data_sent(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with start_transaction():
+ connection.set("somekey1", "my secret string1")
+ connection.set("somekey2", "my secret string2")
+ connection.get("somekey2")
+ connection.delete("somekey1", "somekey2")
+
+ (event,) = events
+ spans = event["spans"]
+ assert spans[0]["op"] == "db.redis"
+ assert spans[0]["description"] == "SET 'somekey1' 'my secret string1'"
+ assert spans[1]["description"] == "SET 'somekey2' 'my secret string2'"
+ assert spans[2]["description"] == "GET 'somekey2'"
+ assert spans[3]["description"] == "DEL 'somekey1' 'somekey2'"
+
+
+def test_no_data_truncation_by_default(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with start_transaction():
+ long_string = "a" * 100000
+ connection.set("somekey1", long_string)
+ short_string = "b" * 10
+ connection.set("somekey2", short_string)
+
+ (event,) = events
+ spans = event["spans"]
+ assert spans[0]["op"] == "db.redis"
+ assert spans[0]["description"] == f"SET 'somekey1' '{long_string}'"
+ assert spans[1]["description"] == f"SET 'somekey2' '{short_string}'"
+
+
+def test_data_truncation_custom(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration(max_data_size=30)],
+ traces_sample_rate=1.0,
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with start_transaction():
+ long_string = "a" * 100000
+ connection.set("somekey1", long_string)
+ short_string = "b" * 10
+ connection.set("somekey2", short_string)
+
+ (event,) = events
+ spans = event["spans"]
+ assert spans[0]["op"] == "db.redis"
+ assert spans[0]["description"] == "SET 'somekey1' '%s..." % (
+ long_string[: 30 - len("...") - len("SET 'somekey1' '")],
+ )
+ assert spans[1]["description"] == "SET 'somekey2' '%s'" % (short_string,)
+
+
+def test_breadcrumbs(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration(max_data_size=30)],
+ send_default_pii=True,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+
+ long_string = "a" * 100000
+ connection.set("somekey1", long_string)
+ short_string = "b" * 10
+ connection.set("somekey2", short_string)
+
+ capture_message("hi")
+
+ (event,) = events
+ crumbs = event["breadcrumbs"]["values"]
+
+ assert crumbs[0] == {
+ "message": "SET 'somekey1' 'aaaaaaaaaaa...",
+ "type": "redis",
+ "category": "redis",
+ "data": {
+ "db.operation": "SET",
+ "redis.is_cluster": False,
+ "redis.command": "SET",
+ "redis.key": "somekey1",
+ },
+ "timestamp": crumbs[0]["timestamp"],
+ }
+ assert crumbs[1] == {
+ "message": "SET 'somekey2' 'bbbbbbbbbb'",
+ "type": "redis",
+ "category": "redis",
+ "data": {
+ "db.operation": "SET",
+ "redis.is_cluster": False,
+ "redis.command": "SET",
+ "redis.key": "somekey2",
+ },
+ "timestamp": crumbs[1]["timestamp"],
+ }
+
+
+def test_db_connection_attributes_client(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[RedisIntegration()],
+ )
+ events = capture_events()
+
+ with start_transaction():
+ connection = FakeStrictRedis(connection_pool=MOCK_CONNECTION_POOL)
+ connection.get("foobar")
+
+ (event,) = events
+ (span,) = event["spans"]
+
+ assert span["op"] == "db.redis"
+ assert span["description"] == "GET 'foobar'"
+ assert span["data"][SPANDATA.DB_SYSTEM] == "redis"
+ assert span["data"][SPANDATA.DB_DRIVER_NAME] == "redis-py"
+ assert span["data"][SPANDATA.DB_NAME] == "1"
+ assert span["data"][SPANDATA.SERVER_ADDRESS] == "localhost"
+ assert span["data"][SPANDATA.SERVER_PORT] == 63791
+
+
+def test_db_connection_attributes_pipeline(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[RedisIntegration()],
+ )
+ events = capture_events()
+
+ with start_transaction():
+ connection = FakeStrictRedis(connection_pool=MOCK_CONNECTION_POOL)
+ pipeline = connection.pipeline(transaction=False)
+ pipeline.get("foo")
+ pipeline.set("bar", 1)
+ pipeline.set("baz", 2)
+ pipeline.execute()
+
+ (event,) = events
+ (span,) = event["spans"]
+
+ assert span["op"] == "db.redis"
+ assert span["description"] == "redis.pipeline.execute"
+ assert span["data"][SPANDATA.DB_SYSTEM] == "redis"
+ assert span["data"][SPANDATA.DB_DRIVER_NAME] == "redis-py"
+ assert span["data"][SPANDATA.DB_NAME] == "1"
+ assert span["data"][SPANDATA.SERVER_ADDRESS] == "localhost"
+ assert span["data"][SPANDATA.SERVER_PORT] == 63791
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with start_transaction(name="custom_transaction"):
+ # default case
+ connection.set("somekey", "somevalue")
+
+ # pipeline
+ pipeline = connection.pipeline(transaction=False)
+ pipeline.get("somekey")
+ pipeline.set("anotherkey", 1)
+ pipeline.execute()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ for span in event["spans"]:
+ assert span["origin"] == "auto.db.redis"
diff --git a/tests/integrations/redis/test_redis_cache_module.py b/tests/integrations/redis/test_redis_cache_module.py
new file mode 100644
index 0000000000..f118aa53f5
--- /dev/null
+++ b/tests/integrations/redis/test_redis_cache_module.py
@@ -0,0 +1,318 @@
+import uuid
+
+import pytest
+
+import fakeredis
+from fakeredis import FakeStrictRedis
+
+from sentry_sdk.integrations.redis import RedisIntegration
+from sentry_sdk.integrations.redis.utils import _get_safe_key, _key_as_string
+from sentry_sdk.utils import parse_version
+import sentry_sdk
+
+
+FAKEREDIS_VERSION = parse_version(fakeredis.__version__)
+
+
+def test_no_cache_basic(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with sentry_sdk.start_transaction():
+ connection.get("mycachekey")
+
+ (event,) = events
+ spans = event["spans"]
+ assert len(spans) == 1
+ assert spans[0]["op"] == "db.redis"
+
+
+def test_cache_basic(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(
+ cache_prefixes=["mycache"],
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with sentry_sdk.start_transaction():
+ connection.hget("mycachekey", "myfield")
+ connection.get("mycachekey")
+ connection.set("mycachekey1", "bla")
+ connection.setex("mycachekey2", 10, "blub")
+ connection.mget("mycachekey1", "mycachekey2")
+
+ (event,) = events
+ spans = event["spans"]
+ assert len(spans) == 9
+
+ # no cache support for hget command
+ assert spans[0]["op"] == "db.redis"
+ assert spans[0]["tags"]["redis.command"] == "HGET"
+
+ assert spans[1]["op"] == "cache.get"
+ assert spans[2]["op"] == "db.redis"
+ assert spans[2]["tags"]["redis.command"] == "GET"
+
+ assert spans[3]["op"] == "cache.put"
+ assert spans[4]["op"] == "db.redis"
+ assert spans[4]["tags"]["redis.command"] == "SET"
+
+ assert spans[5]["op"] == "cache.put"
+ assert spans[6]["op"] == "db.redis"
+ assert spans[6]["tags"]["redis.command"] == "SETEX"
+
+ assert spans[7]["op"] == "cache.get"
+ assert spans[8]["op"] == "db.redis"
+ assert spans[8]["tags"]["redis.command"] == "MGET"
+
+
+def test_cache_keys(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(
+ cache_prefixes=["bla", "blub"],
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with sentry_sdk.start_transaction():
+ connection.get("somethingelse")
+ connection.get("blub")
+ connection.get("blubkeything")
+ connection.get("bl")
+
+ (event,) = events
+ spans = event["spans"]
+ assert len(spans) == 6
+ assert spans[0]["op"] == "db.redis"
+ assert spans[0]["description"] == "GET 'somethingelse'"
+
+ assert spans[1]["op"] == "cache.get"
+ assert spans[1]["description"] == "blub"
+ assert spans[2]["op"] == "db.redis"
+ assert spans[2]["description"] == "GET 'blub'"
+
+ assert spans[3]["op"] == "cache.get"
+ assert spans[3]["description"] == "blubkeything"
+ assert spans[4]["op"] == "db.redis"
+ assert spans[4]["description"] == "GET 'blubkeything'"
+
+ assert spans[5]["op"] == "db.redis"
+ assert spans[5]["description"] == "GET 'bl'"
+
+
+def test_cache_data(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(
+ cache_prefixes=["mycache"],
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis(host="mycacheserver.io", port=6378)
+ with sentry_sdk.start_transaction():
+ connection.get("mycachekey")
+ connection.set("mycachekey", "事实胜于雄辩")
+ connection.get("mycachekey")
+
+ (event,) = events
+ spans = event["spans"]
+
+ assert len(spans) == 6
+
+ assert spans[0]["op"] == "cache.get"
+ assert spans[0]["description"] == "mycachekey"
+ assert spans[0]["data"]["cache.key"] == [
+ "mycachekey",
+ ]
+ assert spans[0]["data"]["cache.hit"] == False # noqa: E712
+ assert "cache.item_size" not in spans[0]["data"]
+ # very old fakeredis can not handle port and/or host.
+ # only applicable for Redis v3
+ if FAKEREDIS_VERSION <= (2, 7, 1):
+ assert "network.peer.port" not in spans[0]["data"]
+ else:
+ assert spans[0]["data"]["network.peer.port"] == 6378
+ if FAKEREDIS_VERSION <= (1, 7, 1):
+ assert "network.peer.address" not in spans[0]["data"]
+ else:
+ assert spans[0]["data"]["network.peer.address"] == "mycacheserver.io"
+
+ assert spans[1]["op"] == "db.redis" # we ignore db spans in this test.
+
+ assert spans[2]["op"] == "cache.put"
+ assert spans[2]["description"] == "mycachekey"
+ assert spans[2]["data"]["cache.key"] == [
+ "mycachekey",
+ ]
+ assert "cache.hit" not in spans[1]["data"]
+ assert spans[2]["data"]["cache.item_size"] == 18
+ # very old fakeredis can not handle port.
+ # only used with redis v3
+ if FAKEREDIS_VERSION <= (2, 7, 1):
+ assert "network.peer.port" not in spans[2]["data"]
+ else:
+ assert spans[2]["data"]["network.peer.port"] == 6378
+ if FAKEREDIS_VERSION <= (1, 7, 1):
+ assert "network.peer.address" not in spans[2]["data"]
+ else:
+ assert spans[2]["data"]["network.peer.address"] == "mycacheserver.io"
+
+ assert spans[3]["op"] == "db.redis" # we ignore db spans in this test.
+
+ assert spans[4]["op"] == "cache.get"
+ assert spans[4]["description"] == "mycachekey"
+ assert spans[4]["data"]["cache.key"] == [
+ "mycachekey",
+ ]
+ assert spans[4]["data"]["cache.hit"] == True # noqa: E712
+ assert spans[4]["data"]["cache.item_size"] == 18
+ # very old fakeredis can not handle port.
+ # only used with redis v3
+ if FAKEREDIS_VERSION <= (2, 7, 1):
+ assert "network.peer.port" not in spans[4]["data"]
+ else:
+ assert spans[4]["data"]["network.peer.port"] == 6378
+ if FAKEREDIS_VERSION <= (1, 7, 1):
+ assert "network.peer.address" not in spans[4]["data"]
+ else:
+ assert spans[4]["data"]["network.peer.address"] == "mycacheserver.io"
+
+ assert spans[5]["op"] == "db.redis" # we ignore db spans in this test.
+
+
+def test_cache_prefixes(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(
+ cache_prefixes=["yes"],
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeStrictRedis()
+ with sentry_sdk.start_transaction():
+ connection.mget("yes", "no")
+ connection.mget("no", 1, "yes")
+ connection.mget("no", "yes.1", "yes.2")
+ connection.mget("no.1", "no.2", "no.3")
+ connection.mget("no.1", "no.2", "no.actually.yes")
+ connection.mget(b"no.3", b"yes.5")
+ connection.mget(uuid.uuid4().bytes)
+ connection.mget(uuid.uuid4().bytes, "yes")
+
+ (event,) = events
+
+ spans = event["spans"]
+ assert len(spans) == 13 # 8 db spans + 5 cache spans
+
+ cache_spans = [span for span in spans if span["op"] == "cache.get"]
+ assert len(cache_spans) == 5
+
+ assert cache_spans[0]["description"] == "yes, no"
+ assert cache_spans[1]["description"] == "no, 1, yes"
+ assert cache_spans[2]["description"] == "no, yes.1, yes.2"
+ assert cache_spans[3]["description"] == "no.3, yes.5"
+ assert cache_spans[4]["description"] == ", yes"
+
+
+@pytest.mark.parametrize(
+ "method_name,args,kwargs,expected_key",
+ [
+ (None, None, None, None),
+ ("", None, None, None),
+ ("set", ["bla", "valuebla"], None, ("bla",)),
+ ("setex", ["bla", 10, "valuebla"], None, ("bla",)),
+ ("get", ["bla"], None, ("bla",)),
+ ("mget", ["bla", "blub", "foo"], None, ("bla", "blub", "foo")),
+ ("set", [b"bla", "valuebla"], None, (b"bla",)),
+ ("setex", [b"bla", 10, "valuebla"], None, (b"bla",)),
+ ("get", [b"bla"], None, (b"bla",)),
+ ("mget", [b"bla", "blub", "foo"], None, (b"bla", "blub", "foo")),
+ ("not-important", None, {"something": "bla"}, None),
+ ("not-important", None, {"key": None}, None),
+ ("not-important", None, {"key": "bla"}, ("bla",)),
+ ("not-important", None, {"key": b"bla"}, (b"bla",)),
+ ("not-important", None, {"key": []}, None),
+ (
+ "not-important",
+ None,
+ {
+ "key": [
+ "bla",
+ ]
+ },
+ ("bla",),
+ ),
+ (
+ "not-important",
+ None,
+ {"key": [b"bla", "blub", "foo"]},
+ (b"bla", "blub", "foo"),
+ ),
+ (
+ "not-important",
+ None,
+ {"key": b"\x00c\x0f\xeaC\xe1L\x1c\xbff\xcb\xcc\xc1\xed\xc6\t"},
+ (b"\x00c\x0f\xeaC\xe1L\x1c\xbff\xcb\xcc\xc1\xed\xc6\t",),
+ ),
+ (
+ "get",
+ [b"\x00c\x0f\xeaC\xe1L\x1c\xbff\xcb\xcc\xc1\xed\xc6\t"],
+ None,
+ (b"\x00c\x0f\xeaC\xe1L\x1c\xbff\xcb\xcc\xc1\xed\xc6\t",),
+ ),
+ (
+ "get",
+ [123],
+ None,
+ (123,),
+ ),
+ ],
+)
+def test_get_safe_key(method_name, args, kwargs, expected_key):
+ assert _get_safe_key(method_name, args, kwargs) == expected_key
+
+
+@pytest.mark.parametrize(
+ "key,expected_key",
+ [
+ (None, ""),
+ (("bla",), "bla"),
+ (("bla", "blub", "foo"), "bla, blub, foo"),
+ ((b"bla",), "bla"),
+ ((b"bla", "blub", "foo"), "bla, blub, foo"),
+ (
+ [
+ "bla",
+ ],
+ "bla",
+ ),
+ (["bla", "blub", "foo"], "bla, blub, foo"),
+ ([uuid.uuid4().bytes], ""),
+ ({"key1": 1, "key2": 2}, "key1, key2"),
+ (1, "1"),
+ ([1, 2, 3, b"hello"], "1, 2, 3, hello"),
+ ],
+)
+def test_key_as_string(key, expected_key):
+ assert _key_as_string(key) == expected_key
diff --git a/tests/integrations/redis/test_redis_cache_module_async.py b/tests/integrations/redis/test_redis_cache_module_async.py
new file mode 100644
index 0000000000..d607f92fbd
--- /dev/null
+++ b/tests/integrations/redis/test_redis_cache_module_async.py
@@ -0,0 +1,187 @@
+import pytest
+
+try:
+ import fakeredis
+ from fakeredis.aioredis import FakeRedis as FakeRedisAsync
+except ModuleNotFoundError:
+ FakeRedisAsync = None
+
+if FakeRedisAsync is None:
+ pytest.skip(
+ "Skipping tests because fakeredis.aioredis not available",
+ allow_module_level=True,
+ )
+
+from sentry_sdk.integrations.redis import RedisIntegration
+from sentry_sdk.utils import parse_version
+import sentry_sdk
+
+
+FAKEREDIS_VERSION = parse_version(fakeredis.__version__)
+
+
+@pytest.mark.asyncio
+async def test_no_cache_basic(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeRedisAsync()
+ with sentry_sdk.start_transaction():
+ await connection.get("myasynccachekey")
+
+ (event,) = events
+ spans = event["spans"]
+ assert len(spans) == 1
+ assert spans[0]["op"] == "db.redis"
+
+
+@pytest.mark.asyncio
+async def test_cache_basic(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(
+ cache_prefixes=["myasynccache"],
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeRedisAsync()
+ with sentry_sdk.start_transaction():
+ await connection.get("myasynccachekey")
+
+ (event,) = events
+ spans = event["spans"]
+ assert len(spans) == 2
+
+ assert spans[0]["op"] == "cache.get"
+ assert spans[1]["op"] == "db.redis"
+
+
+@pytest.mark.asyncio
+async def test_cache_keys(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(
+ cache_prefixes=["abla", "ablub"],
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeRedisAsync()
+ with sentry_sdk.start_transaction():
+ await connection.get("asomethingelse")
+ await connection.get("ablub")
+ await connection.get("ablubkeything")
+ await connection.get("abl")
+
+ (event,) = events
+ spans = event["spans"]
+ assert len(spans) == 6
+ assert spans[0]["op"] == "db.redis"
+ assert spans[0]["description"] == "GET 'asomethingelse'"
+
+ assert spans[1]["op"] == "cache.get"
+ assert spans[1]["description"] == "ablub"
+ assert spans[2]["op"] == "db.redis"
+ assert spans[2]["description"] == "GET 'ablub'"
+
+ assert spans[3]["op"] == "cache.get"
+ assert spans[3]["description"] == "ablubkeything"
+ assert spans[4]["op"] == "db.redis"
+ assert spans[4]["description"] == "GET 'ablubkeything'"
+
+ assert spans[5]["op"] == "db.redis"
+ assert spans[5]["description"] == "GET 'abl'"
+
+
+@pytest.mark.asyncio
+async def test_cache_data(sentry_init, capture_events):
+ sentry_init(
+ integrations=[
+ RedisIntegration(
+ cache_prefixes=["myasynccache"],
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ connection = FakeRedisAsync(host="mycacheserver.io", port=6378)
+ with sentry_sdk.start_transaction():
+ await connection.get("myasynccachekey")
+ await connection.set("myasynccachekey", "事实胜于雄辩")
+ await connection.get("myasynccachekey")
+
+ (event,) = events
+ spans = event["spans"]
+
+ assert len(spans) == 6
+
+ assert spans[0]["op"] == "cache.get"
+ assert spans[0]["description"] == "myasynccachekey"
+ assert spans[0]["data"]["cache.key"] == [
+ "myasynccachekey",
+ ]
+ assert spans[0]["data"]["cache.hit"] == False # noqa: E712
+ assert "cache.item_size" not in spans[0]["data"]
+ # very old fakeredis can not handle port and/or host.
+ # only applicable for Redis v3
+ if FAKEREDIS_VERSION <= (2, 7, 1):
+ assert "network.peer.port" not in spans[0]["data"]
+ else:
+ assert spans[0]["data"]["network.peer.port"] == 6378
+ if FAKEREDIS_VERSION <= (1, 7, 1):
+ assert "network.peer.address" not in spans[0]["data"]
+ else:
+ assert spans[0]["data"]["network.peer.address"] == "mycacheserver.io"
+
+ assert spans[1]["op"] == "db.redis" # we ignore db spans in this test.
+
+ assert spans[2]["op"] == "cache.put"
+ assert spans[2]["description"] == "myasynccachekey"
+ assert spans[2]["data"]["cache.key"] == [
+ "myasynccachekey",
+ ]
+ assert "cache.hit" not in spans[1]["data"]
+ assert spans[2]["data"]["cache.item_size"] == 18
+ # very old fakeredis can not handle port.
+ # only used with redis v3
+ if FAKEREDIS_VERSION <= (2, 7, 1):
+ assert "network.peer.port" not in spans[2]["data"]
+ else:
+ assert spans[2]["data"]["network.peer.port"] == 6378
+ if FAKEREDIS_VERSION <= (1, 7, 1):
+ assert "network.peer.address" not in spans[2]["data"]
+ else:
+ assert spans[2]["data"]["network.peer.address"] == "mycacheserver.io"
+
+ assert spans[3]["op"] == "db.redis" # we ignore db spans in this test.
+
+ assert spans[4]["op"] == "cache.get"
+ assert spans[4]["description"] == "myasynccachekey"
+ assert spans[4]["data"]["cache.key"] == [
+ "myasynccachekey",
+ ]
+ assert spans[4]["data"]["cache.hit"] == True # noqa: E712
+ assert spans[4]["data"]["cache.item_size"] == 18
+ # very old fakeredis can not handle port.
+ # only used with redis v3
+ if FAKEREDIS_VERSION <= (2, 7, 1):
+ assert "network.peer.port" not in spans[4]["data"]
+ else:
+ assert spans[4]["data"]["network.peer.port"] == 6378
+ if FAKEREDIS_VERSION <= (1, 7, 1):
+ assert "network.peer.address" not in spans[4]["data"]
+ else:
+ assert spans[4]["data"]["network.peer.address"] == "mycacheserver.io"
+
+ assert spans[5]["op"] == "db.redis" # we ignore db spans in this test.
diff --git a/tests/integrations/rediscluster/__init__.py b/tests/integrations/redis_py_cluster_legacy/__init__.py
similarity index 100%
rename from tests/integrations/rediscluster/__init__.py
rename to tests/integrations/redis_py_cluster_legacy/__init__.py
diff --git a/tests/integrations/redis_py_cluster_legacy/test_redis_py_cluster_legacy.py b/tests/integrations/redis_py_cluster_legacy/test_redis_py_cluster_legacy.py
new file mode 100644
index 0000000000..36a27d569d
--- /dev/null
+++ b/tests/integrations/redis_py_cluster_legacy/test_redis_py_cluster_legacy.py
@@ -0,0 +1,172 @@
+from unittest import mock
+
+import pytest
+import rediscluster
+
+from sentry_sdk import capture_message
+from sentry_sdk.api import start_transaction
+from sentry_sdk.consts import SPANDATA
+from sentry_sdk.integrations.redis import RedisIntegration
+from tests.conftest import ApproxDict
+
+
+MOCK_CONNECTION_POOL = mock.MagicMock()
+MOCK_CONNECTION_POOL.connection_kwargs = {
+ "host": "localhost",
+ "port": 63791,
+ "db": 1,
+}
+
+
+rediscluster_classes = [rediscluster.RedisCluster]
+
+if hasattr(rediscluster, "StrictRedisCluster"):
+ rediscluster_classes.append(rediscluster.StrictRedisCluster)
+
+
+@pytest.fixture(autouse=True)
+def monkeypatch_rediscluster_classes(reset_integrations):
+ try:
+ pipeline_cls = rediscluster.pipeline.ClusterPipeline
+ except AttributeError:
+ pipeline_cls = rediscluster.StrictClusterPipeline
+ rediscluster.RedisCluster.pipeline = lambda *_, **__: pipeline_cls(
+ connection_pool=MOCK_CONNECTION_POOL
+ )
+ pipeline_cls.execute = lambda *_, **__: None
+ for cls in rediscluster_classes:
+ cls.execute_command = lambda *_, **__: None
+
+
+@pytest.mark.parametrize("rediscluster_cls", rediscluster_classes)
+def test_rediscluster_basic(rediscluster_cls, sentry_init, capture_events):
+ sentry_init(integrations=[RedisIntegration()])
+ events = capture_events()
+
+ rc = rediscluster_cls(connection_pool=MOCK_CONNECTION_POOL)
+ rc.get("foobar")
+ capture_message("hi")
+
+ (event,) = events
+ (crumb,) = event["breadcrumbs"]["values"]
+
+ assert crumb == {
+ "category": "redis",
+ "message": "GET 'foobar'",
+ "data": ApproxDict(
+ {
+ "db.operation": "GET",
+ "redis.key": "foobar",
+ "redis.command": "GET",
+ "redis.is_cluster": True,
+ }
+ ),
+ "timestamp": crumb["timestamp"],
+ "type": "redis",
+ }
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, expected_first_ten",
+ [
+ (False, ["GET 'foo'", "SET 'bar' [Filtered]", "SET 'baz' [Filtered]"]),
+ (True, ["GET 'foo'", "SET 'bar' 1", "SET 'baz' 2"]),
+ ],
+)
+def test_rediscluster_pipeline(
+ sentry_init, capture_events, send_default_pii, expected_first_ten
+):
+ sentry_init(
+ integrations=[RedisIntegration()],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ events = capture_events()
+
+ rc = rediscluster.RedisCluster(connection_pool=MOCK_CONNECTION_POOL)
+ with start_transaction():
+ pipeline = rc.pipeline()
+ pipeline.get("foo")
+ pipeline.set("bar", 1)
+ pipeline.set("baz", 2)
+ pipeline.execute()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["op"] == "db.redis"
+ assert span["description"] == "redis.pipeline.execute"
+ assert span["data"] == ApproxDict(
+ {
+ "redis.commands": {
+ "count": 3,
+ "first_ten": expected_first_ten,
+ },
+ SPANDATA.DB_SYSTEM: "redis",
+ SPANDATA.DB_NAME: "1",
+ SPANDATA.SERVER_ADDRESS: "localhost",
+ SPANDATA.SERVER_PORT: 63791,
+ }
+ )
+ assert span["tags"] == {
+ "redis.transaction": False, # For Cluster, this is always False
+ "redis.is_cluster": True,
+ }
+
+
+@pytest.mark.parametrize("rediscluster_cls", rediscluster_classes)
+def test_db_connection_attributes_client(sentry_init, capture_events, rediscluster_cls):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[RedisIntegration()],
+ )
+ events = capture_events()
+
+ rc = rediscluster_cls(connection_pool=MOCK_CONNECTION_POOL)
+ with start_transaction():
+ rc.get("foobar")
+
+ (event,) = events
+ (span,) = event["spans"]
+
+ assert span["data"] == ApproxDict(
+ {
+ SPANDATA.DB_SYSTEM: "redis",
+ SPANDATA.DB_NAME: "1",
+ SPANDATA.SERVER_ADDRESS: "localhost",
+ SPANDATA.SERVER_PORT: 63791,
+ }
+ )
+
+
+@pytest.mark.parametrize("rediscluster_cls", rediscluster_classes)
+def test_db_connection_attributes_pipeline(
+ sentry_init, capture_events, rediscluster_cls
+):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[RedisIntegration()],
+ )
+ events = capture_events()
+
+ rc = rediscluster.RedisCluster(connection_pool=MOCK_CONNECTION_POOL)
+ with start_transaction():
+ pipeline = rc.pipeline()
+ pipeline.get("foo")
+ pipeline.execute()
+
+ (event,) = events
+ (span,) = event["spans"]
+ assert span["op"] == "db.redis"
+ assert span["description"] == "redis.pipeline.execute"
+ assert span["data"] == ApproxDict(
+ {
+ "redis.commands": {
+ "count": 1,
+ "first_ten": ["GET 'foo'"],
+ },
+ SPANDATA.DB_SYSTEM: "redis",
+ SPANDATA.DB_NAME: "1",
+ SPANDATA.SERVER_ADDRESS: "localhost",
+ SPANDATA.SERVER_PORT: 63791,
+ }
+ )
diff --git a/tests/integrations/rediscluster/test_rediscluster.py b/tests/integrations/rediscluster/test_rediscluster.py
deleted file mode 100644
index 6c7e5f90a4..0000000000
--- a/tests/integrations/rediscluster/test_rediscluster.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import pytest
-from sentry_sdk import capture_message
-from sentry_sdk.api import start_transaction
-from sentry_sdk.integrations.redis import RedisIntegration
-
-import rediscluster
-
-rediscluster_classes = [rediscluster.RedisCluster]
-
-if hasattr(rediscluster, "StrictRedisCluster"):
- rediscluster_classes.append(rediscluster.StrictRedisCluster)
-
-
-@pytest.fixture(autouse=True)
-def monkeypatch_rediscluster_classes(reset_integrations):
-
- try:
- pipeline_cls = rediscluster.pipeline.ClusterPipeline
- except AttributeError:
- pipeline_cls = rediscluster.StrictClusterPipeline
- rediscluster.RedisCluster.pipeline = lambda *_, **__: pipeline_cls(
- connection_pool=True
- )
- pipeline_cls.execute = lambda *_, **__: None
- for cls in rediscluster_classes:
- cls.execute_command = lambda *_, **__: None
-
-
-@pytest.mark.parametrize("rediscluster_cls", rediscluster_classes)
-def test_rediscluster_basic(rediscluster_cls, sentry_init, capture_events):
- sentry_init(integrations=[RedisIntegration()])
- events = capture_events()
-
- rc = rediscluster_cls(connection_pool=True)
- rc.get("foobar")
- capture_message("hi")
-
- (event,) = events
- (crumb,) = event["breadcrumbs"]["values"]
-
- assert crumb == {
- "category": "redis",
- "message": "GET 'foobar'",
- "data": {
- "redis.key": "foobar",
- "redis.command": "GET",
- "redis.is_cluster": True,
- },
- "timestamp": crumb["timestamp"],
- "type": "redis",
- }
-
-
-def test_rediscluster_pipeline(sentry_init, capture_events):
- sentry_init(integrations=[RedisIntegration()], traces_sample_rate=1.0)
- events = capture_events()
-
- rc = rediscluster.RedisCluster(connection_pool=True)
- with start_transaction():
- pipeline = rc.pipeline()
- pipeline.get("foo")
- pipeline.set("bar", 1)
- pipeline.set("baz", 2)
- pipeline.execute()
-
- (event,) = events
- (span,) = event["spans"]
- assert span["op"] == "db.redis"
- assert span["description"] == "redis.pipeline.execute"
- assert span["data"] == {
- "redis.commands": {
- "count": 3,
- "first_ten": ["GET 'foo'", "SET 'bar' 1", "SET 'baz' 2"],
- }
- }
- assert span["tags"] == {
- "redis.transaction": False, # For Cluster, this is always False
- "redis.is_cluster": True,
- }
diff --git a/tests/integrations/requests/__init__.py b/tests/integrations/requests/__init__.py
new file mode 100644
index 0000000000..a711908293
--- /dev/null
+++ b/tests/integrations/requests/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("requests")
diff --git a/tests/integrations/requests/test_requests.py b/tests/integrations/requests/test_requests.py
index 7070895dfc..8cfc0f932f 100644
--- a/tests/integrations/requests/test_requests.py
+++ b/tests/integrations/requests/test_requests.py
@@ -1,32 +1,114 @@
-import pytest
-import responses
+import sys
+from unittest import mock
-requests = pytest.importorskip("requests")
+import pytest
+import requests
from sentry_sdk import capture_message
+from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations.stdlib import StdlibIntegration
+from tests.conftest import ApproxDict, create_mock_http_server
+
+PORT = create_mock_http_server()
def test_crumb_capture(sentry_init, capture_events):
sentry_init(integrations=[StdlibIntegration()])
+ events = capture_events()
- url = "http://example.com/"
- responses.add(responses.GET, url, status=200)
+ url = f"http://localhost:{PORT}/hello-world" # noqa:E231
+ response = requests.get(url)
+ capture_message("Testing!")
+
+ (event,) = events
+ (crumb,) = event["breadcrumbs"]["values"]
+ assert crumb["type"] == "http"
+ assert crumb["category"] == "httplib"
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": url,
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_FRAGMENT: "",
+ SPANDATA.HTTP_QUERY: "",
+ SPANDATA.HTTP_STATUS_CODE: response.status_code,
+ "reason": response.reason,
+ }
+ )
+
+
+@pytest.mark.skipif(
+ sys.version_info < (3, 7),
+ reason="The response status is not set on the span early enough in 3.6",
+)
+@pytest.mark.parametrize(
+ "status_code,level",
+ [
+ (200, None),
+ (301, None),
+ (403, "warning"),
+ (405, "warning"),
+ (500, "error"),
+ ],
+)
+def test_crumb_capture_client_error(sentry_init, capture_events, status_code, level):
+ sentry_init(integrations=[StdlibIntegration()])
events = capture_events()
+ url = f"http://localhost:{PORT}/status/{status_code}" # noqa:E231
response = requests.get(url)
+
+ assert response.status_code == status_code
+
capture_message("Testing!")
(event,) = events
(crumb,) = event["breadcrumbs"]["values"]
assert crumb["type"] == "http"
assert crumb["category"] == "httplib"
- assert crumb["data"] == {
- "url": url,
- "method": "GET",
- "http.fragment": "",
- "http.query": "",
- "status_code": response.status_code,
- "reason": response.reason,
- }
+
+ if level is None:
+ assert "level" not in crumb
+ else:
+ assert crumb["level"] == level
+
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": url,
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_FRAGMENT: "",
+ SPANDATA.HTTP_QUERY: "",
+ SPANDATA.HTTP_STATUS_CODE: response.status_code,
+ "reason": response.reason,
+ }
+ )
+
+
+@pytest.mark.tests_internal_exceptions
+def test_omit_url_data_if_parsing_fails(sentry_init, capture_events):
+ sentry_init(integrations=[StdlibIntegration()])
+
+ events = capture_events()
+
+ url = f"http://localhost:{PORT}/ok" # noqa:E231
+
+ with mock.patch(
+ "sentry_sdk.integrations.stdlib.parse_url",
+ side_effect=ValueError,
+ ):
+ response = requests.get(url)
+
+ capture_message("Testing!")
+
+ (event,) = events
+ assert event["breadcrumbs"]["values"][0]["data"] == ApproxDict(
+ {
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_STATUS_CODE: response.status_code,
+ "reason": response.reason,
+ # no url related data
+ }
+ )
+ assert "url" not in event["breadcrumbs"]["values"][0]["data"]
+ assert SPANDATA.HTTP_FRAGMENT not in event["breadcrumbs"]["values"][0]["data"]
+ assert SPANDATA.HTTP_QUERY not in event["breadcrumbs"]["values"][0]["data"]
diff --git a/tests/integrations/rq/__init__.py b/tests/integrations/rq/__init__.py
index d9714d465a..9766a19465 100644
--- a/tests/integrations/rq/__init__.py
+++ b/tests/integrations/rq/__init__.py
@@ -1,3 +1,3 @@
import pytest
-rq = pytest.importorskip("rq")
+pytest.importorskip("rq")
diff --git a/tests/integrations/rq/test_rq.py b/tests/integrations/rq/test_rq.py
index fb25b65a03..23603ad91d 100644
--- a/tests/integrations/rq/test_rq.py
+++ b/tests/integrations/rq/test_rq.py
@@ -1,31 +1,37 @@
-import pytest
-from fakeredis import FakeStrictRedis
-from sentry_sdk.integrations.rq import RqIntegration
+from unittest import mock
+import pytest
import rq
+from fakeredis import FakeStrictRedis
-try:
- from unittest import mock # python 3.3 and above
-except ImportError:
- import mock # python < 3.3
+import sentry_sdk
+from sentry_sdk import start_transaction
+from sentry_sdk.integrations.rq import RqIntegration
+from sentry_sdk.utils import parse_version
@pytest.fixture(autouse=True)
def _patch_rq_get_server_version(monkeypatch):
"""
- Patch up RQ 1.5 to work with fakeredis.
+ Patch RQ lower than 1.5.1 to work with fakeredis.
https://github.com/jamesls/fakeredis/issues/273
"""
+ try:
+ from distutils.version import StrictVersion
+ except ImportError:
+ return
- from distutils.version import StrictVersion
-
- if tuple(map(int, rq.VERSION.split("."))) >= (1, 5):
+ if parse_version(rq.VERSION) <= (1, 5, 1):
for k in (
"rq.job.Job.get_redis_server_version",
"rq.worker.Worker.get_redis_server_version",
):
- monkeypatch.setattr(k, lambda _: StrictVersion("4.0.0"))
+ try:
+ monkeypatch.setattr(k, lambda _: StrictVersion("4.0.0"))
+ except AttributeError:
+ # old RQ Job/Worker doesn't have a get_redis_server_version attr
+ pass
def crashing_job(foo):
@@ -91,9 +97,10 @@ def test_transport_shutdown(sentry_init, capture_events_forksafe):
def test_transaction_with_error(
- sentry_init, capture_events, DictionaryContaining # noqa:N803
+ sentry_init,
+ capture_events,
+ DictionaryContaining, # noqa:N803
):
-
sentry_init(integrations=[RqIntegration()], traces_sample_rate=1.0)
events = capture_events()
@@ -126,8 +133,73 @@ def test_transaction_with_error(
)
+def test_error_has_trace_context_if_tracing_disabled(
+ sentry_init,
+ capture_events,
+):
+ sentry_init(integrations=[RqIntegration()])
+ events = capture_events()
+
+ queue = rq.Queue(connection=FakeStrictRedis())
+ worker = rq.SimpleWorker([queue], connection=queue.connection)
+
+ queue.enqueue(crashing_job, foo=None)
+ worker.work(burst=True)
+
+ (error_event,) = events
+
+ assert error_event["contexts"]["trace"]
+
+
+def test_tracing_enabled(
+ sentry_init,
+ capture_events,
+):
+ sentry_init(integrations=[RqIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ queue = rq.Queue(connection=FakeStrictRedis())
+ worker = rq.SimpleWorker([queue], connection=queue.connection)
+
+ with start_transaction(op="rq transaction") as transaction:
+ queue.enqueue(crashing_job, foo=None)
+ worker.work(burst=True)
+
+ error_event, envelope, _ = events
+
+ assert error_event["transaction"] == "tests.integrations.rq.test_rq.crashing_job"
+ assert error_event["contexts"]["trace"]["trace_id"] == transaction.trace_id
+
+ assert envelope["contexts"]["trace"] == error_event["contexts"]["trace"]
+
+
+def test_tracing_disabled(
+ sentry_init,
+ capture_events,
+):
+ sentry_init(integrations=[RqIntegration()])
+ events = capture_events()
+
+ queue = rq.Queue(connection=FakeStrictRedis())
+ worker = rq.SimpleWorker([queue], connection=queue.connection)
+
+ scope = sentry_sdk.get_isolation_scope()
+ queue.enqueue(crashing_job, foo=None)
+ worker.work(burst=True)
+
+ (error_event,) = events
+
+ assert error_event["transaction"] == "tests.integrations.rq.test_rq.crashing_job"
+ assert (
+ error_event["contexts"]["trace"]["trace_id"]
+ == scope._propagation_context.trace_id
+ )
+
+
def test_transaction_no_error(
- sentry_init, capture_events, DictionaryContaining # noqa:N803
+ sentry_init,
+ capture_events,
+ DictionaryContaining, # noqa:N803
):
sentry_init(integrations=[RqIntegration()], traces_sample_rate=1.0)
events = capture_events()
@@ -154,7 +226,9 @@ def test_transaction_no_error(
def test_traces_sampler_gets_correct_values_in_sampling_context(
- sentry_init, DictionaryContaining, ObjectDescribedBy # noqa:N803
+ sentry_init,
+ DictionaryContaining,
+ ObjectDescribedBy, # noqa:N803
):
traces_sampler = mock.Mock(return_value=True)
sentry_init(integrations=[RqIntegration()], traces_sampler=traces_sampler)
@@ -184,7 +258,7 @@ def test_traces_sampler_gets_correct_values_in_sampling_context(
@pytest.mark.skipif(
- rq.__version__.split(".") < ["1", "5"], reason="At least rq-1.5 required"
+ parse_version(rq.__version__) < (1, 5), reason="At least rq-1.5 required"
)
def test_job_with_retries(sentry_init, capture_events):
sentry_init(integrations=[RqIntegration()])
@@ -197,3 +271,18 @@ def test_job_with_retries(sentry_init, capture_events):
worker.work(burst=True)
assert len(events) == 1
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(integrations=[RqIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ queue = rq.Queue(connection=FakeStrictRedis())
+ worker = rq.SimpleWorker([queue], connection=queue.connection)
+
+ queue.enqueue(do_trick, "Maisey", trick="kangaroo")
+ worker.work(burst=True)
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.queue.rq"
diff --git a/tests/integrations/rust_tracing/__init__.py b/tests/integrations/rust_tracing/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/rust_tracing/test_rust_tracing.py b/tests/integrations/rust_tracing/test_rust_tracing.py
new file mode 100644
index 0000000000..a3e367f7e9
--- /dev/null
+++ b/tests/integrations/rust_tracing/test_rust_tracing.py
@@ -0,0 +1,473 @@
+from unittest import mock
+import pytest
+
+from string import Template
+from typing import Dict
+
+import sentry_sdk
+from sentry_sdk.integrations.rust_tracing import (
+ RustTracingIntegration,
+ RustTracingLayer,
+ RustTracingLevel,
+ EventTypeMapping,
+)
+from sentry_sdk import start_transaction, capture_message
+
+
+def _test_event_type_mapping(metadata: Dict[str, object]) -> EventTypeMapping:
+ level = RustTracingLevel(metadata.get("level"))
+ if level == RustTracingLevel.Error:
+ return EventTypeMapping.Exc
+ elif level in (RustTracingLevel.Warn, RustTracingLevel.Info):
+ return EventTypeMapping.Breadcrumb
+ elif level == RustTracingLevel.Debug:
+ return EventTypeMapping.Event
+ elif level == RustTracingLevel.Trace:
+ return EventTypeMapping.Ignore
+ else:
+ return EventTypeMapping.Ignore
+
+
+class FakeRustTracing:
+ # Parameters: `level`, `index`
+ span_template = Template(
+ """{"index":$index,"is_root":false,"metadata":{"fields":["index","use_memoized","version"],"file":"src/lib.rs","is_event":false,"is_span":true,"level":"$level","line":40,"module_path":"_bindings","name":"fibonacci","target":"_bindings"},"parent":null,"use_memoized":true}"""
+ )
+
+ # Parameters: `level`, `index`
+ event_template = Template(
+ """{"message":"Getting the ${index}th fibonacci number","metadata":{"fields":["message"],"file":"src/lib.rs","is_event":true,"is_span":false,"level":"$level","line":23,"module_path":"_bindings","name":"event src/lib.rs:23","target":"_bindings"}}"""
+ )
+
+ def __init__(self):
+ self.spans = {}
+
+ def set_layer_impl(self, layer: RustTracingLayer):
+ self.layer = layer
+
+ def new_span(self, level: RustTracingLevel, span_id: int, index_arg: int = 10):
+ span_attrs = self.span_template.substitute(level=level.value, index=index_arg)
+ state = self.layer.on_new_span(span_attrs, str(span_id))
+ self.spans[span_id] = state
+
+ def close_span(self, span_id: int):
+ state = self.spans.pop(span_id)
+ self.layer.on_close(str(span_id), state)
+
+ def event(self, level: RustTracingLevel, span_id: int, index_arg: int = 10):
+ event = self.event_template.substitute(level=level.value, index=index_arg)
+ state = self.spans[span_id]
+ self.layer.on_event(event, state)
+
+ def record(self, span_id: int):
+ state = self.spans[span_id]
+ self.layer.on_record(str(span_id), """{"version": "memoized"}""", state)
+
+
+def test_on_new_span_on_close(sentry_init, capture_events):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_on_new_span_on_close",
+ initializer=rust_tracing.set_layer_impl,
+ include_tracing_fields=True,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ events = capture_events()
+ with start_transaction():
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+
+ sentry_first_rust_span = sentry_sdk.get_current_span()
+ rust_first_rust_span = rust_tracing.spans[3]
+
+ assert sentry_first_rust_span == rust_first_rust_span
+
+ rust_tracing.close_span(3)
+ assert sentry_sdk.get_current_span() != sentry_first_rust_span
+
+ (event,) = events
+ assert len(event["spans"]) == 1
+
+ # Ensure the span metadata is wired up
+ span = event["spans"][0]
+ assert span["op"] == "function"
+ assert span["origin"] == "auto.function.rust_tracing.test_on_new_span_on_close"
+ assert span["description"] == "_bindings::fibonacci"
+
+ # Ensure the span was opened/closed appropriately
+ assert span["start_timestamp"] is not None
+ assert span["timestamp"] is not None
+
+ # Ensure the extra data from Rust is hooked up
+ data = span["data"]
+ assert data["use_memoized"]
+ assert data["index"] == 10
+ assert data["version"] is None
+
+
+def test_nested_on_new_span_on_close(sentry_init, capture_events):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_nested_on_new_span_on_close",
+ initializer=rust_tracing.set_layer_impl,
+ include_tracing_fields=True,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ events = capture_events()
+ with start_transaction():
+ original_sentry_span = sentry_sdk.get_current_span()
+
+ rust_tracing.new_span(RustTracingLevel.Info, 3, index_arg=10)
+ sentry_first_rust_span = sentry_sdk.get_current_span()
+ rust_first_rust_span = rust_tracing.spans[3]
+
+ # Use a different `index_arg` value for the inner span to help
+ # distinguish the two at the end of the test
+ rust_tracing.new_span(RustTracingLevel.Info, 5, index_arg=9)
+ sentry_second_rust_span = sentry_sdk.get_current_span()
+ rust_second_rust_span = rust_tracing.spans[5]
+
+ assert rust_second_rust_span == sentry_second_rust_span
+
+ rust_tracing.close_span(5)
+
+ # Ensure the current sentry span was moved back to the parent
+ sentry_span_after_close = sentry_sdk.get_current_span()
+ assert sentry_span_after_close == sentry_first_rust_span
+ assert sentry_span_after_close == rust_first_rust_span
+
+ rust_tracing.close_span(3)
+
+ assert sentry_sdk.get_current_span() == original_sentry_span
+
+ (event,) = events
+ assert len(event["spans"]) == 2
+
+ # Ensure the span metadata is wired up for all spans
+ first_span, second_span = event["spans"]
+ assert first_span["op"] == "function"
+ assert (
+ first_span["origin"]
+ == "auto.function.rust_tracing.test_nested_on_new_span_on_close"
+ )
+ assert first_span["description"] == "_bindings::fibonacci"
+ assert second_span["op"] == "function"
+ assert (
+ second_span["origin"]
+ == "auto.function.rust_tracing.test_nested_on_new_span_on_close"
+ )
+ assert second_span["description"] == "_bindings::fibonacci"
+
+ # Ensure the spans were opened/closed appropriately
+ assert first_span["start_timestamp"] is not None
+ assert first_span["timestamp"] is not None
+ assert second_span["start_timestamp"] is not None
+ assert second_span["timestamp"] is not None
+
+ # Ensure the extra data from Rust is hooked up in both spans
+ first_span_data = first_span["data"]
+ assert first_span_data["use_memoized"]
+ assert first_span_data["index"] == 10
+ assert first_span_data["version"] is None
+
+ second_span_data = second_span["data"]
+ assert second_span_data["use_memoized"]
+ assert second_span_data["index"] == 9
+ assert second_span_data["version"] is None
+
+
+def test_on_new_span_without_transaction(sentry_init):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_on_new_span_without_transaction", rust_tracing.set_layer_impl
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ assert sentry_sdk.get_current_span() is None
+
+ # Should still create a span hierarchy, it just will not be under a txn
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+ current_span = sentry_sdk.get_current_span()
+ assert current_span is not None
+ assert current_span.containing_transaction is None
+
+
+def test_on_event_exception(sentry_init, capture_events):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_on_event_exception",
+ rust_tracing.set_layer_impl,
+ event_type_mapping=_test_event_type_mapping,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ events = capture_events()
+ sentry_sdk.get_isolation_scope().clear_breadcrumbs()
+
+ with start_transaction():
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+
+ # Mapped to Exception
+ rust_tracing.event(RustTracingLevel.Error, 3)
+
+ rust_tracing.close_span(3)
+
+ assert len(events) == 2
+ exc, _tx = events
+ assert exc["level"] == "error"
+ assert exc["logger"] == "_bindings"
+ assert exc["message"] == "Getting the 10th fibonacci number"
+ assert exc["breadcrumbs"]["values"] == []
+
+ location_context = exc["contexts"]["rust_tracing_location"]
+ assert location_context["module_path"] == "_bindings"
+ assert location_context["file"] == "src/lib.rs"
+ assert location_context["line"] == 23
+
+ field_context = exc["contexts"]["rust_tracing_fields"]
+ assert field_context["message"] == "Getting the 10th fibonacci number"
+
+
+def test_on_event_breadcrumb(sentry_init, capture_events):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_on_event_breadcrumb",
+ rust_tracing.set_layer_impl,
+ event_type_mapping=_test_event_type_mapping,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ events = capture_events()
+ sentry_sdk.get_isolation_scope().clear_breadcrumbs()
+
+ with start_transaction():
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+
+ # Mapped to Breadcrumb
+ rust_tracing.event(RustTracingLevel.Info, 3)
+
+ rust_tracing.close_span(3)
+ capture_message("test message")
+
+ assert len(events) == 2
+ message, _tx = events
+
+ breadcrumbs = message["breadcrumbs"]["values"]
+ assert len(breadcrumbs) == 1
+ assert breadcrumbs[0]["level"] == "info"
+ assert breadcrumbs[0]["message"] == "Getting the 10th fibonacci number"
+ assert breadcrumbs[0]["type"] == "default"
+
+
+def test_on_event_event(sentry_init, capture_events):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_on_event_event",
+ rust_tracing.set_layer_impl,
+ event_type_mapping=_test_event_type_mapping,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ events = capture_events()
+ sentry_sdk.get_isolation_scope().clear_breadcrumbs()
+
+ with start_transaction():
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+
+ # Mapped to Event
+ rust_tracing.event(RustTracingLevel.Debug, 3)
+
+ rust_tracing.close_span(3)
+
+ assert len(events) == 2
+ event, _tx = events
+
+ assert event["logger"] == "_bindings"
+ assert event["level"] == "debug"
+ assert event["message"] == "Getting the 10th fibonacci number"
+ assert event["breadcrumbs"]["values"] == []
+
+ location_context = event["contexts"]["rust_tracing_location"]
+ assert location_context["module_path"] == "_bindings"
+ assert location_context["file"] == "src/lib.rs"
+ assert location_context["line"] == 23
+
+ field_context = event["contexts"]["rust_tracing_fields"]
+ assert field_context["message"] == "Getting the 10th fibonacci number"
+
+
+def test_on_event_ignored(sentry_init, capture_events):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_on_event_ignored",
+ rust_tracing.set_layer_impl,
+ event_type_mapping=_test_event_type_mapping,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ events = capture_events()
+ sentry_sdk.get_isolation_scope().clear_breadcrumbs()
+
+ with start_transaction():
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+
+ # Ignored
+ rust_tracing.event(RustTracingLevel.Trace, 3)
+
+ rust_tracing.close_span(3)
+
+ assert len(events) == 1
+ (tx,) = events
+ assert tx["type"] == "transaction"
+ assert "message" not in tx
+
+
+def test_span_filter(sentry_init, capture_events):
+ def span_filter(metadata: Dict[str, object]) -> bool:
+ return RustTracingLevel(metadata.get("level")) in (
+ RustTracingLevel.Error,
+ RustTracingLevel.Warn,
+ RustTracingLevel.Info,
+ RustTracingLevel.Debug,
+ )
+
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_span_filter",
+ initializer=rust_tracing.set_layer_impl,
+ span_filter=span_filter,
+ include_tracing_fields=True,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ events = capture_events()
+ with start_transaction():
+ original_sentry_span = sentry_sdk.get_current_span()
+
+ # Span is not ignored
+ rust_tracing.new_span(RustTracingLevel.Info, 3, index_arg=10)
+ info_span = sentry_sdk.get_current_span()
+
+ # Span is ignored, current span should remain the same
+ rust_tracing.new_span(RustTracingLevel.Trace, 5, index_arg=9)
+ assert sentry_sdk.get_current_span() == info_span
+
+ # Closing the filtered span should leave the current span alone
+ rust_tracing.close_span(5)
+ assert sentry_sdk.get_current_span() == info_span
+
+ rust_tracing.close_span(3)
+ assert sentry_sdk.get_current_span() == original_sentry_span
+
+ (event,) = events
+ assert len(event["spans"]) == 1
+ # The ignored span has index == 9
+ assert event["spans"][0]["data"]["index"] == 10
+
+
+def test_record(sentry_init):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_record",
+ initializer=rust_tracing.set_layer_impl,
+ include_tracing_fields=True,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ with start_transaction():
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+
+ span_before_record = sentry_sdk.get_current_span().to_json()
+ assert span_before_record["data"]["version"] is None
+
+ rust_tracing.record(3)
+
+ span_after_record = sentry_sdk.get_current_span().to_json()
+ assert span_after_record["data"]["version"] == "memoized"
+
+
+def test_record_in_ignored_span(sentry_init):
+ def span_filter(metadata: Dict[str, object]) -> bool:
+ # Just ignore Trace
+ return RustTracingLevel(metadata.get("level")) != RustTracingLevel.Trace
+
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_record_in_ignored_span",
+ rust_tracing.set_layer_impl,
+ span_filter=span_filter,
+ include_tracing_fields=True,
+ )
+ sentry_init(integrations=[integration], traces_sample_rate=1.0)
+
+ with start_transaction():
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+
+ span_before_record = sentry_sdk.get_current_span().to_json()
+ assert span_before_record["data"]["version"] is None
+
+ rust_tracing.new_span(RustTracingLevel.Trace, 5)
+ rust_tracing.record(5)
+
+ # `on_record()` should not do anything to the current Sentry span if the associated Rust span was ignored
+ span_after_record = sentry_sdk.get_current_span().to_json()
+ assert span_after_record["data"]["version"] is None
+
+
+@pytest.mark.parametrize(
+ "send_default_pii, include_tracing_fields, tracing_fields_expected",
+ [
+ (True, True, True),
+ (True, False, False),
+ (True, None, True),
+ (False, True, True),
+ (False, False, False),
+ (False, None, False),
+ ],
+)
+def test_include_tracing_fields(
+ sentry_init, send_default_pii, include_tracing_fields, tracing_fields_expected
+):
+ rust_tracing = FakeRustTracing()
+ integration = RustTracingIntegration(
+ "test_record",
+ initializer=rust_tracing.set_layer_impl,
+ include_tracing_fields=include_tracing_fields,
+ )
+
+ sentry_init(
+ integrations=[integration],
+ traces_sample_rate=1.0,
+ send_default_pii=send_default_pii,
+ )
+ with start_transaction():
+ rust_tracing.new_span(RustTracingLevel.Info, 3)
+
+ span_before_record = sentry_sdk.get_current_span().to_json()
+ if tracing_fields_expected:
+ assert span_before_record["data"]["version"] is None
+ else:
+ assert span_before_record["data"]["version"] == "[Filtered]"
+
+ rust_tracing.record(3)
+
+ span_after_record = sentry_sdk.get_current_span().to_json()
+
+ if tracing_fields_expected:
+ assert span_after_record["data"] == {
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ "use_memoized": True,
+ "version": "memoized",
+ "index": 10,
+ }
+
+ else:
+ assert span_after_record["data"] == {
+ "thread.id": mock.ANY,
+ "thread.name": mock.ANY,
+ "use_memoized": "[Filtered]",
+ "version": "[Filtered]",
+ "index": "[Filtered]",
+ }
diff --git a/tests/integrations/sanic/__init__.py b/tests/integrations/sanic/__init__.py
index 53449e2f0e..d6b67797a3 100644
--- a/tests/integrations/sanic/__init__.py
+++ b/tests/integrations/sanic/__init__.py
@@ -1,3 +1,3 @@
import pytest
-sanic = pytest.importorskip("sanic")
+pytest.importorskip("sanic")
diff --git a/tests/integrations/sanic/test_sanic.py b/tests/integrations/sanic/test_sanic.py
index de84845cf4..ff1c5efa26 100644
--- a/tests/integrations/sanic/test_sanic.py
+++ b/tests/integrations/sanic/test_sanic.py
@@ -1,19 +1,39 @@
+import asyncio
+import contextlib
import os
-import sys
import random
-import asyncio
+import sys
from unittest.mock import Mock
import pytest
-from sentry_sdk import capture_message, configure_scope
+import sentry_sdk
+from sentry_sdk import capture_message
from sentry_sdk.integrations.sanic import SanicIntegration
+from sentry_sdk.tracing import TransactionSource
from sanic import Sanic, request, response, __version__ as SANIC_VERSION_RAW
from sanic.response import HTTPResponse
from sanic.exceptions import SanicException
+try:
+ from sanic_testing import TestManager
+except ImportError:
+ TestManager = None
+
+try:
+ from sanic_testing.reusable import ReusableClient
+except ImportError:
+ ReusableClient = None
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from collections.abc import Iterable, Container
+ from typing import Any, Optional
+
SANIC_VERSION = tuple(map(int, SANIC_VERSION_RAW.split(".")))
+PERFORMANCE_SUPPORTED = SANIC_VERSION >= (21, 9)
@pytest.fixture
@@ -35,29 +55,49 @@ def new_test_client(self):
if SANIC_VERSION >= (20, 12) and SANIC_VERSION < (22, 6):
# Some builds (20.12.0 intruduced and 22.6.0 removed again) have a feature where the instance is stored in an internal class
# registry for later retrieval, and so add register=False to disable that
- app = Sanic("Test", register=False)
+ sanic_app = Sanic("Test", register=False)
else:
- app = Sanic("Test")
+ sanic_app = Sanic("Test")
+
+ if TestManager is not None:
+ TestManager(sanic_app)
- @app.route("/message")
+ @sanic_app.route("/message")
def hi(request):
capture_message("hi")
return response.text("ok")
- @app.route("/message/")
+ @sanic_app.route("/message/")
def hi_with_id(request, message_id):
capture_message("hi with id")
return response.text("ok with id")
- return app
+ @sanic_app.route("/500")
+ def fivehundred(_):
+ 1 / 0
+
+ return sanic_app
+
+
+def get_client(app):
+ @contextlib.contextmanager
+ def simple_client(app):
+ yield app.test_client
+
+ if ReusableClient is not None:
+ return ReusableClient(app)
+ else:
+ return simple_client(app)
def test_request_data(sentry_init, app, capture_events):
sentry_init(integrations=[SanicIntegration()])
events = capture_events()
- request, response = app.test_client.get("/message?foo=bar")
- assert response.status == 200
+ c = get_client(app)
+ with c as client:
+ _, response = client.get("/message?foo=bar")
+ assert response.status == 200
(event,) = events
assert event["transaction"] == "hi"
@@ -88,14 +128,16 @@ def test_request_data(sentry_init, app, capture_events):
("/message/123456", "hi_with_id", "component"),
],
)
-def test_transaction(
+def test_transaction_name(
sentry_init, app, capture_events, url, expected_transaction, expected_source
):
sentry_init(integrations=[SanicIntegration()])
events = capture_events()
- request, response = app.test_client.get(url)
- assert response.status == 200
+ c = get_client(app)
+ with c as client:
+ _, response = client.get(url)
+ assert response.status == 200
(event,) = events
assert event["transaction"] == expected_transaction
@@ -110,8 +152,10 @@ def test_errors(sentry_init, app, capture_events):
def myerror(request):
raise ValueError("oh no")
- request, response = app.test_client.get("/error")
- assert response.status == 500
+ c = get_client(app)
+ with c as client:
+ _, response = client.get("/error")
+ assert response.status == 500
(event,) = events
assert event["transaction"] == "myerror"
@@ -133,8 +177,10 @@ def test_bad_request_not_captured(sentry_init, app, capture_events):
def index(request):
raise SanicException("...", status_code=400)
- request, response = app.test_client.get("/")
- assert response.status == 400
+ c = get_client(app)
+ with c as client:
+ _, response = client.get("/")
+ assert response.status == 400
assert not events
@@ -151,8 +197,10 @@ def myerror(request):
def myhandler(request, exception):
1 / 0
- request, response = app.test_client.get("/error")
- assert response.status == 500
+ c = get_client(app)
+ with c as client:
+ _, response = client.get("/error")
+ assert response.status == 500
event1, event2 = events
@@ -182,18 +230,17 @@ def test_concurrency(sentry_init, app):
because that's the only way we could reproduce leakage with such a low
amount of concurrent tasks.
"""
-
sentry_init(integrations=[SanicIntegration()])
@app.route("/context-check/")
async def context_check(request, i):
- with configure_scope() as scope:
- scope.set_tag("i", i)
+ scope = sentry_sdk.get_isolation_scope()
+ scope.set_tag("i", i)
await asyncio.sleep(random.random())
- with configure_scope() as scope:
- assert scope._tags["i"] == i
+ scope = sentry_sdk.get_isolation_scope()
+ assert scope._tags["i"] == i
return response.text("ok")
@@ -282,5 +329,136 @@ async def runner():
else:
asyncio.run(runner())
- with configure_scope() as scope:
- assert not scope._tags
+ scope = sentry_sdk.get_isolation_scope()
+ assert not scope._tags
+
+
+class TransactionTestConfig:
+ """
+ Data class to store configurations for each performance transaction test run, including
+ both the inputs and relevant expected results.
+ """
+
+ def __init__(
+ self,
+ integration_args: "Iterable[Optional[Container[int]]]",
+ url: str,
+ expected_status: int,
+ expected_transaction_name: "Optional[str]",
+ expected_source: "Optional[str]" = None,
+ ) -> None:
+ """
+ expected_transaction_name of None indicates we expect to not receive a transaction
+ """
+ self.integration_args = integration_args
+ self.url = url
+ self.expected_status = expected_status
+ self.expected_transaction_name = expected_transaction_name
+ self.expected_source = expected_source
+
+
+@pytest.mark.skipif(
+ not PERFORMANCE_SUPPORTED, reason="Performance not supported on this Sanic version"
+)
+@pytest.mark.parametrize(
+ "test_config",
+ [
+ TransactionTestConfig(
+ # Transaction for successful page load
+ integration_args=(),
+ url="/message",
+ expected_status=200,
+ expected_transaction_name="hi",
+ expected_source=TransactionSource.COMPONENT,
+ ),
+ TransactionTestConfig(
+ # Transaction still recorded when we have an internal server error
+ integration_args=(),
+ url="/500",
+ expected_status=500,
+ expected_transaction_name="fivehundred",
+ expected_source=TransactionSource.COMPONENT,
+ ),
+ TransactionTestConfig(
+ # By default, no transaction when we have a 404 error
+ integration_args=(),
+ url="/404",
+ expected_status=404,
+ expected_transaction_name=None,
+ ),
+ TransactionTestConfig(
+ # With no ignored HTTP statuses, we should get transactions for 404 errors
+ integration_args=(None,),
+ url="/404",
+ expected_status=404,
+ expected_transaction_name="/404",
+ expected_source=TransactionSource.URL,
+ ),
+ TransactionTestConfig(
+ # Transaction can be suppressed for other HTTP statuses, too, by passing config to the integration
+ integration_args=({200},),
+ url="/message",
+ expected_status=200,
+ expected_transaction_name=None,
+ ),
+ ],
+)
+def test_transactions(
+ test_config: "TransactionTestConfig",
+ sentry_init: "Any",
+ app: "Any",
+ capture_events: "Any",
+) -> None:
+ # Init the SanicIntegration with the desired arguments
+ sentry_init(
+ integrations=[SanicIntegration(*test_config.integration_args)],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ # Make request to the desired URL
+ c = get_client(app)
+ with c as client:
+ _, response = client.get(test_config.url)
+ assert response.status == test_config.expected_status
+
+ # Extract the transaction events by inspecting the event types. We should at most have 1 transaction event.
+ transaction_events = [
+ e for e in events if "type" in e and e["type"] == "transaction"
+ ]
+ assert len(transaction_events) <= 1
+
+ # Get the only transaction event, or set to None if there are no transaction events.
+ (transaction_event, *_) = [*transaction_events, None]
+
+ # We should have no transaction event if and only if we expect no transactions
+ assert (transaction_event is None) == (
+ test_config.expected_transaction_name is None
+ )
+
+ # If a transaction was expected, ensure it is correct
+ assert (
+ transaction_event is None
+ or transaction_event["transaction"] == test_config.expected_transaction_name
+ )
+ assert (
+ transaction_event is None
+ or transaction_event["transaction_info"]["source"]
+ == test_config.expected_source
+ )
+
+
+@pytest.mark.skipif(
+ not PERFORMANCE_SUPPORTED, reason="Performance not supported on this Sanic version"
+)
+def test_span_origin(sentry_init, app, capture_events):
+ sentry_init(integrations=[SanicIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ c = get_client(app)
+ with c as client:
+ client.get("/message?foo=bar")
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.sanic"
diff --git a/tests/integrations/serverless/test_serverless.py b/tests/integrations/serverless/test_serverless.py
index cc578ff4c4..a0a33e31ec 100644
--- a/tests/integrations/serverless/test_serverless.py
+++ b/tests/integrations/serverless/test_serverless.py
@@ -11,9 +11,7 @@ def test_basic(sentry_init, capture_exceptions, monkeypatch):
@serverless_function
def foo():
- monkeypatch.setattr(
- "sentry_sdk.Hub.current.flush", lambda: flush_calls.append(1)
- )
+ monkeypatch.setattr("sentry_sdk.flush", lambda: flush_calls.append(1))
1 / 0
with pytest.raises(ZeroDivisionError):
@@ -31,7 +29,7 @@ def test_flush_disabled(sentry_init, capture_exceptions, monkeypatch):
flush_calls = []
- monkeypatch.setattr("sentry_sdk.Hub.current.flush", lambda: flush_calls.append(1))
+ monkeypatch.setattr("sentry_sdk.flush", lambda: flush_calls.append(1))
@serverless_function(flush=False)
def foo():
diff --git a/tests/integrations/socket/test_socket.py b/tests/integrations/socket/test_socket.py
index 914ba0bf84..cc109e0968 100644
--- a/tests/integrations/socket/test_socket.py
+++ b/tests/integrations/socket/test_socket.py
@@ -2,6 +2,9 @@
from sentry_sdk import start_transaction
from sentry_sdk.integrations.socket import SocketIntegration
+from tests.conftest import ApproxDict, create_mock_http_server
+
+PORT = create_mock_http_server()
def test_getaddrinfo_trace(sentry_init, capture_events):
@@ -9,17 +12,19 @@ def test_getaddrinfo_trace(sentry_init, capture_events):
events = capture_events()
with start_transaction():
- socket.getaddrinfo("example.com", 443)
+ socket.getaddrinfo("localhost", PORT)
(event,) = events
(span,) = event["spans"]
assert span["op"] == "socket.dns"
- assert span["description"] == "example.com:443"
- assert span["data"] == {
- "host": "example.com",
- "port": 443,
- }
+ assert span["description"] == f"localhost:{PORT}" # noqa: E231
+ assert span["data"] == ApproxDict(
+ {
+ "host": "localhost",
+ "port": PORT,
+ }
+ )
def test_create_connection_trace(sentry_init, capture_events):
@@ -29,23 +34,48 @@ def test_create_connection_trace(sentry_init, capture_events):
events = capture_events()
with start_transaction():
- socket.create_connection(("example.com", 443), timeout, None)
+ socket.create_connection(("localhost", PORT), timeout, None)
(event,) = events
(connect_span, dns_span) = event["spans"]
# as getaddrinfo gets called in create_connection it should also contain a dns span
assert connect_span["op"] == "socket.connection"
- assert connect_span["description"] == "example.com:443"
- assert connect_span["data"] == {
- "address": ["example.com", 443],
- "timeout": timeout,
- "source_address": None,
- }
+ assert connect_span["description"] == f"localhost:{PORT}" # noqa: E231
+ assert connect_span["data"] == ApproxDict(
+ {
+ "address": ["localhost", PORT],
+ "timeout": timeout,
+ "source_address": None,
+ }
+ )
assert dns_span["op"] == "socket.dns"
- assert dns_span["description"] == "example.com:443"
- assert dns_span["data"] == {
- "host": "example.com",
- "port": 443,
- }
+ assert dns_span["description"] == f"localhost:{PORT}" # noqa: E231
+ assert dns_span["data"] == ApproxDict(
+ {
+ "host": "localhost",
+ "port": PORT,
+ }
+ )
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[SocketIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ socket.create_connection(("localhost", PORT), 1, None)
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ assert event["spans"][0]["op"] == "socket.connection"
+ assert event["spans"][0]["origin"] == "auto.socket.socket"
+
+ assert event["spans"][1]["op"] == "socket.dns"
+ assert event["spans"][1]["origin"] == "auto.socket.socket"
diff --git a/tests/integrations/spark/__init__.py b/tests/integrations/spark/__init__.py
new file mode 100644
index 0000000000..aa6d24a492
--- /dev/null
+++ b/tests/integrations/spark/__init__.py
@@ -0,0 +1,4 @@
+import pytest
+
+pytest.importorskip("pyspark")
+pytest.importorskip("py4j")
diff --git a/tests/integrations/spark/test_spark.py b/tests/integrations/spark/test_spark.py
index 00c0055f12..c5bb70f4d1 100644
--- a/tests/integrations/spark/test_spark.py
+++ b/tests/integrations/spark/test_spark.py
@@ -1,28 +1,43 @@
import pytest
import sys
+from unittest.mock import patch
+
from sentry_sdk.integrations.spark.spark_driver import (
_set_app_properties,
_start_sentry_listener,
SentryListener,
+ SparkIntegration,
)
-
from sentry_sdk.integrations.spark.spark_worker import SparkWorkerIntegration
-
-pytest.importorskip("pyspark")
-pytest.importorskip("py4j")
-
-from pyspark import SparkContext
+from pyspark import SparkConf, SparkContext
from py4j.protocol import Py4JJavaError
+
################
# DRIVER TESTS #
################
-def test_set_app_properties():
- spark_context = SparkContext(appName="Testing123")
+@pytest.fixture(scope="function")
+def sentry_init_with_reset(sentry_init):
+ from sentry_sdk.integrations import _processed_integrations
+
+ yield lambda: sentry_init(integrations=[SparkIntegration()])
+ _processed_integrations.discard("spark")
+
+
+@pytest.fixture(scope="session")
+def create_spark_context():
+ conf = SparkConf().set("spark.driver.bindAddress", "127.0.0.1")
+ sc = SparkContext(conf=conf, appName="Testing123")
+ yield lambda: sc
+ sc.stop()
+
+
+def test_set_app_properties(create_spark_context):
+ spark_context = create_spark_context()
_set_app_properties()
assert spark_context.getLocalProperty("sentry_app_name") == "Testing123"
@@ -33,9 +48,8 @@ def test_set_app_properties():
)
-def test_start_sentry_listener():
- spark_context = SparkContext.getOrCreate()
-
+def test_start_sentry_listener(create_spark_context):
+ spark_context = create_spark_context()
gateway = spark_context._gateway
assert gateway._callback_server is None
@@ -44,90 +58,179 @@ def test_start_sentry_listener():
assert gateway._callback_server is not None
-@pytest.fixture
-def sentry_listener(monkeypatch):
- class MockHub:
- def __init__(self):
- self.args = []
- self.kwargs = {}
+@patch("sentry_sdk.integrations.spark.spark_driver._patch_spark_context_init")
+def test_initialize_spark_integration_before_spark_context_init(
+ mock_patch_spark_context_init,
+ sentry_init_with_reset,
+):
+ # As we are using the same SparkContext connection for the whole session,
+ # we clean it during this test.
+ original_context = SparkContext._active_spark_context
+ SparkContext._active_spark_context = None
+
+ try:
+ sentry_init_with_reset()
+ mock_patch_spark_context_init.assert_called_once()
+ finally:
+ # Restore the original one.
+ SparkContext._active_spark_context = original_context
+
+
+@patch("sentry_sdk.integrations.spark.spark_driver._activate_integration")
+def test_initialize_spark_integration_after_spark_context_init(
+ mock_activate_integration,
+ create_spark_context,
+ sentry_init_with_reset,
+):
+ create_spark_context()
+ sentry_init_with_reset()
- def add_breadcrumb(self, *args, **kwargs):
- self.args = args
- self.kwargs = kwargs
+ mock_activate_integration.assert_called_once()
- listener = SentryListener()
- mock_hub = MockHub()
- monkeypatch.setattr(listener, "hub", mock_hub)
+@pytest.fixture
+def sentry_listener():
+ listener = SentryListener()
- return listener, mock_hub
+ return listener
def test_sentry_listener_on_job_start(sentry_listener):
- listener, mock_hub = sentry_listener
+ listener = sentry_listener
+ with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb:
- class MockJobStart:
- def jobId(self): # noqa: N802
- return "sample-job-id-start"
+ class MockJobStart:
+ def jobId(self): # noqa: N802
+ return "sample-job-id-start"
- mock_job_start = MockJobStart()
- listener.onJobStart(mock_job_start)
+ mock_job_start = MockJobStart()
+ listener.onJobStart(mock_job_start)
- assert mock_hub.kwargs["level"] == "info"
- assert "sample-job-id-start" in mock_hub.kwargs["message"]
+ mock_add_breadcrumb.assert_called_once()
+ mock_hub = mock_add_breadcrumb.call_args
+
+ assert mock_hub.kwargs["level"] == "info"
+ assert "sample-job-id-start" in mock_hub.kwargs["message"]
@pytest.mark.parametrize(
"job_result, level", [("JobSucceeded", "info"), ("JobFailed", "warning")]
)
def test_sentry_listener_on_job_end(sentry_listener, job_result, level):
- listener, mock_hub = sentry_listener
+ listener = sentry_listener
+ with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb:
+
+ class MockJobResult:
+ def toString(self): # noqa: N802
+ return job_result
- class MockJobResult:
- def toString(self): # noqa: N802
- return job_result
+ class MockJobEnd:
+ def jobId(self): # noqa: N802
+ return "sample-job-id-end"
- class MockJobEnd:
- def jobId(self): # noqa: N802
- return "sample-job-id-end"
+ def jobResult(self): # noqa: N802
+ result = MockJobResult()
+ return result
- def jobResult(self): # noqa: N802
- result = MockJobResult()
- return result
+ mock_job_end = MockJobEnd()
+ listener.onJobEnd(mock_job_end)
- mock_job_end = MockJobEnd()
- listener.onJobEnd(mock_job_end)
+ mock_add_breadcrumb.assert_called_once()
+ mock_hub = mock_add_breadcrumb.call_args
- assert mock_hub.kwargs["level"] == level
- assert mock_hub.kwargs["data"]["result"] == job_result
- assert "sample-job-id-end" in mock_hub.kwargs["message"]
+ assert mock_hub.kwargs["level"] == level
+ assert mock_hub.kwargs["data"]["result"] == job_result
+ assert "sample-job-id-end" in mock_hub.kwargs["message"]
def test_sentry_listener_on_stage_submitted(sentry_listener):
- listener, mock_hub = sentry_listener
+ listener = sentry_listener
+ with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb:
+
+ class StageInfo:
+ def stageId(self): # noqa: N802
+ return "sample-stage-id-submit"
+
+ def name(self):
+ return "run-job"
+
+ def attemptId(self): # noqa: N802
+ return 14
+
+ class MockStageSubmitted:
+ def stageInfo(self): # noqa: N802
+ stageinf = StageInfo()
+ return stageinf
+
+ mock_stage_submitted = MockStageSubmitted()
+ listener.onStageSubmitted(mock_stage_submitted)
+
+ mock_add_breadcrumb.assert_called_once()
+ mock_hub = mock_add_breadcrumb.call_args
+
+ assert mock_hub.kwargs["level"] == "info"
+ assert "sample-stage-id-submit" in mock_hub.kwargs["message"]
+ assert mock_hub.kwargs["data"]["attemptId"] == 14
+ assert mock_hub.kwargs["data"]["name"] == "run-job"
+
+
+def test_sentry_listener_on_stage_submitted_no_attempt_id(sentry_listener):
+ listener = sentry_listener
+ with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb:
+
+ class StageInfo:
+ def stageId(self): # noqa: N802
+ return "sample-stage-id-submit"
+
+ def name(self):
+ return "run-job"
+
+ def attemptNumber(self): # noqa: N802
+ return 14
- class StageInfo:
- def stageId(self): # noqa: N802
- return "sample-stage-id-submit"
+ class MockStageSubmitted:
+ def stageInfo(self): # noqa: N802
+ stageinf = StageInfo()
+ return stageinf
+
+ mock_stage_submitted = MockStageSubmitted()
+ listener.onStageSubmitted(mock_stage_submitted)
- def name(self):
- return "run-job"
+ mock_add_breadcrumb.assert_called_once()
+ mock_hub = mock_add_breadcrumb.call_args
- def attemptId(self): # noqa: N802
- return 14
+ assert mock_hub.kwargs["level"] == "info"
+ assert "sample-stage-id-submit" in mock_hub.kwargs["message"]
+ assert mock_hub.kwargs["data"]["attemptId"] == 14
+ assert mock_hub.kwargs["data"]["name"] == "run-job"
- class MockStageSubmitted:
- def stageInfo(self): # noqa: N802
- stageinf = StageInfo()
- return stageinf
- mock_stage_submitted = MockStageSubmitted()
- listener.onStageSubmitted(mock_stage_submitted)
+def test_sentry_listener_on_stage_submitted_no_attempt_id_or_number(sentry_listener):
+ listener = sentry_listener
+ with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb:
- assert mock_hub.kwargs["level"] == "info"
- assert "sample-stage-id-submit" in mock_hub.kwargs["message"]
- assert mock_hub.kwargs["data"]["attemptId"] == 14
- assert mock_hub.kwargs["data"]["name"] == "run-job"
+ class StageInfo:
+ def stageId(self): # noqa: N802
+ return "sample-stage-id-submit"
+
+ def name(self):
+ return "run-job"
+
+ class MockStageSubmitted:
+ def stageInfo(self): # noqa: N802
+ stageinf = StageInfo()
+ return stageinf
+
+ mock_stage_submitted = MockStageSubmitted()
+ listener.onStageSubmitted(mock_stage_submitted)
+
+ mock_add_breadcrumb.assert_called_once()
+ mock_hub = mock_add_breadcrumb.call_args
+
+ assert mock_hub.kwargs["level"] == "info"
+ assert "sample-stage-id-submit" in mock_hub.kwargs["message"]
+ assert "attemptId" not in mock_hub.kwargs["data"]
+ assert mock_hub.kwargs["data"]["name"] == "run-job"
@pytest.fixture
@@ -169,31 +272,37 @@ def stageInfo(self): # noqa: N802
def test_sentry_listener_on_stage_completed_success(
sentry_listener, get_mock_stage_completed
):
- listener, mock_hub = sentry_listener
+ listener = sentry_listener
+ with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb:
+ mock_stage_completed = get_mock_stage_completed(failure_reason=False)
+ listener.onStageCompleted(mock_stage_completed)
- mock_stage_completed = get_mock_stage_completed(failure_reason=False)
- listener.onStageCompleted(mock_stage_completed)
+ mock_add_breadcrumb.assert_called_once()
+ mock_hub = mock_add_breadcrumb.call_args
- assert mock_hub.kwargs["level"] == "info"
- assert "sample-stage-id-submit" in mock_hub.kwargs["message"]
- assert mock_hub.kwargs["data"]["attemptId"] == 14
- assert mock_hub.kwargs["data"]["name"] == "run-job"
- assert "reason" not in mock_hub.kwargs["data"]
+ assert mock_hub.kwargs["level"] == "info"
+ assert "sample-stage-id-submit" in mock_hub.kwargs["message"]
+ assert mock_hub.kwargs["data"]["attemptId"] == 14
+ assert mock_hub.kwargs["data"]["name"] == "run-job"
+ assert "reason" not in mock_hub.kwargs["data"]
def test_sentry_listener_on_stage_completed_failure(
sentry_listener, get_mock_stage_completed
):
- listener, mock_hub = sentry_listener
-
- mock_stage_completed = get_mock_stage_completed(failure_reason=True)
- listener.onStageCompleted(mock_stage_completed)
-
- assert mock_hub.kwargs["level"] == "warning"
- assert "sample-stage-id-submit" in mock_hub.kwargs["message"]
- assert mock_hub.kwargs["data"]["attemptId"] == 14
- assert mock_hub.kwargs["data"]["name"] == "run-job"
- assert mock_hub.kwargs["data"]["reason"] == "failure-reason"
+ listener = sentry_listener
+ with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb:
+ mock_stage_completed = get_mock_stage_completed(failure_reason=True)
+ listener.onStageCompleted(mock_stage_completed)
+
+ mock_add_breadcrumb.assert_called_once()
+ mock_hub = mock_add_breadcrumb.call_args
+
+ assert mock_hub.kwargs["level"] == "warning"
+ assert "sample-stage-id-submit" in mock_hub.kwargs["message"]
+ assert mock_hub.kwargs["data"]["attemptId"] == 14
+ assert mock_hub.kwargs["data"]["name"] == "run-job"
+ assert mock_hub.kwargs["data"]["reason"] == "failure-reason"
################
diff --git a/tests/integrations/sqlalchemy/__init__.py b/tests/integrations/sqlalchemy/__init__.py
index b430bf6d43..33c43a6872 100644
--- a/tests/integrations/sqlalchemy/__init__.py
+++ b/tests/integrations/sqlalchemy/__init__.py
@@ -1,3 +1,9 @@
+import os
+import sys
import pytest
pytest.importorskip("sqlalchemy")
+
+# Load `sqlalchemy_helpers` into the module search path to test query source path names relative to module. See
+# `test_query_source_with_module_in_search_path`
+sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
diff --git a/tests/integrations/sqlalchemy/sqlalchemy_helpers/__init__.py b/tests/integrations/sqlalchemy/sqlalchemy_helpers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/sqlalchemy/sqlalchemy_helpers/helpers.py b/tests/integrations/sqlalchemy/sqlalchemy_helpers/helpers.py
new file mode 100644
index 0000000000..ca65a88d25
--- /dev/null
+++ b/tests/integrations/sqlalchemy/sqlalchemy_helpers/helpers.py
@@ -0,0 +1,7 @@
+def add_model_to_session(model, session):
+ session.add(model)
+ session.commit()
+
+
+def query_first_model_from_session(model_klass, session):
+ return session.query(model_klass).first()
diff --git a/tests/integrations/sqlalchemy/test_sqlalchemy.py b/tests/integrations/sqlalchemy/test_sqlalchemy.py
index d45ea36a19..7c7ce3d845 100644
--- a/tests/integrations/sqlalchemy/test_sqlalchemy.py
+++ b/tests/integrations/sqlalchemy/test_sqlalchemy.py
@@ -1,15 +1,21 @@
-import sys
-import pytest
+import os
+from datetime import datetime
+from unittest import mock
+import pytest
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
+from sqlalchemy import text
-from sentry_sdk import capture_message, start_transaction, configure_scope
+import sentry_sdk
+from sentry_sdk import capture_message, start_transaction
+from sentry_sdk.consts import DEFAULT_MAX_VALUE_LENGTH, SPANDATA
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from sentry_sdk.serializer import MAX_EVENT_BYTES
-from sentry_sdk.utils import json_dumps, MAX_STRING_LENGTH
+from sentry_sdk.tracing_utils import record_sql_queries
+from sentry_sdk.utils import json_dumps
def test_orm_queries(sentry_init, capture_events):
@@ -34,7 +40,9 @@ class Address(Base):
person_id = Column(Integer, ForeignKey("person.id"))
person = relationship(Person)
- engine = create_engine("sqlite:///:memory:")
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine) # noqa: N806
@@ -70,11 +78,7 @@ class Address(Base):
]
-@pytest.mark.skipif(
- sys.version_info < (3,), reason="This sqla usage seems to be broken on Py2"
-)
def test_transactions(sentry_init, capture_events, render_span_tree):
-
sentry_init(
integrations=[SqlalchemyIntegration()],
_experiments={"record_sql_params": True},
@@ -98,7 +102,9 @@ class Address(Base):
person_id = Column(Integer, ForeignKey("person.id"))
person = relationship(Person)
- engine = create_engine("sqlite:///:memory:")
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine) # noqa: N806
@@ -119,6 +125,13 @@ class Address(Base):
(event,) = events
+ for span in event["spans"]:
+ assert span["data"][SPANDATA.DB_SYSTEM] == "sqlite"
+ assert span["data"][SPANDATA.DB_DRIVER_NAME] == "pysqlite"
+ assert span["data"][SPANDATA.DB_NAME] == ":memory:"
+ assert SPANDATA.SERVER_ADDRESS not in span["data"]
+ assert SPANDATA.SERVER_PORT not in span["data"]
+
assert (
render_span_tree(event)
== """\
@@ -139,6 +152,62 @@ class Address(Base):
)
+def test_transactions_no_engine_url(sentry_init, capture_events):
+ sentry_init(
+ integrations=[SqlalchemyIntegration()],
+ _experiments={"record_sql_params": True},
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ Base = declarative_base() # noqa: N806
+
+ class Person(Base):
+ __tablename__ = "person"
+ id = Column(Integer, primary_key=True)
+ name = Column(String(250), nullable=False)
+
+ class Address(Base):
+ __tablename__ = "address"
+ id = Column(Integer, primary_key=True)
+ street_name = Column(String(250))
+ street_number = Column(String(250))
+ post_code = Column(String(250), nullable=False)
+ person_id = Column(Integer, ForeignKey("person.id"))
+ person = relationship(Person)
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ engine.url = None
+ Base.metadata.create_all(engine)
+
+ Session = sessionmaker(bind=engine) # noqa: N806
+ session = Session()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ with session.begin_nested():
+ session.query(Person).first()
+
+ for _ in range(2):
+ with pytest.raises(IntegrityError):
+ with session.begin_nested():
+ session.add(Person(id=1, name="bob"))
+ session.add(Person(id=1, name="bob"))
+
+ with session.begin_nested():
+ session.query(Person).first()
+
+ (event,) = events
+
+ for span in event["spans"]:
+ assert span["data"][SPANDATA.DB_SYSTEM] == "sqlite"
+ assert span["data"][SPANDATA.DB_DRIVER_NAME] == "pysqlite"
+ assert SPANDATA.DB_NAME not in span["data"]
+ assert SPANDATA.SERVER_ADDRESS not in span["data"]
+ assert SPANDATA.SERVER_PORT not in span["data"]
+
+
def test_long_sql_query_preserved(sentry_init, capture_events):
sentry_init(
traces_sample_rate=1,
@@ -146,10 +215,12 @@ def test_long_sql_query_preserved(sentry_init, capture_events):
)
events = capture_events()
- engine = create_engine("sqlite:///:memory:")
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
with start_transaction(name="test"):
with engine.connect() as con:
- con.execute(" UNION ".join("SELECT {}".format(i) for i in range(100)))
+ con.execute(text(" UNION ".join("SELECT {}".format(i) for i in range(100))))
(event,) = events
description = event["spans"][0]["description"]
@@ -164,20 +235,24 @@ def test_large_event_not_truncated(sentry_init, capture_events):
)
events = capture_events()
- long_str = "x" * (MAX_STRING_LENGTH + 10)
+ long_str = "x" * (DEFAULT_MAX_VALUE_LENGTH + 10)
- with configure_scope() as scope:
+ scope = sentry_sdk.get_isolation_scope()
- @scope.add_event_processor
- def processor(event, hint):
- event["message"] = long_str
- return event
+ @scope.add_event_processor
+ def processor(event, hint):
+ event["message"] = long_str
+ return event
- engine = create_engine("sqlite:///:memory:")
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
with start_transaction(name="test"):
with engine.connect() as con:
for _ in range(1500):
- con.execute(" UNION ".join("SELECT {}".format(i) for i in range(100)))
+ con.execute(
+ text(" UNION ".join("SELECT {}".format(i) for i in range(100)))
+ )
(event,) = events
@@ -198,9 +273,427 @@ def processor(event, hint):
assert description.endswith("SELECT 98 UNION SELECT 99")
# Smoke check that truncation of other fields has not changed.
- assert len(event["message"]) == MAX_STRING_LENGTH
+ assert len(event["message"]) == DEFAULT_MAX_VALUE_LENGTH
# The _meta for other truncated fields should be there as well.
assert event["_meta"]["message"] == {
- "": {"len": 1034, "rem": [["!limit", "x", 1021, 1024]]}
+ "": {
+ "len": DEFAULT_MAX_VALUE_LENGTH + 10,
+ "rem": [
+ ["!limit", "x", DEFAULT_MAX_VALUE_LENGTH - 3, DEFAULT_MAX_VALUE_LENGTH]
+ ],
+ }
+ }
+
+
+def test_engine_name_not_string(sentry_init):
+ sentry_init(
+ integrations=[SqlalchemyIntegration()],
+ )
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ engine.dialect.name = b"sqlite"
+
+ with engine.connect() as con:
+ con.execute(text("SELECT 0"))
+
+
+def test_query_source_disabled(sentry_init, capture_events):
+ sentry_options = {
+ "integrations": [SqlalchemyIntegration()],
+ "traces_sample_rate": 1.0,
+ "enable_db_query_source": False,
+ "db_query_source_threshold_ms": 0,
}
+
+ sentry_init(**sentry_options)
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ Base = declarative_base() # noqa: N806
+
+ class Person(Base):
+ __tablename__ = "person"
+ id = Column(Integer, primary_key=True)
+ name = Column(String(250), nullable=False)
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ Base.metadata.create_all(engine)
+
+ Session = sessionmaker(bind=engine) # noqa: N806
+ session = Session()
+
+ bob = Person(name="Bob")
+ session.add(bob)
+
+ assert session.query(Person).first() == bob
+
+ (event,) = events
+
+ for span in event["spans"]:
+ if span.get("op") == "db" and span.get("description").startswith(
+ "SELECT person"
+ ):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+@pytest.mark.parametrize("enable_db_query_source", [None, True])
+def test_query_source_enabled(sentry_init, capture_events, enable_db_query_source):
+ sentry_options = {
+ "integrations": [SqlalchemyIntegration()],
+ "traces_sample_rate": 1.0,
+ "db_query_source_threshold_ms": 0,
+ }
+ if enable_db_query_source is not None:
+ sentry_options["enable_db_query_source"] = enable_db_query_source
+
+ sentry_init(**sentry_options)
+
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ Base = declarative_base() # noqa: N806
+
+ class Person(Base):
+ __tablename__ = "person"
+ id = Column(Integer, primary_key=True)
+ name = Column(String(250), nullable=False)
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ Base.metadata.create_all(engine)
+
+ Session = sessionmaker(bind=engine) # noqa: N806
+ session = Session()
+
+ bob = Person(name="Bob")
+ session.add(bob)
+
+ assert session.query(Person).first() == bob
+
+ (event,) = events
+
+ for span in event["spans"]:
+ if span.get("op") == "db" and span.get("description").startswith(
+ "SELECT person"
+ ):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+def test_query_source(sentry_init, capture_events):
+ sentry_init(
+ integrations=[SqlalchemyIntegration()],
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=0,
+ )
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ Base = declarative_base() # noqa: N806
+
+ class Person(Base):
+ __tablename__ = "person"
+ id = Column(Integer, primary_key=True)
+ name = Column(String(250), nullable=False)
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ Base.metadata.create_all(engine)
+
+ Session = sessionmaker(bind=engine) # noqa: N806
+ session = Session()
+
+ bob = Person(name="Bob")
+ session.add(bob)
+
+ assert session.query(Person).first() == bob
+
+ (event,) = events
+
+ for span in event["spans"]:
+ if span.get("op") == "db" and span.get("description").startswith(
+ "SELECT person"
+ ):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE)
+ == "tests.integrations.sqlalchemy.test_sqlalchemy"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/sqlalchemy/test_sqlalchemy.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "test_query_source"
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+def test_query_source_with_module_in_search_path(sentry_init, capture_events):
+ """
+ Test that query source is relative to the path of the module it ran in
+ """
+ sentry_init(
+ integrations=[SqlalchemyIntegration()],
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=0,
+ )
+ events = capture_events()
+
+ from sqlalchemy_helpers.helpers import (
+ add_model_to_session,
+ query_first_model_from_session,
+ )
+
+ with start_transaction(name="test_transaction", sampled=True):
+ Base = declarative_base() # noqa: N806
+
+ class Person(Base):
+ __tablename__ = "person"
+ id = Column(Integer, primary_key=True)
+ name = Column(String(250), nullable=False)
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ Base.metadata.create_all(engine)
+
+ Session = sessionmaker(bind=engine) # noqa: N806
+ session = Session()
+
+ bob = Person(name="Bob")
+
+ add_model_to_session(bob, session)
+
+ assert query_first_model_from_session(Person, session) == bob
+
+ (event,) = events
+
+ for span in event["spans"]:
+ if span.get("op") == "db" and span.get("description").startswith(
+ "SELECT person"
+ ):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "sqlalchemy_helpers.helpers"
+ assert data.get(SPANDATA.CODE_FILEPATH) == "sqlalchemy_helpers/helpers.py"
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "query_first_model_from_session"
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+def test_no_query_source_if_duration_too_short(sentry_init, capture_events):
+ sentry_init(
+ integrations=[SqlalchemyIntegration()],
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=100,
+ )
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ Base = declarative_base() # noqa: N806
+
+ class Person(Base):
+ __tablename__ = "person"
+ id = Column(Integer, primary_key=True)
+ name = Column(String(250), nullable=False)
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ Base.metadata.create_all(engine)
+
+ Session = sessionmaker(bind=engine) # noqa: N806
+ session = Session()
+
+ bob = Person(name="Bob")
+ session.add(bob)
+
+ class fake_record_sql_queries: # noqa: N801
+ def __init__(self, *args, **kwargs):
+ with record_sql_queries(*args, **kwargs) as span:
+ self.span = span
+
+ self.span.start_timestamp = datetime(2024, 1, 1, microsecond=0)
+ self.span.timestamp = datetime(2024, 1, 1, microsecond=99999)
+
+ def __enter__(self):
+ return self.span
+
+ def __exit__(self, type, value, traceback):
+ pass
+
+ with mock.patch(
+ "sentry_sdk.integrations.sqlalchemy.record_sql_queries",
+ fake_record_sql_queries,
+ ):
+ assert session.query(Person).first() == bob
+
+ (event,) = events
+
+ for span in event["spans"]:
+ if span.get("op") == "db" and span.get("description").startswith(
+ "SELECT person"
+ ):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+def test_query_source_if_duration_over_threshold(sentry_init, capture_events):
+ sentry_init(
+ integrations=[SqlalchemyIntegration()],
+ traces_sample_rate=1.0,
+ enable_db_query_source=True,
+ db_query_source_threshold_ms=100,
+ )
+ events = capture_events()
+
+ with start_transaction(name="test_transaction", sampled=True):
+ Base = declarative_base() # noqa: N806
+
+ class Person(Base):
+ __tablename__ = "person"
+ id = Column(Integer, primary_key=True)
+ name = Column(String(250), nullable=False)
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ Base.metadata.create_all(engine)
+
+ Session = sessionmaker(bind=engine) # noqa: N806
+ session = Session()
+
+ bob = Person(name="Bob")
+ session.add(bob)
+
+ class fake_record_sql_queries: # noqa: N801
+ def __init__(self, *args, **kwargs):
+ with record_sql_queries(*args, **kwargs) as span:
+ self.span = span
+
+ self.span.start_timestamp = datetime(2024, 1, 1, microsecond=0)
+ self.span.timestamp = datetime(2024, 1, 1, microsecond=101000)
+
+ def __enter__(self):
+ return self.span
+
+ def __exit__(self, type, value, traceback):
+ pass
+
+ with mock.patch(
+ "sentry_sdk.integrations.sqlalchemy.record_sql_queries",
+ fake_record_sql_queries,
+ ):
+ assert session.query(Person).first() == bob
+
+ (event,) = events
+
+ for span in event["spans"]:
+ if span.get("op") == "db" and span.get("description").startswith(
+ "SELECT person"
+ ):
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert (
+ data.get(SPANDATA.CODE_NAMESPACE)
+ == "tests.integrations.sqlalchemy.test_sqlalchemy"
+ )
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/sqlalchemy/test_sqlalchemy.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert (
+ data.get(SPANDATA.CODE_FUNCTION)
+ == "test_query_source_if_duration_over_threshold"
+ )
+ break
+ else:
+ raise AssertionError("No db span found")
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[SqlalchemyIntegration()],
+ traces_sample_rate=1.0,
+ )
+ events = capture_events()
+
+ engine = create_engine(
+ "sqlite:///:memory:", connect_args={"check_same_thread": False}
+ )
+ with start_transaction(name="foo"):
+ with engine.connect() as con:
+ con.execute(text("SELECT 0"))
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+ assert event["spans"][0]["origin"] == "auto.db.sqlalchemy"
diff --git a/tests/integrations/starlette/templates/trace_meta.html b/tests/integrations/starlette/templates/trace_meta.html
new file mode 100644
index 0000000000..139fd16101
--- /dev/null
+++ b/tests/integrations/starlette/templates/trace_meta.html
@@ -0,0 +1 @@
+{{ sentry_trace_meta }}
diff --git a/tests/integrations/starlette/test_starlette.py b/tests/integrations/starlette/test_starlette.py
index 03cb270049..801cd53bf4 100644
--- a/tests/integrations/starlette/test_starlette.py
+++ b/tests/integrations/starlette/test_starlette.py
@@ -2,37 +2,39 @@
import base64
import functools
import json
+import logging
import os
+import re
import threading
+import warnings
+from unittest import mock
import pytest
-from sentry_sdk import last_event_id, capture_exception
+from sentry_sdk import capture_message, get_baggage, get_traceparent
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
-
-try:
- from unittest import mock # python 3.3 and above
-except ImportError:
- import mock # python < 3.3
-
-from sentry_sdk import capture_message
from sentry_sdk.integrations.starlette import (
StarletteIntegration,
StarletteRequestExtractor,
)
+from sentry_sdk.utils import parse_version
-starlette = pytest.importorskip("starlette")
+import starlette
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
SimpleUser,
)
+from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
+from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.testclient import TestClient
+from tests.integrations.conftest import parametrize_test_configurable_status_codes
+
-STARLETTE_VERSION = tuple([int(x) for x in starlette.__version__.split(".")])
+STARLETTE_VERSION = parse_version(starlette.__version__)
PICTURE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "photo.jpg")
@@ -61,7 +63,6 @@
starlette.datastructures.UploadFile(
filename="photo.jpg",
file=open(PICTURE, "rb"),
- content_type="image/jpeg",
),
),
]
@@ -93,7 +94,15 @@ async def _mock_receive(msg):
return msg
+from starlette.templating import Jinja2Templates
+
+
def starlette_app_factory(middleware=None, debug=True):
+ template_dir = os.path.join(
+ os.getcwd(), "tests", "integrations", "starlette", "templates"
+ )
+ templates = Jinja2Templates(directory=template_dir)
+
async def _homepage(request):
1 / 0
return starlette.responses.JSONResponse({"status": "ok"})
@@ -105,6 +114,9 @@ async def _message(request):
capture_message("hi")
return starlette.responses.JSONResponse({"status": "ok"})
+ async def _nomessage(request):
+ return starlette.responses.JSONResponse({"status": "ok"})
+
async def _message_with_id(request):
capture_message("hi")
return starlette.responses.JSONResponse({"status": "ok"})
@@ -125,15 +137,43 @@ async def _thread_ids_async(request):
}
)
+ async def _render_template(request):
+ capture_message(get_traceparent() + "\n" + get_baggage())
+
+ template_context = {
+ "request": request,
+ "msg": "Hello Template World!",
+ }
+ if STARLETTE_VERSION >= (1,):
+ return templates.TemplateResponse(
+ request, "trace_meta.html", template_context
+ )
+ else:
+ return templates.TemplateResponse("trace_meta.html", template_context)
+
+ all_methods = [
+ "CONNECT",
+ "DELETE",
+ "GET",
+ "HEAD",
+ "OPTIONS",
+ "PATCH",
+ "POST",
+ "PUT",
+ "TRACE",
+ ]
+
app = starlette.applications.Starlette(
debug=debug,
routes=[
starlette.routing.Route("/some_url", _homepage),
starlette.routing.Route("/custom_error", _custom_error),
starlette.routing.Route("/message", _message),
+ starlette.routing.Route("/nomessage", _nomessage, methods=all_methods),
starlette.routing.Route("/message/{message_id}", _message_with_id),
starlette.routing.Route("/sync/thread_ids", _thread_ids_sync),
starlette.routing.Route("/async/thread_ids", _thread_ids_async),
+ starlette.routing.Route("/render_template", _render_template),
],
middleware=middleware,
)
@@ -202,6 +242,12 @@ async def do_stuff(message):
await self.app(scope, receive, do_stuff)
+class SampleMiddlewareWithArgs(Middleware):
+ def __init__(self, app, bla=None):
+ self.app = app
+ self.bla = bla
+
+
class SampleReceiveSendMiddleware:
def __init__(self, app):
self.app = app
@@ -242,7 +288,7 @@ async def my_send(*args, **kwargs):
@pytest.mark.asyncio
-async def test_starlettrequestextractor_content_length(sentry_init):
+async def test_starletterequestextractor_content_length(sentry_init):
scope = SCOPE.copy()
scope["headers"] = [
[b"content-length", str(len(json.dumps(BODY_JSON))).encode()],
@@ -254,7 +300,7 @@ async def test_starlettrequestextractor_content_length(sentry_init):
@pytest.mark.asyncio
-async def test_starlettrequestextractor_cookies(sentry_init):
+async def test_starletterequestextractor_cookies(sentry_init):
starlette_request = starlette.requests.Request(SCOPE)
extractor = StarletteRequestExtractor(starlette_request)
@@ -265,7 +311,7 @@ async def test_starlettrequestextractor_cookies(sentry_init):
@pytest.mark.asyncio
-async def test_starlettrequestextractor_json(sentry_init):
+async def test_starletterequestextractor_json(sentry_init):
starlette_request = starlette.requests.Request(SCOPE)
# Mocking async `_receive()` that works in Python 3.7+
@@ -279,7 +325,7 @@ async def test_starlettrequestextractor_json(sentry_init):
@pytest.mark.asyncio
-async def test_starlettrequestextractor_form(sentry_init):
+async def test_starletterequestextractor_form(sentry_init):
scope = SCOPE.copy()
scope["headers"] = [
[b"content-type", b"multipart/form-data; boundary=fd721ef49ea403a6"],
@@ -307,7 +353,7 @@ async def test_starlettrequestextractor_form(sentry_init):
@pytest.mark.asyncio
-async def test_starlettrequestextractor_body_consumed_twice(
+async def test_starletterequestextractor_body_consumed_twice(
sentry_init, capture_events
):
"""
@@ -345,7 +391,7 @@ async def test_starlettrequestextractor_body_consumed_twice(
@pytest.mark.asyncio
-async def test_starlettrequestextractor_extract_request_info_too_big(sentry_init):
+async def test_starletterequestextractor_extract_request_info_too_big(sentry_init):
sentry_init(
send_default_pii=True,
integrations=[StarletteIntegration()],
@@ -376,7 +422,7 @@ async def test_starlettrequestextractor_extract_request_info_too_big(sentry_init
@pytest.mark.asyncio
-async def test_starlettrequestextractor_extract_request_info(sentry_init):
+async def test_starletterequestextractor_extract_request_info(sentry_init):
sentry_init(
send_default_pii=True,
integrations=[StarletteIntegration()],
@@ -407,7 +453,7 @@ async def test_starlettrequestextractor_extract_request_info(sentry_init):
@pytest.mark.asyncio
-async def test_starlettrequestextractor_extract_request_info_no_pii(sentry_init):
+async def test_starletterequestextractor_extract_request_info_no_pii(sentry_init):
sentry_init(
send_default_pii=False,
integrations=[StarletteIntegration()],
@@ -605,7 +651,7 @@ def test_user_information_transaction_no_pii(sentry_init, capture_events):
def test_middleware_spans(sentry_init, capture_events):
sentry_init(
traces_sample_rate=1.0,
- integrations=[StarletteIntegration()],
+ integrations=[StarletteIntegration(middleware_spans=True)],
)
starlette_app = starlette_app_factory(
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())]
@@ -620,20 +666,49 @@ def test_middleware_spans(sentry_init, capture_events):
(_, transaction_event) = events
- expected = [
+ expected_middleware_spans = [
"ServerErrorMiddleware",
"AuthenticationMiddleware",
"ExceptionMiddleware",
+ "AuthenticationMiddleware", # 'op': 'middleware.starlette.send'
+ "ServerErrorMiddleware", # 'op': 'middleware.starlette.send'
+ "AuthenticationMiddleware", # 'op': 'middleware.starlette.send'
+ "ServerErrorMiddleware", # 'op': 'middleware.starlette.send'
]
+ assert len(transaction_event["spans"]) == len(expected_middleware_spans)
+
idx = 0
for span in transaction_event["spans"]:
- if span["op"] == "middleware.starlette":
- assert span["description"] == expected[idx]
- assert span["tags"]["starlette.middleware_name"] == expected[idx]
+ if span["op"].startswith("middleware.starlette"):
+ assert (
+ span["tags"]["starlette.middleware_name"]
+ == expected_middleware_spans[idx]
+ )
idx += 1
+def test_middleware_spans_disabled(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[StarletteIntegration(middleware_spans=False)],
+ )
+ starlette_app = starlette_app_factory(
+ middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())]
+ )
+ events = capture_events()
+
+ client = TestClient(starlette_app, raise_server_exceptions=False)
+ try:
+ client.get("/message", auth=("Gabriela", "hello123"))
+ except Exception:
+ pass
+
+ (_, transaction_event) = events
+
+ assert len(transaction_event["spans"]) == 0
+
+
def test_middleware_callback_spans(sentry_init, capture_events):
sentry_init(
traces_sample_rate=1.0,
@@ -678,9 +753,7 @@ def test_middleware_callback_spans(sentry_init, capture_events):
},
{
"op": "middleware.starlette.send",
- "description": "_ASGIAdapter.send..send"
- if STARLETTE_VERSION < (0, 21)
- else "_TestClientTransport.handle_request..send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
"tags": {"starlette.middleware_name": "ServerErrorMiddleware"},
},
{
@@ -695,9 +768,7 @@ def test_middleware_callback_spans(sentry_init, capture_events):
},
{
"op": "middleware.starlette.send",
- "description": "_ASGIAdapter.send..send"
- if STARLETTE_VERSION < (0, 21)
- else "_TestClientTransport.handle_request..send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
"tags": {"starlette.middleware_name": "ServerErrorMiddleware"},
},
]
@@ -759,9 +830,11 @@ def test_middleware_partial_receive_send(sentry_init, capture_events):
},
{
"op": "middleware.starlette.receive",
- "description": "_ASGIAdapter.send..receive"
- if STARLETTE_VERSION < (0, 21)
- else "_TestClientTransport.handle_request..receive",
+ "description": (
+ "_ASGIAdapter.send..receive"
+ if STARLETTE_VERSION < (0, 21)
+ else "_TestClientTransport.handle_request..receive"
+ ),
"tags": {"starlette.middleware_name": "ServerErrorMiddleware"},
},
{
@@ -771,9 +844,7 @@ def test_middleware_partial_receive_send(sentry_init, capture_events):
},
{
"op": "middleware.starlette.send",
- "description": "_ASGIAdapter.send..send"
- if STARLETTE_VERSION < (0, 21)
- else "_TestClientTransport.handle_request..send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
"tags": {"starlette.middleware_name": "ServerErrorMiddleware"},
},
{
@@ -801,28 +872,20 @@ def test_middleware_partial_receive_send(sentry_init, capture_events):
idx += 1
-def test_last_event_id(sentry_init, capture_events):
+@pytest.mark.skipif(
+ STARLETTE_VERSION < (0, 35),
+ reason="Positional args for middleware have been introduced in Starlette >= 0.35",
+)
+def test_middleware_positional_args(sentry_init):
sentry_init(
+ traces_sample_rate=1.0,
integrations=[StarletteIntegration()],
)
- events = capture_events()
-
- def handler(request, exc):
- capture_exception(exc)
- return starlette.responses.PlainTextResponse(last_event_id(), status_code=500)
+ _ = starlette_app_factory(middleware=[Middleware(SampleMiddlewareWithArgs, "bla")])
- app = starlette_app_factory(debug=False)
- app.add_exception_handler(500, handler)
-
- client = TestClient(SentryAsgiMiddleware(app), raise_server_exceptions=False)
- response = client.get("/custom_error")
- assert response.status_code == 500
-
- event = events[0]
- assert response.content.strip().decode("ascii") == event["event_id"]
- (exception,) = event["exception"]["values"]
- assert exception["type"] == "Exception"
- assert exception["value"] == "Too Hot"
+ # Only creating the App with an Middleware with args
+ # should not raise an error
+ # So as long as test passes, we are good
def test_legacy_setup(
@@ -846,11 +909,11 @@ def test_legacy_setup(
@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"])
-@mock.patch("sentry_sdk.profiler.PROFILE_MINIMUM_SAMPLES", 0)
+@mock.patch("sentry_sdk.profiler.transaction_profiler.PROFILE_MINIMUM_SAMPLES", 0)
def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, endpoint):
sentry_init(
traces_sample_rate=1.0,
- _experiments={"profiles_sample_rate": 1.0},
+ profiles_sample_rate=1.0,
)
app = starlette_app_factory()
asgi_app = SentryAsgiMiddleware(app)
@@ -869,7 +932,459 @@ def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, en
profiles = [item for item in envelopes[0].items if item.type == "profile"]
assert len(profiles) == 1
- for profile in profiles:
- transactions = profile.payload.json["transactions"]
+ for item in profiles:
+ transactions = item.payload.json["transactions"]
assert len(transactions) == 1
assert str(data["active"]) == transactions[0]["active_thread_id"]
+
+ transactions = [item for item in envelopes[0].items if item.type == "transaction"]
+ assert len(transactions) == 1
+
+ for item in transactions:
+ transaction = item.payload.json
+ trace_context = transaction["contexts"]["trace"]
+ assert str(data["active"]) == trace_context["data"]["thread.id"]
+
+
+def test_original_request_not_scrubbed(sentry_init, capture_events):
+ sentry_init(integrations=[StarletteIntegration()])
+
+ events = capture_events()
+
+ async def _error(request):
+ logging.critical("Oh no!")
+ assert request.headers["Authorization"] == "Bearer ohno"
+ assert await request.json() == {"password": "ohno"}
+ return starlette.responses.JSONResponse({"status": "Oh no!"})
+
+ app = starlette.applications.Starlette(
+ routes=[
+ starlette.routing.Route("/error", _error, methods=["POST"]),
+ ],
+ )
+
+ client = TestClient(app)
+ client.post(
+ "/error",
+ json={"password": "ohno"},
+ headers={"Authorization": "Bearer ohno"},
+ )
+
+ event = events[0]
+ assert event["request"]["data"] == {"password": "[Filtered]"}
+ assert event["request"]["headers"]["authorization"] == "[Filtered]"
+
+
+@pytest.mark.skipif(STARLETTE_VERSION < (0, 24), reason="Requires Starlette >= 0.24")
+def test_template_tracing_meta(sentry_init, capture_events):
+ sentry_init(
+ auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request.
+ integrations=[StarletteIntegration()],
+ )
+ events = capture_events()
+
+ app = starlette_app_factory()
+
+ client = TestClient(app)
+ response = client.get("/render_template")
+ assert response.status_code == 200
+
+ rendered_meta = response.text
+ traceparent, baggage = events[0]["message"].split("\n")
+ assert traceparent != ""
+ assert baggage != ""
+
+ match = re.match(
+ r'^',
+ rendered_meta,
+ )
+ assert match is not None
+ assert match.group(1) == traceparent
+
+ rendered_baggage = match.group(2)
+ assert rendered_baggage == baggage
+
+
+@pytest.mark.parametrize(
+ "request_url,transaction_style,expected_transaction_name,expected_transaction_source",
+ [
+ (
+ "/message/123456",
+ "endpoint",
+ "tests.integrations.starlette.test_starlette.starlette_app_factory.._message_with_id",
+ "component",
+ ),
+ (
+ "/message/123456",
+ "url",
+ "/message/{message_id}",
+ "route",
+ ),
+ ],
+)
+def test_transaction_name(
+ sentry_init,
+ request_url,
+ transaction_style,
+ expected_transaction_name,
+ expected_transaction_source,
+ capture_envelopes,
+):
+ """
+ Tests that the transaction name is something meaningful.
+ """
+ sentry_init(
+ auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request.
+ integrations=[StarletteIntegration(transaction_style=transaction_style)],
+ traces_sample_rate=1.0,
+ )
+
+ envelopes = capture_envelopes()
+
+ app = starlette_app_factory()
+ client = TestClient(app)
+ client.get(request_url)
+
+ (_, transaction_envelope) = envelopes
+ transaction_event = transaction_envelope.get_transaction_event()
+
+ assert transaction_event["transaction"] == expected_transaction_name
+ assert (
+ transaction_event["transaction_info"]["source"] == expected_transaction_source
+ )
+
+
+@pytest.mark.parametrize(
+ "request_url,transaction_style,expected_transaction_name,expected_transaction_source",
+ [
+ (
+ "/message/123456",
+ "endpoint",
+ "http://testserver/message/123456",
+ "url",
+ ),
+ (
+ "/message/123456",
+ "url",
+ "http://testserver/message/123456",
+ "url",
+ ),
+ ],
+)
+def test_transaction_name_in_traces_sampler(
+ sentry_init,
+ request_url,
+ transaction_style,
+ expected_transaction_name,
+ expected_transaction_source,
+):
+ """
+ Tests that a custom traces_sampler has a meaningful transaction name.
+ In this case the URL or endpoint, because we do not have the route yet.
+ """
+
+ def dummy_traces_sampler(sampling_context):
+ assert (
+ sampling_context["transaction_context"]["name"] == expected_transaction_name
+ )
+ assert (
+ sampling_context["transaction_context"]["source"]
+ == expected_transaction_source
+ )
+
+ sentry_init(
+ auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request.
+ integrations=[StarletteIntegration(transaction_style=transaction_style)],
+ traces_sampler=dummy_traces_sampler,
+ traces_sample_rate=1.0,
+ )
+
+ app = starlette_app_factory()
+ client = TestClient(app)
+ client.get(request_url)
+
+
+@pytest.mark.parametrize("middleware_spans", [False, True])
+@pytest.mark.parametrize(
+ "request_url,transaction_style,expected_transaction_name,expected_transaction_source",
+ [
+ (
+ "/message/123456",
+ "endpoint",
+ "starlette.middleware.trustedhost.TrustedHostMiddleware",
+ "component",
+ ),
+ (
+ "/message/123456",
+ "url",
+ "http://testserver/message/123456",
+ "url",
+ ),
+ ],
+)
+def test_transaction_name_in_middleware(
+ sentry_init,
+ middleware_spans,
+ request_url,
+ transaction_style,
+ expected_transaction_name,
+ expected_transaction_source,
+ capture_envelopes,
+):
+ """
+ Tests that the transaction name is something meaningful.
+ """
+ sentry_init(
+ auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request.
+ integrations=[
+ StarletteIntegration(
+ transaction_style=transaction_style, middleware_spans=middleware_spans
+ ),
+ ],
+ traces_sample_rate=1.0,
+ )
+
+ envelopes = capture_envelopes()
+
+ middleware = [
+ Middleware(
+ TrustedHostMiddleware,
+ allowed_hosts=["example.com", "*.example.com"],
+ ),
+ ]
+
+ app = starlette_app_factory(middleware=middleware)
+ client = TestClient(app)
+ client.get(request_url)
+
+ (transaction_envelope,) = envelopes
+ transaction_event = transaction_envelope.get_transaction_event()
+
+ assert transaction_event["contexts"]["response"]["status_code"] == 400
+ assert transaction_event["transaction"] == expected_transaction_name
+ assert (
+ transaction_event["transaction_info"]["source"] == expected_transaction_source
+ )
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(
+ integrations=[StarletteIntegration()],
+ traces_sample_rate=1.0,
+ )
+ starlette_app = starlette_app_factory(
+ middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())]
+ )
+ events = capture_events()
+
+ client = TestClient(starlette_app, raise_server_exceptions=False)
+ try:
+ client.get("/message", auth=("Gabriela", "hello123"))
+ except Exception:
+ pass
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.starlette"
+ for span in event["spans"]:
+ assert span["origin"] == "auto.http.starlette"
+
+
+class NonIterableContainer:
+ """Wraps any container and makes it non-iterable.
+
+ Used to test backwards compatibility with our old way of defining failed_request_status_codes, which allowed
+ passing in a list of (possibly non-iterable) containers. The Python standard library does not provide any built-in
+ non-iterable containers, so we have to define our own.
+ """
+
+ def __init__(self, inner):
+ self.inner = inner
+
+ def __contains__(self, item):
+ return item in self.inner
+
+
+parametrize_test_configurable_status_codes_deprecated = pytest.mark.parametrize(
+ "failed_request_status_codes,status_code,expected_error",
+ [
+ (None, 500, True),
+ (None, 400, False),
+ ([500, 501], 500, True),
+ ([500, 501], 401, False),
+ ([range(400, 499)], 401, True),
+ ([range(400, 499)], 500, False),
+ ([range(400, 499), range(500, 599)], 300, False),
+ ([range(400, 499), range(500, 599)], 403, True),
+ ([range(400, 499), range(500, 599)], 503, True),
+ ([range(400, 403), 500, 501], 401, True),
+ ([range(400, 403), 500, 501], 405, False),
+ ([range(400, 403), 500, 501], 501, True),
+ ([range(400, 403), 500, 501], 503, False),
+ ([], 500, False),
+ ([NonIterableContainer(range(500, 600))], 500, True),
+ ([NonIterableContainer(range(500, 600))], 404, False),
+ ],
+)
+"""Test cases for configurable status codes (deprecated API).
+Also used by the FastAPI tests.
+"""
+
+
+@parametrize_test_configurable_status_codes_deprecated
+def test_configurable_status_codes_deprecated(
+ sentry_init,
+ capture_events,
+ failed_request_status_codes,
+ status_code,
+ expected_error,
+):
+ with pytest.warns(DeprecationWarning):
+ starlette_integration = StarletteIntegration(
+ failed_request_status_codes=failed_request_status_codes
+ )
+
+ sentry_init(integrations=[starlette_integration])
+
+ events = capture_events()
+
+ async def _error(request):
+ raise HTTPException(status_code)
+
+ app = starlette.applications.Starlette(
+ routes=[
+ starlette.routing.Route("/error", _error, methods=["GET"]),
+ ],
+ )
+
+ client = TestClient(app)
+ client.get("/error")
+
+ if expected_error:
+ assert len(events) == 1
+ else:
+ assert not events
+
+
+@pytest.mark.skipif(
+ STARLETTE_VERSION < (0, 21),
+ reason="Requires Starlette >= 0.21, because earlier versions do not support HTTP 'HEAD' requests",
+)
+def test_transaction_http_method_default(sentry_init, capture_events):
+ """
+ By default OPTIONS and HEAD requests do not create a transaction.
+ """
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[
+ StarletteIntegration(),
+ ],
+ )
+ events = capture_events()
+
+ starlette_app = starlette_app_factory()
+
+ client = TestClient(starlette_app)
+ client.get("/nomessage")
+ client.options("/nomessage")
+ client.head("/nomessage")
+
+ assert len(events) == 1
+
+ (event,) = events
+
+ assert event["request"]["method"] == "GET"
+
+
+@pytest.mark.skipif(
+ STARLETTE_VERSION < (0, 21),
+ reason="Requires Starlette >= 0.21, because earlier versions do not support HTTP 'HEAD' requests",
+)
+def test_transaction_http_method_custom(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ integrations=[
+ StarletteIntegration(
+ http_methods_to_capture=(
+ "OPTIONS",
+ "head",
+ ), # capitalization does not matter
+ ),
+ ],
+ debug=True,
+ )
+ events = capture_events()
+
+ starlette_app = starlette_app_factory()
+
+ client = TestClient(starlette_app)
+ client.get("/nomessage")
+ client.options("/nomessage")
+ client.head("/nomessage")
+
+ assert len(events) == 2
+
+ (event1, event2) = events
+
+ assert event1["request"]["method"] == "OPTIONS"
+ assert event2["request"]["method"] == "HEAD"
+
+
+@parametrize_test_configurable_status_codes
+def test_configurable_status_codes(
+ sentry_init,
+ capture_events,
+ failed_request_status_codes,
+ status_code,
+ expected_error,
+):
+ integration_kwargs = {}
+ if failed_request_status_codes is not None:
+ integration_kwargs["failed_request_status_codes"] = failed_request_status_codes
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", DeprecationWarning)
+ starlette_integration = StarletteIntegration(**integration_kwargs)
+
+ sentry_init(integrations=[starlette_integration])
+
+ events = capture_events()
+
+ async def _error(_):
+ raise HTTPException(status_code)
+
+ app = starlette.applications.Starlette(
+ routes=[
+ starlette.routing.Route("/error", _error, methods=["GET"]),
+ ],
+ )
+
+ client = TestClient(app)
+ client.get("/error")
+
+ assert len(events) == int(expected_error)
+
+
+@pytest.mark.asyncio
+async def test_starletterequestextractor_malformed_json_error_handling(sentry_init):
+ scope = SCOPE.copy()
+ scope["headers"] = [
+ [b"content-type", b"application/json"],
+ ]
+ starlette_request = starlette.requests.Request(scope)
+
+ malformed_json = "{invalid json"
+ malformed_messages = [
+ {"type": "http.request", "body": malformed_json.encode("utf-8")},
+ {"type": "http.disconnect"},
+ ]
+
+ side_effect = [_mock_receive(msg) for msg in malformed_messages]
+ starlette_request._receive = mock.Mock(side_effect=side_effect)
+
+ extractor = StarletteRequestExtractor(starlette_request)
+
+ assert extractor.is_json()
+
+ result = await extractor.json()
+ assert result is None
diff --git a/tests/integrations/starlite/test_starlite.py b/tests/integrations/starlite/test_starlite.py
index 603697ce8b..2c3aa704f5 100644
--- a/tests/integrations/starlite/test_starlite.py
+++ b/tests/integrations/starlite/test_starlite.py
@@ -1,65 +1,19 @@
+from __future__ import annotations
import functools
import pytest
-from sentry_sdk import capture_exception, capture_message, last_event_id
+from sentry_sdk import capture_message
from sentry_sdk.integrations.starlite import StarliteIntegration
-starlite = pytest.importorskip("starlite")
-
from typing import Any, Dict
from starlite import AbstractMiddleware, LoggingConfig, Starlite, get, Controller
from starlite.middleware import LoggingMiddlewareConfig, RateLimitConfig
from starlite.middleware.session.memory_backend import MemoryBackendConfig
-from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
from starlite.testing import TestClient
-class SampleMiddleware(AbstractMiddleware):
- async def __call__(self, scope, receive, send) -> None:
- async def do_stuff(message):
- if message["type"] == "http.response.start":
- # do something here.
- pass
- await send(message)
-
- await self.app(scope, receive, do_stuff)
-
-
-class SampleReceiveSendMiddleware(AbstractMiddleware):
- async def __call__(self, scope, receive, send):
- message = await receive()
- assert message
- assert message["type"] == "http.request"
-
- send_output = await send({"type": "something-unimportant"})
- assert send_output is None
-
- await self.app(scope, receive, send)
-
-
-class SamplePartialReceiveSendMiddleware(AbstractMiddleware):
- async def __call__(self, scope, receive, send):
- message = await receive()
- assert message
- assert message["type"] == "http.request"
-
- send_output = await send({"type": "something-unimportant"})
- assert send_output is None
-
- async def my_receive(*args, **kwargs):
- pass
-
- async def my_send(*args, **kwargs):
- pass
-
- partial_receive = functools.partial(my_receive)
- partial_send = functools.partial(my_send)
-
- await self.app(scope, partial_receive, partial_send)
-
-
def starlite_app_factory(middleware=None, debug=True, exception_handlers=None):
class MyController(Controller):
path = "/controller"
@@ -69,7 +23,7 @@ async def controller_error(self) -> None:
raise Exception("Whoa")
@get("/some_url")
- async def homepage_handler() -> Dict[str, Any]:
+ async def homepage_handler() -> "Dict[str, Any]":
1 / 0
return {"status": "ok"}
@@ -78,12 +32,12 @@ async def custom_error() -> Any:
raise Exception("Too Hot")
@get("/message")
- async def message() -> Dict[str, Any]:
+ async def message() -> "Dict[str, Any]":
capture_message("hi")
return {"status": "ok"}
@get("/message/{message_id:str}")
- async def message_with_id() -> Dict[str, Any]:
+ async def message_with_id() -> "Dict[str, Any]":
capture_message("hi")
return {"status": "ok"}
@@ -154,8 +108,8 @@ def test_catch_exceptions(
assert str(exc) == expected_message
(event,) = events
- assert event["exception"]["values"][0]["mechanism"]["type"] == "starlite"
assert event["transaction"] == expected_tx_name
+ assert event["exception"]["values"][0]["mechanism"]["type"] == "starlite"
def test_middleware_spans(sentry_init, capture_events):
@@ -180,40 +134,50 @@ def test_middleware_spans(sentry_init, capture_events):
client = TestClient(
starlite_app, raise_server_exceptions=False, base_url="http://testserver.local"
)
- try:
- client.get("/message")
- except Exception:
- pass
+ client.get("/message")
(_, transaction_event) = events
- expected = ["SessionMiddleware", "LoggingMiddleware", "RateLimitMiddleware"]
+ expected = {"SessionMiddleware", "LoggingMiddleware", "RateLimitMiddleware"}
+ found = set()
+
+ starlite_spans = (
+ span
+ for span in transaction_event["spans"]
+ if span["op"] == "middleware.starlite"
+ )
- idx = 0
- for span in transaction_event["spans"]:
- if span["op"] == "middleware.starlite":
- assert span["description"] == expected[idx]
- assert span["tags"]["starlite.middleware_name"] == expected[idx]
- idx += 1
+ for span in starlite_spans:
+ assert span["description"] in expected
+ assert span["description"] not in found
+ found.add(span["description"])
+ assert span["description"] == span["tags"]["starlite.middleware_name"]
def test_middleware_callback_spans(sentry_init, capture_events):
+ class SampleMiddleware(AbstractMiddleware):
+ async def __call__(self, scope, receive, send) -> None:
+ async def do_stuff(message):
+ if message["type"] == "http.response.start":
+ # do something here.
+ pass
+ await send(message)
+
+ await self.app(scope, receive, do_stuff)
+
sentry_init(
traces_sample_rate=1.0,
integrations=[StarliteIntegration()],
)
- starlette_app = starlite_app_factory(middleware=[SampleMiddleware])
+ starlite_app = starlite_app_factory(middleware=[SampleMiddleware])
events = capture_events()
- client = TestClient(starlette_app, raise_server_exceptions=False)
- try:
- client.get("/message")
- except Exception:
- pass
+ client = TestClient(starlite_app, raise_server_exceptions=False)
+ client.get("/message")
- (_, transaction_event) = events
+ (_, transaction_events) = events
- expected = [
+ expected_starlite_spans = [
{
"op": "middleware.starlite",
"description": "SampleMiddleware",
@@ -221,59 +185,95 @@ def test_middleware_callback_spans(sentry_init, capture_events):
},
{
"op": "middleware.starlite.send",
- "description": "TestClientTransport.create_send..send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
"tags": {"starlite.middleware_name": "SampleMiddleware"},
},
{
"op": "middleware.starlite.send",
- "description": "TestClientTransport.create_send..send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
"tags": {"starlite.middleware_name": "SampleMiddleware"},
},
]
- print(transaction_event["spans"])
- idx = 0
- for span in transaction_event["spans"]:
- assert span["op"] == expected[idx]["op"]
- assert span["description"] == expected[idx]["description"]
- assert span["tags"] == expected[idx]["tags"]
- idx += 1
+
+ def is_matching_span(expected_span, actual_span):
+ return (
+ expected_span["op"] == actual_span["op"]
+ and expected_span["description"] == actual_span["description"]
+ and expected_span["tags"] == actual_span["tags"]
+ )
+
+ actual_starlite_spans = list(
+ span
+ for span in transaction_events["spans"]
+ if "middleware.starlite" in span["op"]
+ )
+ assert len(actual_starlite_spans) == 3
+
+ for expected_span in expected_starlite_spans:
+ assert any(
+ is_matching_span(expected_span, actual_span)
+ for actual_span in actual_starlite_spans
+ )
def test_middleware_receive_send(sentry_init, capture_events):
+ class SampleReceiveSendMiddleware(AbstractMiddleware):
+ async def __call__(self, scope, receive, send):
+ message = await receive()
+ assert message
+ assert message["type"] == "http.request"
+
+ send_output = await send({"type": "something-unimportant"})
+ assert send_output is None
+
+ await self.app(scope, receive, send)
+
sentry_init(
traces_sample_rate=1.0,
integrations=[StarliteIntegration()],
)
- starlette_app = starlite_app_factory(middleware=[SampleReceiveSendMiddleware])
+ starlite_app = starlite_app_factory(middleware=[SampleReceiveSendMiddleware])
- client = TestClient(starlette_app, raise_server_exceptions=False)
- try:
- # NOTE: the assert statements checking
- # for correct behaviour are in `SampleReceiveSendMiddleware`!
- client.get("/message")
- except Exception:
- pass
+ client = TestClient(starlite_app, raise_server_exceptions=False)
+ # See SampleReceiveSendMiddleware.__call__ above for assertions of correct behavior
+ client.get("/message")
def test_middleware_partial_receive_send(sentry_init, capture_events):
+ class SamplePartialReceiveSendMiddleware(AbstractMiddleware):
+ async def __call__(self, scope, receive, send):
+ message = await receive()
+ assert message
+ assert message["type"] == "http.request"
+
+ send_output = await send({"type": "something-unimportant"})
+ assert send_output is None
+
+ async def my_receive(*args, **kwargs):
+ pass
+
+ async def my_send(*args, **kwargs):
+ pass
+
+ partial_receive = functools.partial(my_receive)
+ partial_send = functools.partial(my_send)
+
+ await self.app(scope, partial_receive, partial_send)
+
sentry_init(
traces_sample_rate=1.0,
integrations=[StarliteIntegration()],
)
- starlette_app = starlite_app_factory(
- middleware=[SamplePartialReceiveSendMiddleware]
- )
+ starlite_app = starlite_app_factory(middleware=[SamplePartialReceiveSendMiddleware])
events = capture_events()
- client = TestClient(starlette_app, raise_server_exceptions=False)
- try:
- client.get("/message")
- except Exception:
- pass
+ client = TestClient(starlite_app, raise_server_exceptions=False)
+ # See SamplePartialReceiveSendMiddleware.__call__ above for assertions of correct behavior
+ client.get("/message")
- (_, transaction_event) = events
+ (_, transaction_events) = events
- expected = [
+ expected_starlite_spans = [
{
"op": "middleware.starlite",
"description": "SamplePartialReceiveSendMiddleware",
@@ -286,40 +286,110 @@ def test_middleware_partial_receive_send(sentry_init, capture_events):
},
{
"op": "middleware.starlite.send",
- "description": "TestClientTransport.create_send..send",
+ "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send",
"tags": {"starlite.middleware_name": "SamplePartialReceiveSendMiddleware"},
},
]
- print(transaction_event["spans"])
- idx = 0
- for span in transaction_event["spans"]:
- assert span["op"] == expected[idx]["op"]
- assert span["description"].startswith(expected[idx]["description"])
- assert span["tags"] == expected[idx]["tags"]
- idx += 1
+ def is_matching_span(expected_span, actual_span):
+ return (
+ expected_span["op"] == actual_span["op"]
+ and actual_span["description"].startswith(expected_span["description"])
+ and expected_span["tags"] == actual_span["tags"]
+ )
+
+ actual_starlite_spans = list(
+ span
+ for span in transaction_events["spans"]
+ if "middleware.starlite" in span["op"]
+ )
+ assert len(actual_starlite_spans) == 3
+
+ for expected_span in expected_starlite_spans:
+ assert any(
+ is_matching_span(expected_span, actual_span)
+ for actual_span in actual_starlite_spans
+ )
-def test_last_event_id(sentry_init, capture_events):
+def test_span_origin(sentry_init, capture_events):
sentry_init(
integrations=[StarliteIntegration()],
+ traces_sample_rate=1.0,
+ )
+
+ logging_config = LoggingMiddlewareConfig()
+ session_config = MemoryBackendConfig()
+ rate_limit_config = RateLimitConfig(rate_limit=("hour", 5))
+
+ starlite_app = starlite_app_factory(
+ middleware=[
+ session_config.middleware,
+ logging_config.middleware,
+ rate_limit_config.middleware,
+ ]
)
events = capture_events()
- def handler(request, exc):
- capture_exception(exc)
- return starlite.response.Response(last_event_id(), status_code=500)
+ client = TestClient(
+ starlite_app, raise_server_exceptions=False, base_url="http://testserver.local"
+ )
+ client.get("/message")
+
+ (_, event) = events
+
+ assert event["contexts"]["trace"]["origin"] == "auto.http.starlite"
+ for span in event["spans"]:
+ assert span["origin"] == "auto.http.starlite"
+
+
+@pytest.mark.parametrize(
+ "is_send_default_pii",
+ [
+ True,
+ False,
+ ],
+ ids=[
+ "send_default_pii=True",
+ "send_default_pii=False",
+ ],
+)
+def test_starlite_scope_user_on_exception_event(
+ sentry_init, capture_exceptions, capture_events, is_send_default_pii
+):
+ class TestUserMiddleware(AbstractMiddleware):
+ async def __call__(self, scope, receive, send):
+ scope["user"] = {
+ "email": "lennon@thebeatles.com",
+ "username": "john",
+ "id": "1",
+ }
+ await self.app(scope, receive, send)
- app = starlite_app_factory(
- debug=False, exception_handlers={HTTP_500_INTERNAL_SERVER_ERROR: handler}
+ sentry_init(
+ integrations=[StarliteIntegration()], send_default_pii=is_send_default_pii
)
+ starlite_app = starlite_app_factory(middleware=[TestUserMiddleware])
+ exceptions = capture_exceptions()
+ events = capture_events()
+
+ # This request intentionally raises an exception
+ client = TestClient(starlite_app)
+ try:
+ client.get("/some_url")
+ except Exception:
+ pass
+
+ assert len(exceptions) == 1
+ assert len(events) == 1
+ (event,) = events
- client = TestClient(app, raise_server_exceptions=False)
- response = client.get("/custom_error")
- assert response.status_code == 500
- print(events)
- event = events[-1]
- assert response.content.strip().decode("ascii").strip('"') == event["event_id"]
- (exception,) = event["exception"]["values"]
- assert exception["type"] == "Exception"
- assert exception["value"] == "Too Hot"
+ if is_send_default_pii:
+ assert "user" in event
+ assert event["user"] == {
+ "email": "lennon@thebeatles.com",
+ "username": "john",
+ "id": "1",
+ }
+ else:
+ assert "user" not in event
diff --git a/tests/integrations/statsig/__init__.py b/tests/integrations/statsig/__init__.py
new file mode 100644
index 0000000000..6abc08235b
--- /dev/null
+++ b/tests/integrations/statsig/__init__.py
@@ -0,0 +1,3 @@
+import pytest
+
+pytest.importorskip("statsig")
diff --git a/tests/integrations/statsig/test_statsig.py b/tests/integrations/statsig/test_statsig.py
new file mode 100644
index 0000000000..5eb2cf39f3
--- /dev/null
+++ b/tests/integrations/statsig/test_statsig.py
@@ -0,0 +1,203 @@
+import concurrent.futures as cf
+import sys
+from contextlib import contextmanager
+from statsig import statsig
+from statsig.statsig_user import StatsigUser
+from random import random
+from unittest.mock import Mock
+from sentry_sdk import start_span, start_transaction
+from tests.conftest import ApproxDict
+
+import pytest
+
+import sentry_sdk
+from sentry_sdk.integrations.statsig import StatsigIntegration
+
+
+@contextmanager
+def mock_statsig(gate_dict):
+ old_check_gate = statsig.check_gate
+
+ def mock_check_gate(user, gate, *args, **kwargs):
+ return gate_dict.get(gate, False)
+
+ statsig.check_gate = Mock(side_effect=mock_check_gate)
+
+ yield
+
+ statsig.check_gate = old_check_gate
+
+
+def test_check_gate(sentry_init, capture_events, uninstall_integration):
+ uninstall_integration(StatsigIntegration.identifier)
+
+ with mock_statsig({"hello": True, "world": False}):
+ sentry_init(integrations=[StatsigIntegration()])
+ events = capture_events()
+ user = StatsigUser(user_id="user-id")
+
+ statsig.check_gate(user, "hello")
+ statsig.check_gate(user, "world")
+ statsig.check_gate(user, "other") # unknown gates default to False.
+
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 1
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": False},
+ {"flag": "other", "result": False},
+ ]
+ }
+
+
+def test_check_gate_threaded(sentry_init, capture_events, uninstall_integration):
+ uninstall_integration(StatsigIntegration.identifier)
+
+ with mock_statsig({"hello": True, "world": False}):
+ sentry_init(integrations=[StatsigIntegration()])
+ events = capture_events()
+ user = StatsigUser(user_id="user-id")
+
+ # Capture an eval before we split isolation scopes.
+ statsig.check_gate(user, "hello")
+
+ def task(flag_key):
+ # Creates a new isolation scope for the thread.
+ # This means the evaluations in each task are captured separately.
+ with sentry_sdk.isolation_scope():
+ statsig.check_gate(user, flag_key)
+ # use a tag to identify to identify events later on
+ sentry_sdk.set_tag("task_id", flag_key)
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ with cf.ThreadPoolExecutor(max_workers=2) as pool:
+ pool.map(task, ["world", "other"])
+
+ # Capture error in original scope
+ sentry_sdk.set_tag("task_id", "0")
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 3
+ events.sort(key=lambda e: e["tags"]["task_id"])
+
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ ]
+ }
+ assert events[1]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "other", "result": False},
+ ]
+ }
+ assert events[2]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": False},
+ ]
+ }
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher")
+def test_check_gate_asyncio(sentry_init, capture_events, uninstall_integration):
+ asyncio = pytest.importorskip("asyncio")
+ uninstall_integration(StatsigIntegration.identifier)
+
+ with mock_statsig({"hello": True, "world": False}):
+ sentry_init(integrations=[StatsigIntegration()])
+ events = capture_events()
+ user = StatsigUser(user_id="user-id")
+
+ # Capture an eval before we split isolation scopes.
+ statsig.check_gate(user, "hello")
+
+ async def task(flag_key):
+ with sentry_sdk.isolation_scope():
+ statsig.check_gate(user, flag_key)
+ # use a tag to identify to identify events later on
+ sentry_sdk.set_tag("task_id", flag_key)
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ async def runner():
+ return asyncio.gather(task("world"), task("other"))
+
+ asyncio.run(runner())
+
+ # Capture error in original scope
+ sentry_sdk.set_tag("task_id", "0")
+ sentry_sdk.capture_exception(Exception("something wrong!"))
+
+ assert len(events) == 3
+ events.sort(key=lambda e: e["tags"]["task_id"])
+
+ assert events[0]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ ]
+ }
+ assert events[1]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "other", "result": False},
+ ]
+ }
+ assert events[2]["contexts"]["flags"] == {
+ "values": [
+ {"flag": "hello", "result": True},
+ {"flag": "world", "result": False},
+ ]
+ }
+
+
+def test_wraps_original(sentry_init, uninstall_integration):
+ uninstall_integration(StatsigIntegration.identifier)
+ flag_value = random() < 0.5
+
+ with mock_statsig(
+ {"test-flag": flag_value}
+ ): # patches check_gate with a Mock object.
+ mock_check_gate = statsig.check_gate
+ sentry_init(integrations=[StatsigIntegration()]) # wraps check_gate.
+ user = StatsigUser(user_id="user-id")
+
+ res = statsig.check_gate(user, "test-flag", "extra-arg", kwarg=1) # type: ignore[arg-type]
+
+ assert res == flag_value
+ assert mock_check_gate.call_args == ( # type: ignore[attr-defined]
+ (user, "test-flag", "extra-arg"),
+ {"kwarg": 1},
+ )
+
+
+def test_wrapper_attributes(sentry_init, uninstall_integration):
+ uninstall_integration(StatsigIntegration.identifier)
+ original_check_gate = statsig.check_gate
+ sentry_init(integrations=[StatsigIntegration()])
+
+ # Methods have not lost their qualified names after decoration.
+ assert statsig.check_gate.__name__ == "check_gate"
+ assert statsig.check_gate.__qualname__ == original_check_gate.__qualname__
+
+ # Clean up
+ statsig.check_gate = original_check_gate
+
+
+def test_statsig_span_integration(sentry_init, capture_events, uninstall_integration):
+ uninstall_integration(StatsigIntegration.identifier)
+
+ with mock_statsig({"hello": True}):
+ sentry_init(traces_sample_rate=1.0, integrations=[StatsigIntegration()])
+ events = capture_events()
+ user = StatsigUser(user_id="user-id")
+ with start_transaction(name="hi"):
+ with start_span(op="foo", name="bar"):
+ statsig.check_gate(user, "hello")
+ statsig.check_gate(user, "world")
+
+ (event,) = events
+ assert event["spans"][0]["data"] == ApproxDict(
+ {"flag.evaluation.hello": True, "flag.evaluation.world": False}
+ )
diff --git a/tests/integrations/stdlib/__init__.py b/tests/integrations/stdlib/__init__.py
new file mode 100644
index 0000000000..472e0151b2
--- /dev/null
+++ b/tests/integrations/stdlib/__init__.py
@@ -0,0 +1,6 @@
+import os
+import sys
+
+# Load `httplib_helpers` into the module search path to test request source path names relative to module. See
+# `test_request_source_with_module_in_search_path`
+sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
diff --git a/tests/integrations/stdlib/httplib_helpers/__init__.py b/tests/integrations/stdlib/httplib_helpers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/stdlib/httplib_helpers/helpers.py b/tests/integrations/stdlib/httplib_helpers/helpers.py
new file mode 100644
index 0000000000..875052e7b5
--- /dev/null
+++ b/tests/integrations/stdlib/httplib_helpers/helpers.py
@@ -0,0 +1,3 @@
+def get_request_with_connection(connection, url):
+ connection.request("GET", url)
+ connection.getresponse()
diff --git a/tests/integrations/stdlib/test_httplib.py b/tests/integrations/stdlib/test_httplib.py
index f6ace42ba2..cdbf6cd68c 100644
--- a/tests/integrations/stdlib/test_httplib.py
+++ b/tests/integrations/stdlib/test_httplib.py
@@ -1,37 +1,48 @@
-import random
+import os
+import datetime
+from http.client import HTTPConnection, HTTPSConnection
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from socket import SocketIO
+from threading import Thread
+from urllib.error import HTTPError
+from urllib.request import urlopen
+from unittest import mock
import pytest
-try:
- # py3
- from urllib.request import urlopen
-except ImportError:
- # py2
- from urllib import urlopen
-
-try:
- # py2
- from httplib import HTTPConnection, HTTPSConnection
-except ImportError:
- # py3
- from http.client import HTTPConnection, HTTPSConnection
-
-try:
- from unittest import mock # python 3.3 and above
-except ImportError:
- import mock # python < 3.3
-
-
-from sentry_sdk import capture_message, start_transaction
-from sentry_sdk.consts import MATCH_ALL
-from sentry_sdk.tracing import Transaction
+from sentry_sdk import capture_message, start_transaction, continue_trace
+from sentry_sdk.consts import MATCH_ALL, SPANDATA
from sentry_sdk.integrations.stdlib import StdlibIntegration
-from tests.conftest import create_mock_http_server
+from tests.conftest import ApproxDict, create_mock_http_server, get_free_port
PORT = create_mock_http_server()
+class MockProxyRequestHandler(BaseHTTPRequestHandler):
+ def do_CONNECT(self):
+ self.send_response(200, "Connection Established")
+ self.end_headers()
+
+ self.rfile.readline()
+
+ response = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"
+ self.wfile.write(response)
+ self.wfile.flush()
+
+
+def create_mock_proxy_server():
+ proxy_port = get_free_port()
+ proxy_server = HTTPServer(("localhost", proxy_port), MockProxyRequestHandler)
+ proxy_thread = Thread(target=proxy_server.serve_forever)
+ proxy_thread.daemon = True
+ proxy_thread.start()
+ return proxy_port
+
+
+PROXY_PORT = create_mock_proxy_server()
+
+
def test_crumb_capture(sentry_init, capture_events):
sentry_init(integrations=[StdlibIntegration()])
events = capture_events()
@@ -46,14 +57,60 @@ def test_crumb_capture(sentry_init, capture_events):
assert crumb["type"] == "http"
assert crumb["category"] == "httplib"
- assert crumb["data"] == {
- "url": url,
- "method": "GET",
- "status_code": 200,
- "reason": "OK",
- "http.fragment": "",
- "http.query": "",
- }
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": url,
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_STATUS_CODE: 200,
+ "reason": "OK",
+ SPANDATA.HTTP_FRAGMENT: "",
+ SPANDATA.HTTP_QUERY: "",
+ }
+ )
+
+
+@pytest.mark.parametrize(
+ "status_code,level",
+ [
+ (200, None),
+ (301, None),
+ (403, "warning"),
+ (405, "warning"),
+ (500, "error"),
+ ],
+)
+def test_crumb_capture_client_error(sentry_init, capture_events, status_code, level):
+ sentry_init(integrations=[StdlibIntegration()])
+ events = capture_events()
+
+ url = f"http://localhost:{PORT}/status/{status_code}" # noqa:E231
+ try:
+ urlopen(url)
+ except HTTPError:
+ pass
+
+ capture_message("Testing!")
+
+ (event,) = events
+ (crumb,) = event["breadcrumbs"]["values"]
+
+ assert crumb["type"] == "http"
+ assert crumb["category"] == "httplib"
+
+ if level is None:
+ assert "level" not in crumb
+ else:
+ assert crumb["level"] == level
+
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": url,
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_STATUS_CODE: status_code,
+ SPANDATA.HTTP_FRAGMENT: "",
+ SPANDATA.HTTP_QUERY: "",
+ }
+ )
def test_crumb_capture_hint(sentry_init, capture_events):
@@ -73,25 +130,27 @@ def before_breadcrumb(crumb, hint):
(crumb,) = event["breadcrumbs"]["values"]
assert crumb["type"] == "http"
assert crumb["category"] == "httplib"
- assert crumb["data"] == {
- "url": url,
- "method": "GET",
- "status_code": 200,
- "reason": "OK",
- "extra": "foo",
- "http.fragment": "",
- "http.query": "",
- }
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": url,
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_STATUS_CODE: 200,
+ "reason": "OK",
+ "extra": "foo",
+ SPANDATA.HTTP_FRAGMENT: "",
+ SPANDATA.HTTP_QUERY: "",
+ }
+ )
-def test_empty_realurl(sentry_init, capture_events):
+def test_empty_realurl(sentry_init):
"""
Ensure that after using sentry_sdk.init you can putrequest a
None url.
"""
sentry_init(dsn="")
- HTTPConnection("example.com", port=443).putrequest("POST", None)
+ HTTPConnection("localhost", port=PORT).putrequest("POST", None)
def test_httplib_misuse(sentry_init, capture_events, request):
@@ -114,7 +173,7 @@ def test_httplib_misuse(sentry_init, capture_events, request):
conn.request("GET", "/200")
- with pytest.raises(Exception):
+ with pytest.raises(Exception): # noqa: B017
# This raises an exception, because we didn't call `getresponse` for
# the previous request yet.
#
@@ -131,14 +190,16 @@ def test_httplib_misuse(sentry_init, capture_events, request):
assert crumb["type"] == "http"
assert crumb["category"] == "httplib"
- assert crumb["data"] == {
- "url": "http://localhost:{}/200".format(PORT),
- "method": "GET",
- "status_code": 200,
- "reason": "OK",
- "http.fragment": "",
- "http.query": "",
- }
+ assert crumb["data"] == ApproxDict(
+ {
+ "url": "http://localhost:{}/200".format(PORT),
+ SPANDATA.HTTP_METHOD: "GET",
+ SPANDATA.HTTP_STATUS_CODE: 200,
+ "reason": "OK",
+ SPANDATA.HTTP_FRAGMENT: "",
+ SPANDATA.HTTP_QUERY: "",
+ }
+ )
def test_outgoing_trace_headers(sentry_init, monkeypatch):
@@ -150,14 +211,16 @@ def test_outgoing_trace_headers(sentry_init, monkeypatch):
sentry_init(traces_sample_rate=1.0)
- headers = {}
- headers["baggage"] = (
- "other-vendor-value-1=foo;bar;baz, sentry-trace_id=771a43a4192642f0b136d5159a501700, "
- "sentry-public_key=49d0f7386ad645858ae85020e393bef3, sentry-sample_rate=0.01337, "
- "sentry-user_id=Am%C3%A9lie, other-vendor-value-2=foo;bar;"
- )
+ headers = {
+ "sentry-trace": "771a43a4192642f0b136d5159a501700-1234567890abcdef-1",
+ "baggage": (
+ "other-vendor-value-1=foo;bar;baz, sentry-trace_id=771a43a4192642f0b136d5159a501700, "
+ "sentry-public_key=49d0f7386ad645858ae85020e393bef3, sentry-sample_rate=0.01337, "
+ "sentry-user_id=Am%C3%A9lie, sentry-sample_rand=0.132521102938283, other-vendor-value-2=foo;bar;"
+ ),
+ }
- transaction = Transaction.continue_from_headers(headers)
+ transaction = continue_trace(headers)
with start_transaction(
transaction=transaction,
@@ -165,7 +228,6 @@ def test_outgoing_trace_headers(sentry_init, monkeypatch):
op="greeting.sniff",
trace_id="12312012123120121231201212312012",
) as transaction:
-
HTTPSConnection("www.squirrelchasers.com").request("GET", "/top-chasers")
(request_str,) = mock_send.call_args[0]
@@ -183,17 +245,16 @@ def test_outgoing_trace_headers(sentry_init, monkeypatch):
)
assert request_headers["sentry-trace"] == expected_sentry_trace
- expected_outgoing_baggage_items = [
- "sentry-trace_id=771a43a4192642f0b136d5159a501700",
- "sentry-public_key=49d0f7386ad645858ae85020e393bef3",
- "sentry-sample_rate=0.01337",
- "sentry-user_id=Am%C3%A9lie",
- ]
-
- assert sorted(request_headers["baggage"].split(",")) == sorted(
- expected_outgoing_baggage_items
+ expected_outgoing_baggage = (
+ "sentry-trace_id=771a43a4192642f0b136d5159a501700,"
+ "sentry-public_key=49d0f7386ad645858ae85020e393bef3,"
+ "sentry-sample_rate=1.0,"
+ "sentry-user_id=Am%C3%A9lie,"
+ "sentry-sample_rand=0.132521102938283"
)
+ assert request_headers["baggage"] == expected_outgoing_baggage
+
def test_outgoing_trace_headers_head_sdk(sentry_init, monkeypatch):
# HTTPSConnection.send is passed a string containing (among other things)
@@ -202,11 +263,9 @@ def test_outgoing_trace_headers_head_sdk(sentry_init, monkeypatch):
mock_send = mock.Mock()
monkeypatch.setattr(HTTPSConnection, "send", mock_send)
- # make sure transaction is always sampled
- monkeypatch.setattr(random, "random", lambda: 0.1)
-
sentry_init(traces_sample_rate=0.5, release="foo")
- transaction = Transaction.continue_from_headers({})
+ with mock.patch("sentry_sdk.tracing_utils.Random.randrange", return_value=250000):
+ transaction = continue_trace({})
with start_transaction(transaction=transaction, name="Head SDK tx") as transaction:
HTTPSConnection("www.squirrelchasers.com").request("GET", "/top-chasers")
@@ -226,16 +285,16 @@ def test_outgoing_trace_headers_head_sdk(sentry_init, monkeypatch):
)
assert request_headers["sentry-trace"] == expected_sentry_trace
- expected_outgoing_baggage_items = [
- "sentry-trace_id=%s" % transaction.trace_id,
- "sentry-sample_rate=0.5",
- "sentry-release=foo",
- "sentry-environment=production",
- ]
+ expected_outgoing_baggage = (
+ "sentry-trace_id=%s,"
+ "sentry-sample_rand=0.250000,"
+ "sentry-environment=production,"
+ "sentry-release=foo,"
+ "sentry-sample_rate=0.5,"
+ "sentry-sampled=%s"
+ ) % (transaction.trace_id, "true" if transaction.sampled else "false")
- assert sorted(request_headers["baggage"].split(",")) == sorted(
- expected_outgoing_baggage_items
- )
+ assert request_headers["baggage"] == expected_outgoing_baggage
@pytest.mark.parametrize(
@@ -318,7 +377,7 @@ def test_option_trace_propagation_targets(
)
}
- transaction = Transaction.continue_from_headers(headers)
+ transaction = continue_trace(headers)
with start_transaction(
transaction=transaction,
@@ -326,7 +385,6 @@ def test_option_trace_propagation_targets(
op="greeting.sniff",
trace_id="12312012123120121231201212312012",
) as transaction:
-
HTTPSConnection(host).request("GET", path)
(request_str,) = mock_send.call_args[0]
@@ -342,3 +400,293 @@ def test_option_trace_propagation_targets(
else:
assert "sentry-trace" not in request_headers
assert "baggage" not in request_headers
+
+
+def test_request_source_disabled(sentry_init, capture_events):
+ sentry_options = {
+ "traces_sample_rate": 1.0,
+ "enable_http_request_source": False,
+ "http_request_source_threshold_ms": 0,
+ }
+
+ sentry_init(**sentry_options)
+
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ conn = HTTPConnection("localhost", port=PORT)
+ conn.request("GET", "/foo")
+ conn.getresponse()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+@pytest.mark.parametrize("enable_http_request_source", [None, True])
+def test_request_source_enabled(
+ sentry_init, capture_events, enable_http_request_source
+):
+ sentry_options = {
+ "traces_sample_rate": 1.0,
+ "http_request_source_threshold_ms": 0,
+ }
+ if enable_http_request_source is not None:
+ sentry_options["enable_http_request_source"] = enable_http_request_source
+
+ sentry_init(**sentry_options)
+
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ conn = HTTPConnection("localhost", port=PORT)
+ conn.request("GET", "/foo")
+ conn.getresponse()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+
+def test_request_source(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ conn = HTTPConnection("localhost", port=PORT)
+ conn.request("GET", "/foo")
+ conn.getresponse()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "tests.integrations.stdlib.test_httplib"
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/stdlib/test_httplib.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "test_request_source"
+
+
+def test_request_source_with_module_in_search_path(sentry_init, capture_events):
+ """
+ Test that request source is relative to the path of the module it ran in
+ """
+ sentry_init(
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=0,
+ )
+
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ from httplib_helpers.helpers import get_request_with_connection
+
+ conn = HTTPConnection("localhost", port=PORT)
+ get_request_with_connection(conn, "/foo")
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "httplib_helpers.helpers"
+ assert data.get(SPANDATA.CODE_FILEPATH) == "httplib_helpers/helpers.py"
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert data.get(SPANDATA.CODE_FUNCTION) == "get_request_with_connection"
+
+
+def test_no_request_source_if_duration_too_short(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=100,
+ )
+
+ already_patched_putrequest = HTTPConnection.putrequest
+
+ class HttpConnectionWithPatchedSpan(HTTPConnection):
+ def putrequest(self, *args, **kwargs) -> None:
+ already_patched_putrequest(self, *args, **kwargs)
+ span = self._sentrysdk_span # type: ignore
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=99999)
+
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ conn = HttpConnectionWithPatchedSpan("localhost", port=PORT)
+ conn.request("GET", "/foo")
+ conn.getresponse()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO not in data
+ assert SPANDATA.CODE_NAMESPACE not in data
+ assert SPANDATA.CODE_FILEPATH not in data
+ assert SPANDATA.CODE_FUNCTION not in data
+
+
+def test_request_source_if_duration_over_threshold(sentry_init, capture_events):
+ sentry_init(
+ traces_sample_rate=1.0,
+ enable_http_request_source=True,
+ http_request_source_threshold_ms=100,
+ )
+
+ already_patched_putrequest = HTTPConnection.putrequest
+
+ class HttpConnectionWithPatchedSpan(HTTPConnection):
+ def putrequest(self, *args, **kwargs) -> None:
+ already_patched_putrequest(self, *args, **kwargs)
+ span = self._sentrysdk_span # type: ignore
+ span.start_timestamp = datetime.datetime(2024, 1, 1, microsecond=0)
+ span.timestamp = datetime.datetime(2024, 1, 1, microsecond=100001)
+
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ conn = HttpConnectionWithPatchedSpan("localhost", port=PORT)
+ conn.request("GET", "/foo")
+ conn.getresponse()
+
+ (event,) = events
+
+ span = event["spans"][-1]
+ assert span["description"].startswith("GET")
+
+ data = span.get("data", {})
+
+ assert SPANDATA.CODE_LINENO in data
+ assert SPANDATA.CODE_NAMESPACE in data
+ assert SPANDATA.CODE_FILEPATH in data
+ assert SPANDATA.CODE_FUNCTION in data
+
+ assert type(data.get(SPANDATA.CODE_LINENO)) == int
+ assert data.get(SPANDATA.CODE_LINENO) > 0
+ assert data.get(SPANDATA.CODE_NAMESPACE) == "tests.integrations.stdlib.test_httplib"
+ assert data.get(SPANDATA.CODE_FILEPATH).endswith(
+ "tests/integrations/stdlib/test_httplib.py"
+ )
+
+ is_relative_path = data.get(SPANDATA.CODE_FILEPATH)[0] != os.sep
+ assert is_relative_path
+
+ assert (
+ data.get(SPANDATA.CODE_FUNCTION)
+ == "test_request_source_if_duration_over_threshold"
+ )
+
+
+def test_span_origin(sentry_init, capture_events):
+ sentry_init(traces_sample_rate=1.0, debug=True)
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ conn = HTTPConnection("localhost", port=PORT)
+ conn.request("GET", "/foo")
+ conn.getresponse()
+
+ (event,) = events
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ assert event["spans"][0]["op"] == "http.client"
+ assert event["spans"][0]["origin"] == "auto.http.stdlib.httplib"
+
+
+def test_http_timeout(monkeypatch, sentry_init, capture_envelopes):
+ mock_readinto = mock.Mock(side_effect=TimeoutError)
+ monkeypatch.setattr(SocketIO, "readinto", mock_readinto)
+
+ sentry_init(traces_sample_rate=1.0)
+
+ envelopes = capture_envelopes()
+
+ with pytest.raises(TimeoutError):
+ with start_transaction(op="op", name="name"):
+ conn = HTTPConnection("localhost", port=PORT)
+ conn.request("GET", "/bla")
+ conn.getresponse()
+
+ (transaction_envelope,) = envelopes
+ transaction = transaction_envelope.get_transaction_event()
+ assert len(transaction["spans"]) == 1
+
+ span = transaction["spans"][0]
+ assert span["op"] == "http.client"
+ assert span["description"] == f"GET http://localhost:{PORT}/bla" # noqa: E231
+
+
+@pytest.mark.parametrize("tunnel_port", [8080, None])
+def test_proxy_http_tunnel(sentry_init, capture_events, tunnel_port):
+ sentry_init(traces_sample_rate=1.0)
+ events = capture_events()
+
+ with start_transaction(name="test_transaction"):
+ conn = HTTPConnection("localhost", PROXY_PORT)
+ conn.set_tunnel("api.example.com", tunnel_port)
+ conn.request("GET", "/foo")
+ conn.getresponse()
+
+ (event,) = events
+ (span,) = event["spans"]
+
+ port_modifier = f":{tunnel_port}" if tunnel_port else ""
+ assert span["description"] == f"GET http://api.example.com{port_modifier}/foo"
+ assert span["data"]["url"] == f"http://api.example.com{port_modifier}/foo"
+ assert span["data"][SPANDATA.HTTP_METHOD] == "GET"
+ assert span["data"][SPANDATA.NETWORK_PEER_ADDRESS] == "localhost"
+ assert span["data"][SPANDATA.NETWORK_PEER_PORT] == PROXY_PORT
diff --git a/tests/integrations/stdlib/test_subprocess.py b/tests/integrations/stdlib/test_subprocess.py
index 31da043ac3..593ef8a0dc 100644
--- a/tests/integrations/stdlib/test_subprocess.py
+++ b/tests/integrations/stdlib/test_subprocess.py
@@ -2,18 +2,13 @@
import platform
import subprocess
import sys
+from collections.abc import Mapping
import pytest
from sentry_sdk import capture_message, start_transaction
-from sentry_sdk._compat import PY2
from sentry_sdk.integrations.stdlib import StdlibIntegration
-
-
-if PY2:
- from collections import Mapping
-else:
- from collections.abc import Mapping
+from tests.conftest import ApproxDict
class ImmutableDict(Mapping):
@@ -125,7 +120,7 @@ def test_subprocess_basic(
assert message_event["message"] == "hi"
- data = {"subprocess.cwd": os.getcwd()} if with_cwd else {}
+ data = ApproxDict({"subprocess.cwd": os.getcwd()} if with_cwd else {})
(crumb,) = message_event["breadcrumbs"]["values"]
assert crumb == {
@@ -179,6 +174,19 @@ def test_subprocess_basic(
assert sys.executable + " -c" in subprocess_init_span["description"]
+def test_subprocess_empty_env(sentry_init, monkeypatch):
+ monkeypatch.setenv("TEST_MARKER", "should_not_be_seen")
+ sentry_init(integrations=[StdlibIntegration()], traces_sample_rate=1.0)
+ with start_transaction(name="foo"):
+ args = [
+ sys.executable,
+ "-c",
+ "import os; print(os.environ.get('TEST_MARKER', None))",
+ ]
+ output = subprocess.check_output(args, env={}, universal_newlines=True)
+ assert "should_not_be_seen" not in output
+
+
def test_subprocess_invalid_args(sentry_init):
sentry_init(integrations=[StdlibIntegration()])
@@ -186,3 +194,33 @@ def test_subprocess_invalid_args(sentry_init):
subprocess.Popen(1)
assert "'int' object is not iterable" in str(excinfo.value)
+
+
+def test_subprocess_span_origin(sentry_init, capture_events):
+ sentry_init(integrations=[StdlibIntegration()], traces_sample_rate=1.0)
+ events = capture_events()
+
+ with start_transaction(name="foo"):
+ args = [
+ sys.executable,
+ "-c",
+ "print('hello world')",
+ ]
+ kw = {"args": args, "stdout": subprocess.PIPE}
+
+ popen = subprocess.Popen(**kw)
+ popen.communicate()
+ popen.poll()
+
+ (event,) = events
+
+ assert event["contexts"]["trace"]["origin"] == "manual"
+
+ assert event["spans"][0]["op"] == "subprocess"
+ assert event["spans"][0]["origin"] == "auto.subprocess.stdlib.subprocess"
+
+ assert event["spans"][1]["op"] == "subprocess.communicate"
+ assert event["spans"][1]["origin"] == "auto.subprocess.stdlib.subprocess"
+
+ assert event["spans"][2]["op"] == "subprocess.wait"
+ assert event["spans"][2]["origin"] == "auto.subprocess.stdlib.subprocess"
diff --git a/tests/integrations/strawberry/__init__.py b/tests/integrations/strawberry/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integrations/strawberry/test_strawberry.py b/tests/integrations/strawberry/test_strawberry.py
new file mode 100644
index 0000000000..d3174ed857
--- /dev/null
+++ b/tests/integrations/strawberry/test_strawberry.py
@@ -0,0 +1,746 @@
+import pytest
+from typing import AsyncGenerator, Optional
+
+strawberry = pytest.importorskip("strawberry")
+pytest.importorskip("fastapi")
+pytest.importorskip("flask")
+
+
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+from flask import Flask
+from strawberry.fastapi import GraphQLRouter
+from strawberry.flask.views import GraphQLView
+
+from sentry_sdk.consts import OP
+from sentry_sdk.integrations.fastapi import FastApiIntegration
+from sentry_sdk.integrations.flask import FlaskIntegration
+from sentry_sdk.integrations.starlette import StarletteIntegration
+from sentry_sdk.integrations.strawberry import (
+ StrawberryIntegration,
+ SentryAsyncExtension,
+ SentrySyncExtension,
+)
+from tests.conftest import ApproxDict
+
+try:
+ from strawberry.extensions.tracing import (
+ SentryTracingExtension,
+ SentryTracingExtensionSync,
+ )
+except ImportError:
+ SentryTracingExtension = None
+ SentryTracingExtensionSync = None
+
+parameterize_strawberry_test = pytest.mark.parametrize(
+ "client_factory,async_execution,framework_integrations",
+ (
+ (
+ "async_app_client_factory",
+ True,
+ [FastApiIntegration(), StarletteIntegration()],
+ ),
+ ("sync_app_client_factory", False, [FlaskIntegration()]),
+ ),
+)
+
+
+@strawberry.type
+class Query:
+ @strawberry.field
+ def hello(self) -> str:
+ return "Hello World"
+
+ @strawberry.field
+ def error(self) -> int:
+ return 1 / 0
+
+
+@strawberry.type
+class Mutation:
+ @strawberry.mutation
+ def change(self, attribute: str) -> str:
+ return attribute
+
+
+@strawberry.type
+class Message:
+ content: str
+
+
+@strawberry.type
+class Subscription:
+ @strawberry.subscription
+ async def message_added(self) -> Optional[AsyncGenerator[Message, None]]:
+ message = Message(content="Hello, world!")
+ yield message
+
+
+@pytest.fixture
+def async_app_client_factory():
+ def create_app(schema):
+ async_app = FastAPI()
+ async_app.include_router(GraphQLRouter(schema), prefix="/graphql")
+ return TestClient(async_app)
+
+ return create_app
+
+
+@pytest.fixture
+def sync_app_client_factory():
+ def create_app(schema):
+ sync_app = Flask(__name__)
+ sync_app.add_url_rule(
+ "/graphql",
+ view_func=GraphQLView.as_view("graphql_view", schema=schema),
+ )
+ return sync_app.test_client()
+
+ return create_app
+
+
+def test_async_execution_uses_async_extension(sentry_init):
+ sentry_init(integrations=[StrawberryIntegration(async_execution=True)])
+
+ schema = strawberry.Schema(Query)
+ assert SentryAsyncExtension in schema.extensions
+ assert SentrySyncExtension not in schema.extensions
+
+
+def test_sync_execution_uses_sync_extension(sentry_init):
+ sentry_init(integrations=[StrawberryIntegration(async_execution=False)])
+
+ schema = strawberry.Schema(Query)
+ assert SentrySyncExtension in schema.extensions
+ assert SentryAsyncExtension not in schema.extensions
+
+
+def test_use_sync_extension_if_not_specified(sentry_init):
+ sentry_init(integrations=[StrawberryIntegration()])
+ schema = strawberry.Schema(Query)
+ assert SentrySyncExtension in schema.extensions
+ assert SentryAsyncExtension not in schema.extensions
+
+
+@pytest.mark.skipif(
+ SentryTracingExtension is None,
+ reason="SentryTracingExtension no longer available in this Strawberry version",
+)
+def test_replace_existing_sentry_async_extension(sentry_init):
+ sentry_init(integrations=[StrawberryIntegration()])
+
+ schema = strawberry.Schema(Query, extensions=[SentryTracingExtension])
+ assert SentryTracingExtension not in schema.extensions
+ assert SentrySyncExtension not in schema.extensions
+ assert SentryAsyncExtension in schema.extensions
+
+
+@pytest.mark.skipif(
+ SentryTracingExtensionSync is None,
+ reason="SentryTracingExtensionSync no longer available in this Strawberry version",
+)
+def test_replace_existing_sentry_sync_extension(sentry_init):
+ sentry_init(integrations=[StrawberryIntegration()])
+
+ schema = strawberry.Schema(Query, extensions=[SentryTracingExtensionSync])
+ assert SentryTracingExtensionSync not in schema.extensions
+ assert SentryAsyncExtension not in schema.extensions
+ assert SentrySyncExtension in schema.extensions
+
+
+@parameterize_strawberry_test
+def test_capture_request_if_available_and_send_pii_is_on(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ sentry_init(
+ send_default_pii=True,
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = "query ErrorQuery { error }"
+ client.post("/graphql", json={"query": query, "operationName": "ErrorQuery"})
+
+ assert len(events) == 1
+
+ (error_event,) = events
+
+ assert error_event["exception"]["values"][0]["mechanism"]["type"] == "strawberry"
+ assert error_event["request"]["api_target"] == "graphql"
+ assert error_event["request"]["data"] == {
+ "query": query,
+ "operationName": "ErrorQuery",
+ }
+ assert error_event["contexts"]["response"] == {
+ "data": {
+ "data": None,
+ "errors": [
+ {
+ "message": "division by zero",
+ "locations": [{"line": 1, "column": 20}],
+ "path": ["error"],
+ }
+ ],
+ }
+ }
+ assert len(error_event["breadcrumbs"]["values"]) == 1
+ assert error_event["breadcrumbs"]["values"][0]["category"] == "graphql.operation"
+ assert error_event["breadcrumbs"]["values"][0]["data"] == {
+ "operation_name": "ErrorQuery",
+ "operation_type": "query",
+ }
+
+
+@parameterize_strawberry_test
+def test_do_not_capture_request_if_send_pii_is_off(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = "query ErrorQuery { error }"
+ client.post("/graphql", json={"query": query, "operationName": "ErrorQuery"})
+
+ assert len(events) == 1
+
+ (error_event,) = events
+ assert error_event["exception"]["values"][0]["mechanism"]["type"] == "strawberry"
+ assert "data" not in error_event["request"]
+ assert "response" not in error_event["contexts"]
+
+ assert len(error_event["breadcrumbs"]["values"]) == 1
+ assert error_event["breadcrumbs"]["values"][0]["category"] == "graphql.operation"
+ assert error_event["breadcrumbs"]["values"][0]["data"] == {
+ "operation_name": "ErrorQuery",
+ "operation_type": "query",
+ }
+
+
+@parameterize_strawberry_test
+def test_breadcrumb_no_operation_name(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = "{ error }"
+ client.post("/graphql", json={"query": query})
+
+ assert len(events) == 1
+
+ (error_event,) = events
+
+ assert len(error_event["breadcrumbs"]["values"]) == 1
+ assert error_event["breadcrumbs"]["values"][0]["category"] == "graphql.operation"
+ assert error_event["breadcrumbs"]["values"][0]["data"] == {
+ "operation_name": None,
+ "operation_type": "query",
+ }
+
+
+@parameterize_strawberry_test
+def test_capture_transaction_on_error(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ sentry_init(
+ send_default_pii=True,
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ traces_sample_rate=1,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = "query ErrorQuery { error }"
+ client.post("/graphql", json={"query": query, "operationName": "ErrorQuery"})
+
+ assert len(events) == 2
+ (_, transaction_event) = events
+
+ assert transaction_event["transaction"] == "ErrorQuery"
+ assert transaction_event["contexts"]["trace"]["op"] == OP.GRAPHQL_QUERY
+ assert transaction_event["spans"]
+
+ query_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_QUERY
+ ]
+ assert len(query_spans) == 1, "exactly one query span expected"
+ query_span = query_spans[0]
+ assert query_span["description"] == "query ErrorQuery"
+ assert query_span["data"]["graphql.operation.type"] == "query"
+ assert query_span["data"]["graphql.operation.name"] == "ErrorQuery"
+ assert query_span["data"]["graphql.document"] == query
+ assert query_span["data"]["graphql.resource_name"]
+
+ parse_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_PARSE
+ ]
+ assert len(parse_spans) == 1, "exactly one parse span expected"
+ parse_span = parse_spans[0]
+ assert parse_span["parent_span_id"] == query_span["span_id"]
+ assert parse_span["description"] == "parsing"
+
+ validate_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_VALIDATE
+ ]
+ assert len(validate_spans) == 1, "exactly one validate span expected"
+ validate_span = validate_spans[0]
+ assert validate_span["parent_span_id"] == query_span["span_id"]
+ assert validate_span["description"] == "validation"
+
+ resolve_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_RESOLVE
+ ]
+ assert len(resolve_spans) == 1, "exactly one resolve span expected"
+ resolve_span = resolve_spans[0]
+ assert resolve_span["parent_span_id"] == query_span["span_id"]
+ assert resolve_span["description"] == "resolving Query.error"
+ assert resolve_span["data"] == ApproxDict(
+ {
+ "graphql.field_name": "error",
+ "graphql.parent_type": "Query",
+ "graphql.field_path": "Query.error",
+ "graphql.path": "error",
+ }
+ )
+
+
+@parameterize_strawberry_test
+def test_capture_transaction_on_success(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ traces_sample_rate=1,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = "query GreetingQuery { hello }"
+ client.post("/graphql", json={"query": query, "operationName": "GreetingQuery"})
+
+ assert len(events) == 1
+ (transaction_event,) = events
+
+ assert transaction_event["transaction"] == "GreetingQuery"
+ assert transaction_event["contexts"]["trace"]["op"] == OP.GRAPHQL_QUERY
+ assert transaction_event["spans"]
+
+ query_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_QUERY
+ ]
+ assert len(query_spans) == 1, "exactly one query span expected"
+ query_span = query_spans[0]
+ assert query_span["description"] == "query GreetingQuery"
+ assert query_span["data"]["graphql.operation.type"] == "query"
+ assert query_span["data"]["graphql.operation.name"] == "GreetingQuery"
+ assert query_span["data"]["graphql.document"] == query
+ assert query_span["data"]["graphql.resource_name"]
+
+ parse_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_PARSE
+ ]
+ assert len(parse_spans) == 1, "exactly one parse span expected"
+ parse_span = parse_spans[0]
+ assert parse_span["parent_span_id"] == query_span["span_id"]
+ assert parse_span["description"] == "parsing"
+
+ validate_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_VALIDATE
+ ]
+ assert len(validate_spans) == 1, "exactly one validate span expected"
+ validate_span = validate_spans[0]
+ assert validate_span["parent_span_id"] == query_span["span_id"]
+ assert validate_span["description"] == "validation"
+
+ resolve_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_RESOLVE
+ ]
+ assert len(resolve_spans) == 1, "exactly one resolve span expected"
+ resolve_span = resolve_spans[0]
+ assert resolve_span["parent_span_id"] == query_span["span_id"]
+ assert resolve_span["description"] == "resolving Query.hello"
+ assert resolve_span["data"] == ApproxDict(
+ {
+ "graphql.field_name": "hello",
+ "graphql.parent_type": "Query",
+ "graphql.field_path": "Query.hello",
+ "graphql.path": "hello",
+ }
+ )
+
+
+@parameterize_strawberry_test
+def test_transaction_no_operation_name(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ traces_sample_rate=1,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = "{ hello }"
+ client.post("/graphql", json={"query": query})
+
+ assert len(events) == 1
+ (transaction_event,) = events
+
+ if async_execution:
+ assert transaction_event["transaction"] == "/graphql"
+ else:
+ assert transaction_event["transaction"] == "graphql_view"
+
+ assert transaction_event["spans"]
+
+ query_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_QUERY
+ ]
+ assert len(query_spans) == 1, "exactly one query span expected"
+ query_span = query_spans[0]
+ assert query_span["description"] == "query"
+ assert query_span["data"]["graphql.operation.type"] == "query"
+ assert query_span["data"]["graphql.operation.name"] is None
+ assert query_span["data"]["graphql.document"] == query
+ assert query_span["data"]["graphql.resource_name"]
+
+ parse_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_PARSE
+ ]
+ assert len(parse_spans) == 1, "exactly one parse span expected"
+ parse_span = parse_spans[0]
+ assert parse_span["parent_span_id"] == query_span["span_id"]
+ assert parse_span["description"] == "parsing"
+
+ validate_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_VALIDATE
+ ]
+ assert len(validate_spans) == 1, "exactly one validate span expected"
+ validate_span = validate_spans[0]
+ assert validate_span["parent_span_id"] == query_span["span_id"]
+ assert validate_span["description"] == "validation"
+
+ resolve_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_RESOLVE
+ ]
+ assert len(resolve_spans) == 1, "exactly one resolve span expected"
+ resolve_span = resolve_spans[0]
+ assert resolve_span["parent_span_id"] == query_span["span_id"]
+ assert resolve_span["description"] == "resolving Query.hello"
+ assert resolve_span["data"] == ApproxDict(
+ {
+ "graphql.field_name": "hello",
+ "graphql.parent_type": "Query",
+ "graphql.field_path": "Query.hello",
+ "graphql.path": "hello",
+ }
+ )
+
+
+@parameterize_strawberry_test
+def test_transaction_mutation(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ traces_sample_rate=1,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query, mutation=Mutation)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = 'mutation Change { change(attribute: "something") }'
+ client.post("/graphql", json={"query": query})
+
+ assert len(events) == 1
+ (transaction_event,) = events
+
+ assert transaction_event["transaction"] == "Change"
+ assert transaction_event["contexts"]["trace"]["op"] == OP.GRAPHQL_MUTATION
+ assert transaction_event["spans"]
+
+ query_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_MUTATION
+ ]
+ assert len(query_spans) == 1, "exactly one mutation span expected"
+ query_span = query_spans[0]
+ assert query_span["description"] == "mutation"
+ assert query_span["data"]["graphql.operation.type"] == "mutation"
+ assert query_span["data"]["graphql.operation.name"] is None
+ assert query_span["data"]["graphql.document"] == query
+ assert query_span["data"]["graphql.resource_name"]
+
+ parse_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_PARSE
+ ]
+ assert len(parse_spans) == 1, "exactly one parse span expected"
+ parse_span = parse_spans[0]
+ assert parse_span["parent_span_id"] == query_span["span_id"]
+ assert parse_span["description"] == "parsing"
+
+ validate_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_VALIDATE
+ ]
+ assert len(validate_spans) == 1, "exactly one validate span expected"
+ validate_span = validate_spans[0]
+ assert validate_span["parent_span_id"] == query_span["span_id"]
+ assert validate_span["description"] == "validation"
+
+ resolve_spans = [
+ span for span in transaction_event["spans"] if span["op"] == OP.GRAPHQL_RESOLVE
+ ]
+ assert len(resolve_spans) == 1, "exactly one resolve span expected"
+ resolve_span = resolve_spans[0]
+ assert resolve_span["parent_span_id"] == query_span["span_id"]
+ assert resolve_span["description"] == "resolving Mutation.change"
+ assert resolve_span["data"] == ApproxDict(
+ {
+ "graphql.field_name": "change",
+ "graphql.parent_type": "Mutation",
+ "graphql.field_path": "Mutation.change",
+ "graphql.path": "change",
+ }
+ )
+
+
+@parameterize_strawberry_test
+def test_handle_none_query_gracefully(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ client.post("/graphql", json={})
+
+ assert len(events) == 0, "expected no events to be sent to Sentry"
+
+
+@parameterize_strawberry_test
+def test_span_origin(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ """
+ Tests for OP.GRAPHQL_MUTATION, OP.GRAPHQL_PARSE, OP.GRAPHQL_VALIDATE, OP.GRAPHQL_RESOLVE,
+ """
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ traces_sample_rate=1,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query, mutation=Mutation)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = 'mutation Change { change(attribute: "something") }'
+ client.post("/graphql", json={"query": query})
+
+ (event,) = events
+
+ is_flask = "Flask" in str(framework_integrations[0])
+ if is_flask:
+ assert event["contexts"]["trace"]["origin"] == "auto.http.flask"
+ else:
+ assert event["contexts"]["trace"]["origin"] == "auto.http.starlette"
+
+ for span in event["spans"]:
+ if span["op"].startswith("graphql."):
+ assert span["origin"] == "auto.graphql.strawberry"
+
+
+@parameterize_strawberry_test
+def test_span_origin2(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ """
+ Tests for OP.GRAPHQL_QUERY
+ """
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ traces_sample_rate=1,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query, mutation=Mutation)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = "query GreetingQuery { hello }"
+ client.post("/graphql", json={"query": query, "operationName": "GreetingQuery"})
+
+ (event,) = events
+
+ is_flask = "Flask" in str(framework_integrations[0])
+ if is_flask:
+ assert event["contexts"]["trace"]["origin"] == "auto.http.flask"
+ else:
+ assert event["contexts"]["trace"]["origin"] == "auto.http.starlette"
+
+ for span in event["spans"]:
+ if span["op"].startswith("graphql."):
+ assert span["origin"] == "auto.graphql.strawberry"
+
+
+@parameterize_strawberry_test
+def test_span_origin3(
+ request,
+ sentry_init,
+ capture_events,
+ client_factory,
+ async_execution,
+ framework_integrations,
+):
+ """
+ Tests for OP.GRAPHQL_SUBSCRIPTION
+ """
+ sentry_init(
+ integrations=[
+ StrawberryIntegration(async_execution=async_execution),
+ ]
+ + framework_integrations,
+ traces_sample_rate=1,
+ )
+ events = capture_events()
+
+ schema = strawberry.Schema(Query, subscription=Subscription)
+
+ client_factory = request.getfixturevalue(client_factory)
+ client = client_factory(schema)
+
+ query = "subscription { messageAdded { content } }"
+ client.post("/graphql", json={"query": query})
+
+ (event,) = events
+
+ is_flask = "Flask" in str(framework_integrations[0])
+ if is_flask:
+ assert event["contexts"]["trace"]["origin"] == "auto.http.flask"
+ else:
+ assert event["contexts"]["trace"]["origin"] == "auto.http.starlette"
+
+ for span in event["spans"]:
+ if span["op"].startswith("graphql."):
+ assert span["origin"] == "auto.graphql.strawberry"
diff --git a/tests/integrations/sys_exit/test_sys_exit.py b/tests/integrations/sys_exit/test_sys_exit.py
new file mode 100644
index 0000000000..a9909ae3c2
--- /dev/null
+++ b/tests/integrations/sys_exit/test_sys_exit.py
@@ -0,0 +1,71 @@
+import sys
+
+import pytest
+
+from sentry_sdk.integrations.sys_exit import SysExitIntegration
+
+
+@pytest.mark.parametrize(
+ ("integration_params", "exit_status", "should_capture"),
+ (
+ ({}, 0, False),
+ ({}, 1, True),
+ ({}, None, False),
+ ({}, "unsuccessful exit", True),
+ ({"capture_successful_exits": False}, 0, False),
+ ({"capture_successful_exits": False}, 1, True),
+ ({"capture_successful_exits": False}, None, False),
+ ({"capture_successful_exits": False}, "unsuccessful exit", True),
+ ({"capture_successful_exits": True}, 0, True),
+ ({"capture_successful_exits": True}, 1, True),
+ ({"capture_successful_exits": True}, None, True),
+ ({"capture_successful_exits": True}, "unsuccessful exit", True),
+ ),
+)
+def test_sys_exit(
+ sentry_init, capture_events, integration_params, exit_status, should_capture
+):
+ sentry_init(integrations=[SysExitIntegration(**integration_params)])
+
+ events = capture_events()
+
+ # Manually catch the sys.exit rather than using pytest.raises because IDE does not recognize that pytest.raises
+ # will catch SystemExit.
+ try:
+ sys.exit(exit_status)
+ except SystemExit:
+ ...
+ else:
+ pytest.fail("Patched sys.exit did not raise SystemExit")
+
+ if should_capture:
+ (event,) = events
+ (exception_value,) = event["exception"]["values"]
+
+ assert exception_value["type"] == "SystemExit"
+ assert exception_value["value"] == (
+ str(exit_status) if exit_status is not None else ""
+ )
+ else:
+ assert len(events) == 0
+
+
+def test_sys_exit_integration_not_auto_enabled(sentry_init, capture_events):
+ sentry_init() # No SysExitIntegration
+
+ events = capture_events()
+
+ # Manually catch the sys.exit rather than using pytest.raises because IDE does not recognize that pytest.raises
+ # will catch SystemExit.
+ try:
+ sys.exit(1)
+ except SystemExit:
+ ...
+ else:
+ pytest.fail(
+ "sys.exit should not be patched, but it must have been because it did not raise SystemExit"
+ )
+
+ assert len(events) == 0, (
+ "No events should have been captured because sys.exit should not have been patched"
+ )
diff --git a/tests/integrations/test_gnu_backtrace.py b/tests/integrations/test_gnu_backtrace.py
index b91359dfa8..be7346a2c3 100644
--- a/tests/integrations/test_gnu_backtrace.py
+++ b/tests/integrations/test_gnu_backtrace.py
@@ -4,78 +4,65 @@
from sentry_sdk.integrations.gnu_backtrace import GnuBacktraceIntegration
LINES = r"""
-0. clickhouse-server(StackTrace::StackTrace()+0x16) [0x99d31a6]
-1. clickhouse-server(DB::Exception::Exception(std::__cxx11::basic_string, std::allocator > const&, int)+0x22) [0x372c822]
-10. clickhouse-server(DB::ActionsVisitor::visit(std::shared_ptr const&)+0x1a12) [0x6ae45d2]
-10. clickhouse-server(DB::InterpreterSelectQuery::executeImpl(DB::InterpreterSelectQuery::Pipeline&, std::shared_ptr const&, bool)+0x11af) [0x75c68ff]
-10. clickhouse-server(ThreadPoolImpl::worker(std::_List_iterator)+0x1ab) [0x6f90c1b]
-11. clickhouse-server() [0xae06ddf]
-11. clickhouse-server(DB::ExpressionAnalyzer::getRootActions(std::shared_ptr const&, bool, std::shared_ptr&, bool)+0xdb) [0x6a0a63b]
-11. clickhouse-server(DB::InterpreterSelectQuery::InterpreterSelectQuery(std::shared_ptr const&, DB::Context const&, std::shared_ptr const&, std::shared_ptr const&, std::vector, std::allocator >, std::allocator, std::allocator > > > const&, DB::QueryProcessingStage::Enum, unsigned long, bool)+0x5e6) [0x75c7516]
-12. /lib/x86_64-linux-gnu/libpthread.so.0(+0x8184) [0x7f3bbc568184]
-12. clickhouse-server(DB::ExpressionAnalyzer::getConstActions()+0xc9) [0x6a0b059]
-12. clickhouse-server(DB::InterpreterSelectQuery::InterpreterSelectQuery(std::shared_ptr const&, DB::Context const&, std::vector, std::allocator >, std::allocator, std::allocator > > > const&, DB::QueryProcessingStage::Enum, unsigned long, bool)+0x56) [0x75c8276]
-13. /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3bbbb8303d]
-13. clickhouse-server(DB::InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(std::shared_ptr const&, DB::Context const&, std::vector, std::allocator >, std::allocator, std::allocator > > > const&, DB::QueryProcessingStage::Enum, unsigned long, bool)+0x7e7) [0x75d4067]
-13. clickhouse-server(DB::evaluateConstantExpression(std::shared_ptr const&, DB::Context const&)+0x3ed) [0x656bfdd]
-14. clickhouse-server(DB::InterpreterFactory::get(std::shared_ptr&, DB::Context&, DB::QueryProcessingStage::Enum)+0x3a8) [0x75b0298]
-14. clickhouse-server(DB::makeExplicitSet(DB::ASTFunction const*, DB::Block const&, bool, DB::Context const&, DB::SizeLimits const&, std::unordered_map, DB::PreparedSetKey::Hash, std::equal_to, std::allocator > > >&)+0x382) [0x6adf692]
-15. clickhouse-server() [0x7664c79]
-15. clickhouse-server(DB::ActionsVisitor::makeSet(DB::ASTFunction const*, DB::Block const&)+0x2a7) [0x6ae2227]
-16. clickhouse-server(DB::ActionsVisitor::visit(std::shared_ptr const&)+0x1973) [0x6ae4533]
-16. clickhouse-server(DB::executeQuery(std::__cxx11::basic_string, std::allocator > const&, DB::Context&, bool, DB::QueryProcessingStage::Enum)+0x8a) [0x76669fa]
-17. clickhouse-server(DB::ActionsVisitor::visit(std::shared_ptr const&)+0x1324) [0x6ae3ee4]
-17. clickhouse-server(DB::TCPHandler::runImpl()+0x4b9) [0x30973c9]
-18. clickhouse-server(DB::ExpressionAnalyzer::getRootActions(std::shared_ptr const&, bool, std::shared_ptr&, bool)+0xdb) [0x6a0a63b]
-18. clickhouse-server(DB::TCPHandler::run()+0x2b) [0x30985ab]
-19. clickhouse-server(DB::ExpressionAnalyzer::appendGroupBy(DB::ExpressionActionsChain&, bool)+0x100) [0x6a0b4f0]
-19. clickhouse-server(Poco::Net::TCPServerConnection::start()+0xf) [0x9b53e4f]
-2. clickhouse-server(DB::FunctionTuple::getReturnTypeImpl(std::vector, std::allocator > > const&) const+0x122) [0x3a2a0f2]
-2. clickhouse-server(DB::readException(DB::Exception&, DB::ReadBuffer&, std::__cxx11::basic_string, std::allocator