Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

psycopg: Patch server_cursor_factory #3181

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@

_logger = logging.getLogger(__name__)
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
_OTEL_SERVER_CURSOR_FACTORY_KEY = "_otel_orig_server_cursor_factory"


class PsycopgInstrumentor(BaseInstrumentor):
Expand Down Expand Up @@ -257,9 +258,17 @@ def instrument_connection(connection, tracer_provider=None):
setattr(
connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory
)
setattr(
connection,
_OTEL_SERVER_CURSOR_FACTORY_KEY,
connection.server_cursor_factory,
)
connection.cursor_factory = _new_cursor_factory(
tracer_provider=tracer_provider
)
connection.server_cursor_factory = _new_cursor_factory(
tracer_provider=tracer_provider
)
connection._is_instrumented_by_opentelemetry = True
else:
_logger.warning(
Expand All @@ -273,6 +282,9 @@ def uninstrument_connection(connection):
connection.cursor_factory = getattr(
connection, _OTEL_CURSOR_FACTORY_KEY, None
)
connection.server_cursor_factory = getattr(
connection, _OTEL_SERVER_CURSOR_FACTORY_KEY, None
)

return connection

Expand All @@ -293,6 +305,12 @@ def wrapped_connection(
kwargs["cursor_factory"] = _new_cursor_factory(**new_factory_kwargs)
connection = connect_method(*args, **kwargs)
self.get_connection_attributes(connection)

connection.server_cursor_factory = _new_cursor_factory(
db_api=self,
base_factory=getattr(connection, "server_cursor_factory", None),
)

return connection


Expand All @@ -313,6 +331,11 @@ async def wrapped_connection(
)
connection = await connect_method(*args, **kwargs)
self.get_connection_attributes(connection)

connection.server_cursor_factory = _new_cursor_async_factory(
db_api=self,
base_factory=getattr(connection, "server_cursor_factory", None),
)
return connection


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import types
from typing import Optional
from unittest import IsolatedAsyncioTestCase, mock

import psycopg
Expand Down Expand Up @@ -83,10 +84,14 @@ class MockConnection:

def __init__(self, *args, **kwargs):
self.cursor_factory = kwargs.pop("cursor_factory", None)
self.server_cursor_factory = lambda _: MockCursor()

def cursor(self):
if self.cursor_factory:
def cursor(self, name: Optional[str] = None):
if not name and self.cursor_factory:
return self.cursor_factory(self)

if name and self.server_cursor_factory:
return self.server_cursor_factory(self)
return MockCursor()

def get_dsn_parameters(self): # pylint: disable=no-self-use
Expand All @@ -102,15 +107,18 @@ class MockAsyncConnection:

def __init__(self, *args, **kwargs):
self.cursor_factory = kwargs.pop("cursor_factory", None)
self.server_cursor_factory = lambda _: MockAsyncCursor()

@staticmethod
async def connect(*args, **kwargs):
return MockAsyncConnection(**kwargs)

def cursor(self):
if self.cursor_factory:
cur = self.cursor_factory(self)
return cur
def cursor(self, name: Optional[str] = None):
if not name and self.cursor_factory:
return self.cursor_factory(self)

if name and self.server_cursor_factory:
return self.server_cursor_factory(self)
return MockAsyncCursor()

def execute(self, query, params=None, *, prepare=None, binary=False):
Expand Down Expand Up @@ -197,6 +205,36 @@ def test_instrumentor(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_instrumentor_with_named_cursor(self):
PsycopgInstrumentor().instrument()

cnx = psycopg.connect(database="test")

cursor = cnx.cursor(name="named_cursor")

query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

cnx = psycopg.connect(database="test")
cursor = cnx.cursor(name="named_cursor")
query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
def test_instrumentor_with_connection_class(self):
PsycopgInstrumentor().instrument()
Expand Down Expand Up @@ -228,6 +266,36 @@ def test_instrumentor_with_connection_class(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_instrumentor_with_connection_class_and_named_cursor(self):
PsycopgInstrumentor().instrument()

cnx = psycopg.Connection.connect(database="test")

cursor = cnx.cursor(name="named_cursor")

query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

cnx = psycopg.Connection.connect(database="test")
cursor = cnx.cursor(name="named_cursor")
query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_span_name(self):
PsycopgInstrumentor().instrument()

Expand Down Expand Up @@ -314,6 +382,23 @@ def test_instrument_connection(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
def test_instrument_connection_with_named_cursor(self):
cnx = psycopg.connect(database="test")
query = "SELECT * FROM test"
cursor = cnx.cursor(name="named_cursor")
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 0)

cnx = PsycopgInstrumentor().instrument_connection(cnx)
cursor = cnx.cursor(name="named_cursor")
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
def test_instrument_connection_with_instrument(self):
cnx = psycopg.connect(database="test")
Expand Down Expand Up @@ -368,6 +453,25 @@ def test_uninstrument_connection_with_instrument_connection(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_uninstrument_connection_with_instrument_connection_and_named_cursor(
self,
):
cnx = psycopg.connect(database="test")
PsycopgInstrumentor().instrument_connection(cnx)
query = "SELECT * FROM test"
cursor = cnx.cursor(name="named_cursor")
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

cnx = PsycopgInstrumentor().uninstrument_connection(cnx)
cursor = cnx.cursor(name="named_cursor")
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@mock.patch("opentelemetry.instrumentation.dbapi.wrap_connect")
def test_sqlcommenter_enabled(self, event_mocked):
cnx = psycopg.connect(database="test")
Expand Down Expand Up @@ -419,6 +523,33 @@ async def test_async_connection():
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

async def test_wrap_async_connection_class_with_named_cursor(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect("test")
async with acnx as cnx:
async with cnx.cursor(name="named_cursor") as cursor:
await cursor.execute("SELECT * FROM test")

await test_async_connection()
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

await test_async_connection()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
async def test_instrumentor_with_async_connection_class(self):
PsycopgInstrumentor().instrument()
Expand Down
Loading