fix: prevent recovery cascade and improve Stop-scenario reasoning lookup (#25)
parent
7bdf177e0f
commit
4eebf78351
|
|
@ -46,6 +46,17 @@ def tool_call_ids(message: dict[str, Any]) -> list[str]:
|
|||
return ids
|
||||
|
||||
|
||||
def tool_call_names(message: dict[str, Any]) -> list[str]:
|
||||
names: list[str] = []
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict) and function.get("name"):
|
||||
names.append(str(function["name"]))
|
||||
return names
|
||||
|
||||
|
||||
def message_signature(message: dict[str, Any]) -> str:
|
||||
tool_calls = [
|
||||
normalize_tool_call(tool_call)
|
||||
|
|
@ -125,6 +136,13 @@ def scoped_reasoning_keys(message: dict[str, Any], scope: str) -> list[str]:
|
|||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
# Recovery-of-last-resort key. Catches the case where a streaming response
|
||||
# was interrupted (user pressed Stop) before the tool_call.id chunk arrived,
|
||||
# so neither tool_call_id nor tool_call_signature (which canonicalizes
|
||||
# arguments) survives the round-trip through Cursor's transcript.
|
||||
keys.extend(
|
||||
f"scope:{scope}:tool_name:{tool_name}" for tool_name in tool_call_names(message)
|
||||
)
|
||||
return keys
|
||||
|
||||
|
||||
|
|
@ -152,6 +170,10 @@ def portable_reasoning_keys(
|
|||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
keys.extend(
|
||||
f"namespace:{cache_namespace}:turn:{turn_signature}:" f"tool_name:{tool_name}"
|
||||
for tool_name in tool_call_names(message)
|
||||
)
|
||||
return keys
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -679,54 +679,64 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
)
|
||||
finalized = False
|
||||
pending_recovery_notice = recovery_notice
|
||||
while True:
|
||||
try:
|
||||
line = response.readline()
|
||||
except (HTTPException, OSError) as exc:
|
||||
LOG.warning("upstream streaming response read failed: %s", exc)
|
||||
return ProxyResponseResult(False, usage)
|
||||
if not line:
|
||||
break
|
||||
(
|
||||
rewritten,
|
||||
finalized,
|
||||
pending_recovery_notice,
|
||||
chunk_usage,
|
||||
) = self._rewrite_sse_line(
|
||||
line,
|
||||
original_model,
|
||||
accumulator,
|
||||
cache_namespace,
|
||||
response_contexts,
|
||||
display_adapter,
|
||||
pending_recovery_notice,
|
||||
trace,
|
||||
)
|
||||
if chunk_usage is not None:
|
||||
usage = chunk_usage
|
||||
if trace is not None:
|
||||
trace.record_stream_chunk(line, rewritten)
|
||||
if not self._write_to_client(
|
||||
rewritten, "sending streaming response chunk", flush=True
|
||||
):
|
||||
return ProxyResponseResult(False, usage)
|
||||
if finalized:
|
||||
break
|
||||
|
||||
if not finalized:
|
||||
if self.config.verbose:
|
||||
log_json("model streaming assistant messages", accumulator.messages())
|
||||
stored = sum(
|
||||
accumulator.store_reasoning(
|
||||
self.reasoning_store,
|
||||
scope,
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
line = response.readline()
|
||||
except (HTTPException, OSError) as exc:
|
||||
LOG.warning("upstream streaming response read failed: %s", exc)
|
||||
return ProxyResponseResult(False, usage)
|
||||
if not line:
|
||||
break
|
||||
(
|
||||
rewritten,
|
||||
finalized,
|
||||
pending_recovery_notice,
|
||||
chunk_usage,
|
||||
) = self._rewrite_sse_line(
|
||||
line,
|
||||
original_model,
|
||||
accumulator,
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
response_contexts,
|
||||
display_adapter,
|
||||
pending_recovery_notice,
|
||||
trace,
|
||||
)
|
||||
for scope, prior_messages in response_contexts
|
||||
)
|
||||
if self.config.verbose and stored:
|
||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||
if chunk_usage is not None:
|
||||
usage = chunk_usage
|
||||
if trace is not None:
|
||||
trace.record_stream_chunk(line, rewritten)
|
||||
if not self._write_to_client(
|
||||
rewritten, "sending streaming response chunk", flush=True
|
||||
):
|
||||
return ProxyResponseResult(False, usage)
|
||||
if finalized:
|
||||
break
|
||||
finally:
|
||||
# Store partial reasoning whenever the stream exits without
|
||||
# the upstream's [DONE] terminator (client disconnect, upstream
|
||||
# read failure, exception). Without this, a Stop pressed mid-stream
|
||||
# would discard any reasoning the proxy received but never cached.
|
||||
if not finalized:
|
||||
if self.config.verbose:
|
||||
log_json(
|
||||
"model streaming assistant messages", accumulator.messages()
|
||||
)
|
||||
stored = sum(
|
||||
accumulator.store_reasoning(
|
||||
self.reasoning_store,
|
||||
ctx_scope,
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
for ctx_scope, prior_messages in response_contexts
|
||||
)
|
||||
if self.config.verbose and stored:
|
||||
LOG.info(
|
||||
"stored %s streaming reasoning cache key(s) before exit",
|
||||
stored,
|
||||
)
|
||||
return ProxyResponseResult(True, usage)
|
||||
|
||||
def _rewrite_sse_line(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from .reasoning_store import (
|
|||
conversation_scope,
|
||||
message_signature,
|
||||
tool_call_ids,
|
||||
tool_call_names,
|
||||
tool_call_signature,
|
||||
turn_context_signature,
|
||||
)
|
||||
|
|
@ -383,6 +384,16 @@ def reasoning_lookup_keys(
|
|||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
keys.extend(
|
||||
{
|
||||
"kind": "tool_name",
|
||||
"function_name": tool_name,
|
||||
"key": f"scope:{scope}:tool_name:{tool_name}",
|
||||
"portable": False,
|
||||
"hit": False,
|
||||
}
|
||||
for tool_name in tool_call_names(message)
|
||||
)
|
||||
if cache_namespace and prior_messages is not None:
|
||||
turn_signature = turn_context_signature(prior_messages)
|
||||
keys.append(
|
||||
|
|
@ -428,6 +439,20 @@ def reasoning_lookup_keys(
|
|||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
keys.extend(
|
||||
{
|
||||
"kind": "portable_tool_name",
|
||||
"function_name": tool_name,
|
||||
"key": (
|
||||
f"namespace:{cache_namespace}:turn:{turn_signature}:"
|
||||
f"tool_name:{tool_name}"
|
||||
),
|
||||
"turn_context_signature": turn_signature,
|
||||
"portable": True,
|
||||
"hit": False,
|
||||
}
|
||||
for tool_name in tool_call_names(message)
|
||||
)
|
||||
return keys
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -715,5 +715,215 @@ class CrossModeAndModelTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
class StopMidStreamingToolCallTests(unittest.TestCase):
|
||||
"""Regression for the 'Stop pressed during streaming tool-call arguments'
|
||||
scenario. When the upstream stream is cut off before the tool_call.id
|
||||
chunk arrives, the cached message has tool_calls with no IDs. Cursor
|
||||
synthesises its own ID for its bookkeeping, so the next request looks
|
||||
nothing like the cached message at the id/signature/message-content
|
||||
levels. The tool_name fallback is the only thing that can rescue this."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.store = ReasoningStore(":memory:")
|
||||
|
||||
def test_tool_name_fallback_restores_reasoning_when_id_missing(self) -> None:
|
||||
# Turn 1 prepares an upstream request and caches a partial assistant
|
||||
# message simulating a Stop before id arrived.
|
||||
first_payload = {
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "u1"},
|
||||
],
|
||||
}
|
||||
first_prepared = prepare_upstream_request(
|
||||
first_payload,
|
||||
ProxyConfig(missing_reasoning_strategy="recover"),
|
||||
self.store,
|
||||
)
|
||||
|
||||
partial_response = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Need to grep.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "grep_search",
|
||||
"arguments": '{"q":',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
rewrite_response_body(
|
||||
json.dumps(partial_response).encode("utf-8"),
|
||||
original_model=first_prepared.original_model,
|
||||
store=self.store,
|
||||
request_messages=first_prepared.record_response_messages,
|
||||
cache_namespace=first_prepared.cache_namespace,
|
||||
scope=first_prepared.record_response_scope,
|
||||
prior_messages=first_prepared.record_response_messages,
|
||||
recording_contexts=first_prepared.record_response_contexts,
|
||||
)
|
||||
|
||||
# Turn 2: Cursor saved the partial response with a synthesised id and
|
||||
# its own best guess for the arguments.
|
||||
second_payload = {
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "u1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "cursor-synth-1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "grep_search",
|
||||
"arguments": '{"q":"foo"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "cursor-synth-1",
|
||||
"content": "match",
|
||||
},
|
||||
{"role": "user", "content": "u2"},
|
||||
],
|
||||
}
|
||||
second_prepared = prepare_upstream_request(
|
||||
second_payload,
|
||||
ProxyConfig(missing_reasoning_strategy="recover"),
|
||||
self.store,
|
||||
)
|
||||
|
||||
self.assertEqual(second_prepared.patched_reasoning_messages, 1)
|
||||
self.assertEqual(second_prepared.missing_reasoning_messages, 0)
|
||||
self.assertIsNone(second_prepared.recovery_notice)
|
||||
self.assertEqual(
|
||||
second_prepared.payload["messages"][2]["reasoning_content"],
|
||||
"Need to grep.",
|
||||
)
|
||||
|
||||
def test_tool_name_keys_are_isolated_across_distinct_turns(self) -> None:
|
||||
# Two separate turns each interrupt with the same function name.
|
||||
# The strict scope already differs (each turn has more prior
|
||||
# messages) so the two cached entries should not collide and the
|
||||
# second turn's reasoning must not leak into the first turn's slot.
|
||||
config = ProxyConfig(missing_reasoning_strategy="recover")
|
||||
|
||||
def cache_partial(payload: dict, reasoning: str, args_fragment: str) -> dict:
|
||||
prepared = prepare_upstream_request(payload, config, self.store)
|
||||
response = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": reasoning,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "grep_search",
|
||||
"arguments": args_fragment,
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
rewrite_response_body(
|
||||
json.dumps(response).encode("utf-8"),
|
||||
original_model=prepared.original_model,
|
||||
store=self.store,
|
||||
request_messages=prepared.record_response_messages,
|
||||
cache_namespace=prepared.cache_namespace,
|
||||
scope=prepared.record_response_scope,
|
||||
prior_messages=prepared.record_response_messages,
|
||||
recording_contexts=prepared.record_response_contexts,
|
||||
)
|
||||
return prepared
|
||||
|
||||
turn_a_payload = {
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "u-A"},
|
||||
],
|
||||
}
|
||||
cache_partial(turn_a_payload, "Reasoning A.", '{"q":')
|
||||
|
||||
turn_b_payload = {
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "u-A"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "synth-A",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "grep_search",
|
||||
"arguments": '{"q":"a"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "synth-A", "content": "ra"},
|
||||
{"role": "user", "content": "u-B"},
|
||||
],
|
||||
}
|
||||
cache_partial(turn_b_payload, "Reasoning B.", '{"q":')
|
||||
|
||||
# Now look up turn A's assistant under its own scope. It must still
|
||||
# return Reasoning A and never Reasoning B (no scope collision).
|
||||
recovery_payload = {
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "u-A"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "synth-A",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "grep_search",
|
||||
"arguments": '{"q":"a"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "synth-A", "content": "ra"},
|
||||
{"role": "user", "content": "u-A2"},
|
||||
],
|
||||
}
|
||||
prepared = prepare_upstream_request(recovery_payload, config, self.store)
|
||||
self.assertEqual(
|
||||
prepared.payload["messages"][2]["reasoning_content"],
|
||||
"Reasoning A.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue