fix: prevent recovery cascade and improve Stop-scenario reasoning lookup (#25)

main
Anudor 2026-05-02 00:17:58 +08:00 committed by GitHub
parent 7bdf177e0f
commit 4eebf78351
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 313 additions and 46 deletions

View File

@ -46,6 +46,17 @@ def tool_call_ids(message: dict[str, Any]) -> list[str]:
return ids return ids
def tool_call_names(message: dict[str, Any]) -> list[str]:
names: list[str] = []
for tool_call in message.get("tool_calls") or []:
if not isinstance(tool_call, dict):
continue
function = tool_call.get("function")
if isinstance(function, dict) and function.get("name"):
names.append(str(function["name"]))
return names
def message_signature(message: dict[str, Any]) -> str: def message_signature(message: dict[str, Any]) -> str:
tool_calls = [ tool_calls = [
normalize_tool_call(tool_call) normalize_tool_call(tool_call)
@ -125,6 +136,13 @@ def scoped_reasoning_keys(message: dict[str, Any], scope: str) -> list[str]:
for tool_call in (message.get("tool_calls") or []) for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict) if isinstance(tool_call, dict)
) )
# Recovery-of-last-resort key. Catches the case where a streaming response
# was interrupted (user pressed Stop) before the tool_call.id chunk arrived,
# so neither tool_call_id nor tool_call_signature (which canonicalizes
# arguments) survives the round-trip through Cursor's transcript.
keys.extend(
f"scope:{scope}:tool_name:{tool_name}" for tool_name in tool_call_names(message)
)
return keys return keys
@ -152,6 +170,10 @@ def portable_reasoning_keys(
for tool_call in (message.get("tool_calls") or []) for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict) if isinstance(tool_call, dict)
) )
keys.extend(
f"namespace:{cache_namespace}:turn:{turn_signature}:" f"tool_name:{tool_name}"
for tool_name in tool_call_names(message)
)
return keys return keys

View File

@ -679,54 +679,64 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
) )
finalized = False finalized = False
pending_recovery_notice = recovery_notice pending_recovery_notice = recovery_notice
while True: try:
try: while True:
line = response.readline() try:
except (HTTPException, OSError) as exc: line = response.readline()
LOG.warning("upstream streaming response read failed: %s", exc) except (HTTPException, OSError) as exc:
return ProxyResponseResult(False, usage) LOG.warning("upstream streaming response read failed: %s", exc)
if not line: return ProxyResponseResult(False, usage)
break if not line:
( break
rewritten, (
finalized, rewritten,
pending_recovery_notice, finalized,
chunk_usage, pending_recovery_notice,
) = self._rewrite_sse_line( chunk_usage,
line, ) = self._rewrite_sse_line(
original_model, line,
accumulator, original_model,
cache_namespace, accumulator,
response_contexts,
display_adapter,
pending_recovery_notice,
trace,
)
if chunk_usage is not None:
usage = chunk_usage
if trace is not None:
trace.record_stream_chunk(line, rewritten)
if not self._write_to_client(
rewritten, "sending streaming response chunk", flush=True
):
return ProxyResponseResult(False, usage)
if finalized:
break
if not finalized:
if self.config.verbose:
log_json("model streaming assistant messages", accumulator.messages())
stored = sum(
accumulator.store_reasoning(
self.reasoning_store,
scope,
cache_namespace, cache_namespace,
prior_messages, response_contexts,
display_adapter,
pending_recovery_notice,
trace,
) )
for scope, prior_messages in response_contexts if chunk_usage is not None:
) usage = chunk_usage
if self.config.verbose and stored: if trace is not None:
LOG.info("stored %s streaming reasoning cache key(s)", stored) trace.record_stream_chunk(line, rewritten)
if not self._write_to_client(
rewritten, "sending streaming response chunk", flush=True
):
return ProxyResponseResult(False, usage)
if finalized:
break
finally:
# Store partial reasoning whenever the stream exits without
# the upstream's [DONE] terminator (client disconnect, upstream
# read failure, exception). Without this, a Stop pressed mid-stream
# would discard any reasoning the proxy received but never cached.
if not finalized:
if self.config.verbose:
log_json(
"model streaming assistant messages", accumulator.messages()
)
stored = sum(
accumulator.store_reasoning(
self.reasoning_store,
ctx_scope,
cache_namespace,
prior_messages,
)
for ctx_scope, prior_messages in response_contexts
)
if self.config.verbose and stored:
LOG.info(
"stored %s streaming reasoning cache key(s) before exit",
stored,
)
return ProxyResponseResult(True, usage) return ProxyResponseResult(True, usage)
def _rewrite_sse_line( def _rewrite_sse_line(

View File

@ -13,6 +13,7 @@ from .reasoning_store import (
conversation_scope, conversation_scope,
message_signature, message_signature,
tool_call_ids, tool_call_ids,
tool_call_names,
tool_call_signature, tool_call_signature,
turn_context_signature, turn_context_signature,
) )
@ -383,6 +384,16 @@ def reasoning_lookup_keys(
for tool_call in (message.get("tool_calls") or []) for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict) if isinstance(tool_call, dict)
) )
keys.extend(
{
"kind": "tool_name",
"function_name": tool_name,
"key": f"scope:{scope}:tool_name:{tool_name}",
"portable": False,
"hit": False,
}
for tool_name in tool_call_names(message)
)
if cache_namespace and prior_messages is not None: if cache_namespace and prior_messages is not None:
turn_signature = turn_context_signature(prior_messages) turn_signature = turn_context_signature(prior_messages)
keys.append( keys.append(
@ -428,6 +439,20 @@ def reasoning_lookup_keys(
for tool_call in (message.get("tool_calls") or []) for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict) if isinstance(tool_call, dict)
) )
keys.extend(
{
"kind": "portable_tool_name",
"function_name": tool_name,
"key": (
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"tool_name:{tool_name}"
),
"turn_context_signature": turn_signature,
"portable": True,
"hit": False,
}
for tool_name in tool_call_names(message)
)
return keys return keys

View File

@ -715,5 +715,215 @@ class CrossModeAndModelTests(unittest.TestCase):
) )
class StopMidStreamingToolCallTests(unittest.TestCase):
"""Regression for the 'Stop pressed during streaming tool-call arguments'
scenario. When the upstream stream is cut off before the tool_call.id
chunk arrives, the cached message has tool_calls with no IDs. Cursor
synthesises its own ID for its bookkeeping, so the next request looks
nothing like the cached message at the id/signature/message-content
levels. The tool_name fallback is the only thing that can rescue this."""
def setUp(self) -> None:
self.store = ReasoningStore(":memory:")
def test_tool_name_fallback_restores_reasoning_when_id_missing(self) -> None:
# Turn 1 prepares an upstream request and caches a partial assistant
# message simulating a Stop before id arrived.
first_payload = {
"model": "deepseek-v4-pro",
"messages": [
{"role": "system", "content": "sys"},
{"role": "user", "content": "u1"},
],
}
first_prepared = prepare_upstream_request(
first_payload,
ProxyConfig(missing_reasoning_strategy="recover"),
self.store,
)
partial_response = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "",
"reasoning_content": "Need to grep.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "grep_search",
"arguments": '{"q":',
},
}
],
},
}
]
}
rewrite_response_body(
json.dumps(partial_response).encode("utf-8"),
original_model=first_prepared.original_model,
store=self.store,
request_messages=first_prepared.record_response_messages,
cache_namespace=first_prepared.cache_namespace,
scope=first_prepared.record_response_scope,
prior_messages=first_prepared.record_response_messages,
recording_contexts=first_prepared.record_response_contexts,
)
# Turn 2: Cursor saved the partial response with a synthesised id and
# its own best guess for the arguments.
second_payload = {
"model": "deepseek-v4-pro",
"messages": [
{"role": "system", "content": "sys"},
{"role": "user", "content": "u1"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "cursor-synth-1",
"type": "function",
"function": {
"name": "grep_search",
"arguments": '{"q":"foo"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "cursor-synth-1",
"content": "match",
},
{"role": "user", "content": "u2"},
],
}
second_prepared = prepare_upstream_request(
second_payload,
ProxyConfig(missing_reasoning_strategy="recover"),
self.store,
)
self.assertEqual(second_prepared.patched_reasoning_messages, 1)
self.assertEqual(second_prepared.missing_reasoning_messages, 0)
self.assertIsNone(second_prepared.recovery_notice)
self.assertEqual(
second_prepared.payload["messages"][2]["reasoning_content"],
"Need to grep.",
)
def test_tool_name_keys_are_isolated_across_distinct_turns(self) -> None:
# Two separate turns each interrupt with the same function name.
# The strict scope already differs (each turn has more prior
# messages) so the two cached entries should not collide and the
# second turn's reasoning must not leak into the first turn's slot.
config = ProxyConfig(missing_reasoning_strategy="recover")
def cache_partial(payload: dict, reasoning: str, args_fragment: str) -> dict:
prepared = prepare_upstream_request(payload, config, self.store)
response = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "",
"reasoning_content": reasoning,
"tool_calls": [
{
"type": "function",
"function": {
"name": "grep_search",
"arguments": args_fragment,
},
}
],
},
}
]
}
rewrite_response_body(
json.dumps(response).encode("utf-8"),
original_model=prepared.original_model,
store=self.store,
request_messages=prepared.record_response_messages,
cache_namespace=prepared.cache_namespace,
scope=prepared.record_response_scope,
prior_messages=prepared.record_response_messages,
recording_contexts=prepared.record_response_contexts,
)
return prepared
turn_a_payload = {
"model": "deepseek-v4-pro",
"messages": [
{"role": "system", "content": "sys"},
{"role": "user", "content": "u-A"},
],
}
cache_partial(turn_a_payload, "Reasoning A.", '{"q":')
turn_b_payload = {
"model": "deepseek-v4-pro",
"messages": [
{"role": "system", "content": "sys"},
{"role": "user", "content": "u-A"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "synth-A",
"type": "function",
"function": {
"name": "grep_search",
"arguments": '{"q":"a"}',
},
}
],
},
{"role": "tool", "tool_call_id": "synth-A", "content": "ra"},
{"role": "user", "content": "u-B"},
],
}
cache_partial(turn_b_payload, "Reasoning B.", '{"q":')
# Now look up turn A's assistant under its own scope. It must still
# return Reasoning A and never Reasoning B (no scope collision).
recovery_payload = {
"model": "deepseek-v4-pro",
"messages": [
{"role": "system", "content": "sys"},
{"role": "user", "content": "u-A"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "synth-A",
"type": "function",
"function": {
"name": "grep_search",
"arguments": '{"q":"a"}',
},
}
],
},
{"role": "tool", "tool_call_id": "synth-A", "content": "ra"},
{"role": "user", "content": "u-A2"},
],
}
prepared = prepare_upstream_request(recovery_payload, config, self.store)
self.assertEqual(
prepared.payload["messages"][2]["reasoning_content"],
"Reasoning A.",
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()