fix: prevent recovery cascade and improve Stop-scenario reasoning lookup (#25)
parent
7bdf177e0f
commit
4eebf78351
|
|
@ -46,6 +46,17 @@ def tool_call_ids(message: dict[str, Any]) -> list[str]:
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def tool_call_names(message: dict[str, Any]) -> list[str]:
|
||||||
|
names: list[str] = []
|
||||||
|
for tool_call in message.get("tool_calls") or []:
|
||||||
|
if not isinstance(tool_call, dict):
|
||||||
|
continue
|
||||||
|
function = tool_call.get("function")
|
||||||
|
if isinstance(function, dict) and function.get("name"):
|
||||||
|
names.append(str(function["name"]))
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
def message_signature(message: dict[str, Any]) -> str:
|
def message_signature(message: dict[str, Any]) -> str:
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
normalize_tool_call(tool_call)
|
normalize_tool_call(tool_call)
|
||||||
|
|
@ -125,6 +136,13 @@ def scoped_reasoning_keys(message: dict[str, Any], scope: str) -> list[str]:
|
||||||
for tool_call in (message.get("tool_calls") or [])
|
for tool_call in (message.get("tool_calls") or [])
|
||||||
if isinstance(tool_call, dict)
|
if isinstance(tool_call, dict)
|
||||||
)
|
)
|
||||||
|
# Recovery-of-last-resort key. Catches the case where a streaming response
|
||||||
|
# was interrupted (user pressed Stop) before the tool_call.id chunk arrived,
|
||||||
|
# so neither tool_call_id nor tool_call_signature (which canonicalizes
|
||||||
|
# arguments) survives the round-trip through Cursor's transcript.
|
||||||
|
keys.extend(
|
||||||
|
f"scope:{scope}:tool_name:{tool_name}" for tool_name in tool_call_names(message)
|
||||||
|
)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -152,6 +170,10 @@ def portable_reasoning_keys(
|
||||||
for tool_call in (message.get("tool_calls") or [])
|
for tool_call in (message.get("tool_calls") or [])
|
||||||
if isinstance(tool_call, dict)
|
if isinstance(tool_call, dict)
|
||||||
)
|
)
|
||||||
|
keys.extend(
|
||||||
|
f"namespace:{cache_namespace}:turn:{turn_signature}:" f"tool_name:{tool_name}"
|
||||||
|
for tool_name in tool_call_names(message)
|
||||||
|
)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -679,54 +679,64 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
)
|
)
|
||||||
finalized = False
|
finalized = False
|
||||||
pending_recovery_notice = recovery_notice
|
pending_recovery_notice = recovery_notice
|
||||||
while True:
|
try:
|
||||||
try:
|
while True:
|
||||||
line = response.readline()
|
try:
|
||||||
except (HTTPException, OSError) as exc:
|
line = response.readline()
|
||||||
LOG.warning("upstream streaming response read failed: %s", exc)
|
except (HTTPException, OSError) as exc:
|
||||||
return ProxyResponseResult(False, usage)
|
LOG.warning("upstream streaming response read failed: %s", exc)
|
||||||
if not line:
|
return ProxyResponseResult(False, usage)
|
||||||
break
|
if not line:
|
||||||
(
|
break
|
||||||
rewritten,
|
(
|
||||||
finalized,
|
rewritten,
|
||||||
pending_recovery_notice,
|
finalized,
|
||||||
chunk_usage,
|
pending_recovery_notice,
|
||||||
) = self._rewrite_sse_line(
|
chunk_usage,
|
||||||
line,
|
) = self._rewrite_sse_line(
|
||||||
original_model,
|
line,
|
||||||
accumulator,
|
original_model,
|
||||||
cache_namespace,
|
accumulator,
|
||||||
response_contexts,
|
|
||||||
display_adapter,
|
|
||||||
pending_recovery_notice,
|
|
||||||
trace,
|
|
||||||
)
|
|
||||||
if chunk_usage is not None:
|
|
||||||
usage = chunk_usage
|
|
||||||
if trace is not None:
|
|
||||||
trace.record_stream_chunk(line, rewritten)
|
|
||||||
if not self._write_to_client(
|
|
||||||
rewritten, "sending streaming response chunk", flush=True
|
|
||||||
):
|
|
||||||
return ProxyResponseResult(False, usage)
|
|
||||||
if finalized:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not finalized:
|
|
||||||
if self.config.verbose:
|
|
||||||
log_json("model streaming assistant messages", accumulator.messages())
|
|
||||||
stored = sum(
|
|
||||||
accumulator.store_reasoning(
|
|
||||||
self.reasoning_store,
|
|
||||||
scope,
|
|
||||||
cache_namespace,
|
cache_namespace,
|
||||||
prior_messages,
|
response_contexts,
|
||||||
|
display_adapter,
|
||||||
|
pending_recovery_notice,
|
||||||
|
trace,
|
||||||
)
|
)
|
||||||
for scope, prior_messages in response_contexts
|
if chunk_usage is not None:
|
||||||
)
|
usage = chunk_usage
|
||||||
if self.config.verbose and stored:
|
if trace is not None:
|
||||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
trace.record_stream_chunk(line, rewritten)
|
||||||
|
if not self._write_to_client(
|
||||||
|
rewritten, "sending streaming response chunk", flush=True
|
||||||
|
):
|
||||||
|
return ProxyResponseResult(False, usage)
|
||||||
|
if finalized:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
# Store partial reasoning whenever the stream exits without
|
||||||
|
# the upstream's [DONE] terminator (client disconnect, upstream
|
||||||
|
# read failure, exception). Without this, a Stop pressed mid-stream
|
||||||
|
# would discard any reasoning the proxy received but never cached.
|
||||||
|
if not finalized:
|
||||||
|
if self.config.verbose:
|
||||||
|
log_json(
|
||||||
|
"model streaming assistant messages", accumulator.messages()
|
||||||
|
)
|
||||||
|
stored = sum(
|
||||||
|
accumulator.store_reasoning(
|
||||||
|
self.reasoning_store,
|
||||||
|
ctx_scope,
|
||||||
|
cache_namespace,
|
||||||
|
prior_messages,
|
||||||
|
)
|
||||||
|
for ctx_scope, prior_messages in response_contexts
|
||||||
|
)
|
||||||
|
if self.config.verbose and stored:
|
||||||
|
LOG.info(
|
||||||
|
"stored %s streaming reasoning cache key(s) before exit",
|
||||||
|
stored,
|
||||||
|
)
|
||||||
return ProxyResponseResult(True, usage)
|
return ProxyResponseResult(True, usage)
|
||||||
|
|
||||||
def _rewrite_sse_line(
|
def _rewrite_sse_line(
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from .reasoning_store import (
|
||||||
conversation_scope,
|
conversation_scope,
|
||||||
message_signature,
|
message_signature,
|
||||||
tool_call_ids,
|
tool_call_ids,
|
||||||
|
tool_call_names,
|
||||||
tool_call_signature,
|
tool_call_signature,
|
||||||
turn_context_signature,
|
turn_context_signature,
|
||||||
)
|
)
|
||||||
|
|
@ -383,6 +384,16 @@ def reasoning_lookup_keys(
|
||||||
for tool_call in (message.get("tool_calls") or [])
|
for tool_call in (message.get("tool_calls") or [])
|
||||||
if isinstance(tool_call, dict)
|
if isinstance(tool_call, dict)
|
||||||
)
|
)
|
||||||
|
keys.extend(
|
||||||
|
{
|
||||||
|
"kind": "tool_name",
|
||||||
|
"function_name": tool_name,
|
||||||
|
"key": f"scope:{scope}:tool_name:{tool_name}",
|
||||||
|
"portable": False,
|
||||||
|
"hit": False,
|
||||||
|
}
|
||||||
|
for tool_name in tool_call_names(message)
|
||||||
|
)
|
||||||
if cache_namespace and prior_messages is not None:
|
if cache_namespace and prior_messages is not None:
|
||||||
turn_signature = turn_context_signature(prior_messages)
|
turn_signature = turn_context_signature(prior_messages)
|
||||||
keys.append(
|
keys.append(
|
||||||
|
|
@ -428,6 +439,20 @@ def reasoning_lookup_keys(
|
||||||
for tool_call in (message.get("tool_calls") or [])
|
for tool_call in (message.get("tool_calls") or [])
|
||||||
if isinstance(tool_call, dict)
|
if isinstance(tool_call, dict)
|
||||||
)
|
)
|
||||||
|
keys.extend(
|
||||||
|
{
|
||||||
|
"kind": "portable_tool_name",
|
||||||
|
"function_name": tool_name,
|
||||||
|
"key": (
|
||||||
|
f"namespace:{cache_namespace}:turn:{turn_signature}:"
|
||||||
|
f"tool_name:{tool_name}"
|
||||||
|
),
|
||||||
|
"turn_context_signature": turn_signature,
|
||||||
|
"portable": True,
|
||||||
|
"hit": False,
|
||||||
|
}
|
||||||
|
for tool_name in tool_call_names(message)
|
||||||
|
)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -715,5 +715,215 @@ class CrossModeAndModelTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StopMidStreamingToolCallTests(unittest.TestCase):
|
||||||
|
"""Regression for the 'Stop pressed during streaming tool-call arguments'
|
||||||
|
scenario. When the upstream stream is cut off before the tool_call.id
|
||||||
|
chunk arrives, the cached message has tool_calls with no IDs. Cursor
|
||||||
|
synthesises its own ID for its bookkeeping, so the next request looks
|
||||||
|
nothing like the cached message at the id/signature/message-content
|
||||||
|
levels. The tool_name fallback is the only thing that can rescue this."""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.store = ReasoningStore(":memory:")
|
||||||
|
|
||||||
|
def test_tool_name_fallback_restores_reasoning_when_id_missing(self) -> None:
|
||||||
|
# Turn 1 prepares an upstream request and caches a partial assistant
|
||||||
|
# message simulating a Stop before id arrived.
|
||||||
|
first_payload = {
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "u1"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
first_prepared = prepare_upstream_request(
|
||||||
|
first_payload,
|
||||||
|
ProxyConfig(missing_reasoning_strategy="recover"),
|
||||||
|
self.store,
|
||||||
|
)
|
||||||
|
|
||||||
|
partial_response = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"reasoning_content": "Need to grep.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "grep_search",
|
||||||
|
"arguments": '{"q":',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
rewrite_response_body(
|
||||||
|
json.dumps(partial_response).encode("utf-8"),
|
||||||
|
original_model=first_prepared.original_model,
|
||||||
|
store=self.store,
|
||||||
|
request_messages=first_prepared.record_response_messages,
|
||||||
|
cache_namespace=first_prepared.cache_namespace,
|
||||||
|
scope=first_prepared.record_response_scope,
|
||||||
|
prior_messages=first_prepared.record_response_messages,
|
||||||
|
recording_contexts=first_prepared.record_response_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Turn 2: Cursor saved the partial response with a synthesised id and
|
||||||
|
# its own best guess for the arguments.
|
||||||
|
second_payload = {
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "u1"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "cursor-synth-1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "grep_search",
|
||||||
|
"arguments": '{"q":"foo"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "cursor-synth-1",
|
||||||
|
"content": "match",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "u2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
second_prepared = prepare_upstream_request(
|
||||||
|
second_payload,
|
||||||
|
ProxyConfig(missing_reasoning_strategy="recover"),
|
||||||
|
self.store,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(second_prepared.patched_reasoning_messages, 1)
|
||||||
|
self.assertEqual(second_prepared.missing_reasoning_messages, 0)
|
||||||
|
self.assertIsNone(second_prepared.recovery_notice)
|
||||||
|
self.assertEqual(
|
||||||
|
second_prepared.payload["messages"][2]["reasoning_content"],
|
||||||
|
"Need to grep.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tool_name_keys_are_isolated_across_distinct_turns(self) -> None:
|
||||||
|
# Two separate turns each interrupt with the same function name.
|
||||||
|
# The strict scope already differs (each turn has more prior
|
||||||
|
# messages) so the two cached entries should not collide and the
|
||||||
|
# second turn's reasoning must not leak into the first turn's slot.
|
||||||
|
config = ProxyConfig(missing_reasoning_strategy="recover")
|
||||||
|
|
||||||
|
def cache_partial(payload: dict, reasoning: str, args_fragment: str) -> dict:
|
||||||
|
prepared = prepare_upstream_request(payload, config, self.store)
|
||||||
|
response = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"reasoning_content": reasoning,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "grep_search",
|
||||||
|
"arguments": args_fragment,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
rewrite_response_body(
|
||||||
|
json.dumps(response).encode("utf-8"),
|
||||||
|
original_model=prepared.original_model,
|
||||||
|
store=self.store,
|
||||||
|
request_messages=prepared.record_response_messages,
|
||||||
|
cache_namespace=prepared.cache_namespace,
|
||||||
|
scope=prepared.record_response_scope,
|
||||||
|
prior_messages=prepared.record_response_messages,
|
||||||
|
recording_contexts=prepared.record_response_contexts,
|
||||||
|
)
|
||||||
|
return prepared
|
||||||
|
|
||||||
|
turn_a_payload = {
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "u-A"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
cache_partial(turn_a_payload, "Reasoning A.", '{"q":')
|
||||||
|
|
||||||
|
turn_b_payload = {
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "u-A"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "synth-A",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "grep_search",
|
||||||
|
"arguments": '{"q":"a"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "synth-A", "content": "ra"},
|
||||||
|
{"role": "user", "content": "u-B"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
cache_partial(turn_b_payload, "Reasoning B.", '{"q":')
|
||||||
|
|
||||||
|
# Now look up turn A's assistant under its own scope. It must still
|
||||||
|
# return Reasoning A and never Reasoning B (no scope collision).
|
||||||
|
recovery_payload = {
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "u-A"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "synth-A",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "grep_search",
|
||||||
|
"arguments": '{"q":"a"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "synth-A", "content": "ra"},
|
||||||
|
{"role": "user", "content": "u-A2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
prepared = prepare_upstream_request(recovery_payload, config, self.store)
|
||||||
|
self.assertEqual(
|
||||||
|
prepared.payload["messages"][2]["reasoning_content"],
|
||||||
|
"Reasoning A.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue