Skip to content
2 changes: 1 addition & 1 deletion redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def parse_error(cls, response):
exception_class = cls.EXCEPTION_CLASSES[error_code]
if isinstance(exception_class, dict):
exception_class = exception_class.get(response, ResponseError)
return exception_class(response)
return exception_class(response, status_code=error_code)
return ResponseError(response)

def on_disconnect(self):
Expand Down
222 changes: 211 additions & 11 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
AfterPubSubConnectionInstantiationEvent,
AfterSingleConnectionInstantiationEvent,
ClientType,
EventDispatcher, AfterCommandExecutionEvent,
EventDispatcher, AfterCommandExecutionEvent, OnErrorEvent,
)
from redis.exceptions import (
ConnectionError,
Expand Down Expand Up @@ -640,7 +640,14 @@ def _send_command_parse_response(self, conn, command_name, *args, **options):
conn.send_command(*args, **options)
return self.parse_response(conn, command_name, **options)

def _close_connection(self, conn) -> None:
def _close_connection(
self,
conn,
error: Optional[Exception] = None,
failure_count: Optional[int] = None,
start_time: Optional[float] = None,
command_name: Optional[str] = None,
) -> None:
"""
Close the connection before retrying.

Expand All @@ -650,7 +657,27 @@ def _close_connection(self, conn) -> None:
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
"""
if error and failure_count <= conn.retry.get_retries():
self._event_dispatcher.dispatch(
AfterCommandExecutionEvent(
command_name=command_name,
duration_seconds=time.monotonic() - start_time,
server_address=conn.host,
server_port=conn.port,
db_namespace=str(conn.db),
error=error,
retry_attempts=failure_count,
)
)

self._event_dispatcher.dispatch(
OnErrorEvent(
error=error,
server_address=conn.host,
server_port=conn.port,
retry_attempts=failure_count,
)
)
conn.disconnect()

# COMMAND EXECUTION AND PROTOCOL PARSING
Expand All @@ -673,7 +700,14 @@ def _execute_command(self, *args, **options):
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda _: self._close_connection(conn),
lambda error, failure_count: self._close_connection(
conn,
error,
failure_count,
start_time,
command_name,
),
with_failure_count=True
)

self._event_dispatcher.dispatch(
Expand All @@ -697,6 +731,14 @@ def _execute_command(self, *args, **options):
error=e,
)
)
self._event_dispatcher.dispatch(
OnErrorEvent(
error=e,
server_address=conn.host,
server_port=conn.port,
is_internal=False,
)
)
raise

finally:
Expand Down Expand Up @@ -974,13 +1016,42 @@ def clean_health_check_responses(self) -> None:
)
ttl -= 1

def _reconnect(self, conn) -> None:
def _reconnect(
self,
conn,
error: Optional[Exception] = None,
failure_count: Optional[int] = None,
start_time: Optional[float] = None,
command_name: Optional[str] = None,
) -> None:
"""
The supported exceptions are already checked in the
retry object so we don't need to do it here.

In this error handler we are trying to reconnect to the server.
"""
if error and failure_count <= conn.retry.get_retries():
if command_name:
self._event_dispatcher.dispatch(
AfterCommandExecutionEvent(
command_name=command_name,
duration_seconds=time.monotonic() - start_time,
server_address=conn.host,
server_port=conn.port,
db_namespace=str(conn.db),
error=error,
retry_attempts=failure_count,
)
)

self._event_dispatcher.dispatch(
OnErrorEvent(
error=error,
server_address=conn.host,
server_port=conn.port,
retry_attempts=failure_count,
)
)
conn.disconnect()
conn.connect()

Expand All @@ -996,12 +1067,60 @@ def _execute(self, conn, command, *args, **kwargs):
if conn.should_reconnect():
self._reconnect(conn)

response = conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda _: self._reconnect(conn),
)
if not len(args) == 0:
command_name = args[0]
else:
command_name = None

return response
# Start timing for observability
start_time = time.monotonic()

try:
response = conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error, failure_count: self._reconnect(
conn,
error,
failure_count,
start_time,
command_name,
),
with_failure_count=True
)

if command_name:
self._event_dispatcher.dispatch(
AfterCommandExecutionEvent(
command_name=command_name,
duration_seconds=time.monotonic() - start_time,
server_address=conn.host,
server_port=conn.port,
db_namespace=str(conn.db),
)
)

return response
except Exception as e:
if command_name:
self._event_dispatcher.dispatch(
AfterCommandExecutionEvent(
command_name=command_name,
duration_seconds=time.monotonic() - start_time,
server_address=conn.host,
server_port=conn.port,
db_namespace=str(conn.db),
error=e,
)
)
self._event_dispatcher.dispatch(
OnErrorEvent(
error=e,
server_address=conn.host,
server_port=conn.port,
is_internal=False,
)
)
raise

def parse_response(self, block=True, timeout=0):
"""Parse the response from a publish/subscribe command"""
Expand Down Expand Up @@ -1494,6 +1613,9 @@ def _disconnect_reset_raise_on_watching(
self,
conn: AbstractConnection,
error: Exception,
failure_count: Optional[int] = None,
start_time: Optional[float] = None,
command_name: Optional[str] = None,
) -> None:
"""
Close the connection reset watching state and
Expand All @@ -1505,6 +1627,27 @@ def _disconnect_reset_raise_on_watching(
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
"""
if error and failure_count <= conn.retry.get_retries():
self._event_dispatcher.dispatch(
AfterCommandExecutionEvent(
command_name=command_name,
duration_seconds=time.monotonic() - start_time,
server_address=conn.host,
server_port=conn.port,
db_namespace=str(conn.db),
error=error,
retry_attempts=failure_count,
)
)

self._event_dispatcher.dispatch(
OnErrorEvent(
error=error,
server_address=conn.host,
server_port=conn.port,
retry_attempts=failure_count,
)
)
conn.disconnect()

# if we were already watching a variable, the watch is no longer
Expand Down Expand Up @@ -1538,7 +1681,14 @@ def immediate_execute_command(self, *args, **options):
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise_on_watching(conn, error),
lambda error, failure_count: self._disconnect_reset_raise_on_watching(
conn,
error,
failure_count,
start_time,
command_name,
),
with_failure_count=True
)

self._event_dispatcher.dispatch(
Expand All @@ -1563,6 +1713,14 @@ def immediate_execute_command(self, *args, **options):
error=e,
)
)
self._event_dispatcher.dispatch(
OnErrorEvent(
error=e,
server_address=conn.host,
server_port=conn.port,
is_internal=False
)
)
raise


Expand Down Expand Up @@ -1709,6 +1867,10 @@ def _disconnect_raise_on_watching(
self,
conn: AbstractConnection,
error: Exception,
failure_count: Optional[int] = None,
start_time: Optional[float] = None,
command_name: Optional[str] = None,
batch_size: Optional[int] = None,
) -> None:
"""
Close the connection, raise an exception if we were watching.
Expand All @@ -1719,6 +1881,28 @@ def _disconnect_raise_on_watching(
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
"""
if error and failure_count <= conn.retry.get_retries():
self._event_dispatcher.dispatch(
AfterCommandExecutionEvent(
command_name=command_name,
duration_seconds=time.monotonic() - start_time,
server_address=conn.host,
server_port=conn.port,
db_namespace=str(conn.db),
error=error,
retry_attempts=failure_count,
batch_size=batch_size,
)
)

self._event_dispatcher.dispatch(
OnErrorEvent(
error=error,
server_address=conn.host,
server_port=conn.port,
retry_attempts=failure_count
)
)
conn.disconnect()
# if we were watching a variable, the watch is no longer valid
# since this connection has died. raise a WatchError, which
Expand Down Expand Up @@ -1755,7 +1939,15 @@ def execute(self, raise_on_error: bool = True) -> List[Any]:
try:
response = conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_on_watching(conn, error),
lambda error, failure_count: self._disconnect_raise_on_watching(
conn,
error,
failure_count,
start_time,
operation_name,
len(stack),
),
with_failure_count=True
)

self._event_dispatcher.dispatch(
Expand All @@ -1781,6 +1973,14 @@ def execute(self, raise_on_error: bool = True) -> List[Any]:
batch_size=len(stack),
)
)
self._event_dispatcher.dispatch(
OnErrorEvent(
error=e,
server_address=conn.host,
server_port=conn.port,
is_internal=False
)
)
raise

finally:
Expand Down
Loading