fix(cache): fix reasoning recovery during mode/model switch (#28)

main
Yixing Lao 2026-04-29 22:32:13 +08:00 committed by GitHub
parent cf3a9d3875
commit 5f14da32c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 851 additions and 72 deletions

View File

@ -62,6 +62,13 @@ def message_signature(message: dict[str, Any]) -> str:
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
def _sha256_json(payload: Any) -> str:
canonical = json.dumps(
payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")
)
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
def canonical_scope_message(message: dict[str, Any]) -> dict[str, Any]:
canonical: dict[str, Any] = {"role": message.get("role")}
for key in ("content", "name", "tool_call_id", "prefix"):
@ -81,10 +88,71 @@ def conversation_scope(messages: list[dict[str, Any]], namespace: str = "") -> s
payload: Any = scope_messages
if namespace:
payload = {"namespace": namespace, "messages": scope_messages}
canonical = json.dumps(
payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")
return _sha256_json(payload)
def turn_context_signature(prior_messages: list[dict[str, Any]]) -> str:
last_user_index = next(
(
index
for index in range(len(prior_messages) - 1, -1, -1)
if prior_messages[index].get("role") == "user"
),
-1,
)
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
start_index = 0
if last_user_index != -1:
start_index = last_user_index
while start_index > 0 and prior_messages[start_index - 1].get("role") == "user":
start_index -= 1
context_messages = [
canonical_scope_message(message)
for message in prior_messages[start_index:]
if message.get("role") != "system"
]
return _sha256_json(context_messages)
def scoped_reasoning_keys(message: dict[str, Any], scope: str) -> list[str]:
keys = [f"scope:{scope}:signature:{message_signature(message)}"]
keys.extend(
f"scope:{scope}:tool_call:{tool_call_id}"
for tool_call_id in tool_call_ids(message)
)
keys.extend(
f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}"
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
return keys
def portable_reasoning_keys(
message: dict[str, Any],
cache_namespace: str,
prior_messages: list[dict[str, Any]],
) -> list[str]:
if not cache_namespace:
return []
turn_signature = turn_context_signature(prior_messages)
keys = [
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"signature:{message_signature(message)}"
]
keys.extend(
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"tool_call:{tool_call_id}"
for tool_call_id in tool_call_ids(message)
)
keys.extend(
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"tool_call_signature:{tool_call_signature(tool_call)}"
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
return keys
class ReasoningStore:
@ -155,45 +223,65 @@ class ReasoningStore:
return None
return str(row[0])
def store_assistant_message(self, message: dict[str, Any], scope: str) -> int:
def store_assistant_message(
self,
message: dict[str, Any],
scope: str,
cache_namespace: str = "",
prior_messages: list[dict[str, Any]] | None = None,
) -> int:
if message.get("role") != "assistant":
return 0
reasoning = message.get("reasoning_content")
if not isinstance(reasoning, str):
return 0
keys = [f"scope:{scope}:signature:{message_signature(message)}"]
keys.extend(
f"scope:{scope}:tool_call:{tool_call_id}"
for tool_call_id in tool_call_ids(message)
)
keys.extend(
f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}"
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
keys = scoped_reasoning_keys(message, scope)
if prior_messages is not None:
keys.extend(
portable_reasoning_keys(message, cache_namespace, prior_messages)
)
keys = list(dict.fromkeys(keys))
for key in keys:
self.put(key, reasoning, message)
return len(keys)
def lookup_for_message(self, message: dict[str, Any], scope: str) -> str | None:
reasoning = self.get(f"scope:{scope}:signature:{message_signature(message)}")
if reasoning is not None:
return reasoning
for tool_call_id in tool_call_ids(message):
reasoning = self.get(f"scope:{scope}:tool_call:{tool_call_id}")
if reasoning is not None:
return reasoning
for tool_call in message.get("tool_calls") or []:
if not isinstance(tool_call, dict):
continue
reasoning = self.get(
f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}"
def lookup_for_message(
self,
message: dict[str, Any],
scope: str,
cache_namespace: str = "",
prior_messages: list[dict[str, Any]] | None = None,
) -> str | None:
keys = scoped_reasoning_keys(message, scope)
if prior_messages is not None:
keys.extend(
portable_reasoning_keys(message, cache_namespace, prior_messages)
)
for key in keys:
reasoning = self.get(key)
if reasoning is not None:
return reasoning
return None
def backfill_portable_aliases(
self,
message: dict[str, Any],
reasoning: str,
cache_namespace: str,
prior_messages: list[dict[str, Any]],
) -> int:
if not isinstance(reasoning, str):
return 0
keys = portable_reasoning_keys(message, cache_namespace, prior_messages)
if not keys:
return 0
message_with_reasoning = dict(message)
message_with_reasoning["reasoning_content"] = reasoning
for key in dict.fromkeys(keys):
self.put(key, reasoning, message_with_reasoning)
return len(keys)
def clear(self) -> int:
with self._lock:
row = self._conn.execute("SELECT COUNT(*) FROM reasoning_cache").fetchone()

View File

@ -160,16 +160,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
)
if prepared.recovered_reasoning_messages:
if prepared.recovery_notice:
LOG.warning(
(
"recovered request because cached reasoning_content was "
"unavailable for %s assistant message(s); omitted %s "
"older message(s) from forwarded history and will show "
"a Cursor notice"
),
prepared.recovered_reasoning_messages,
prepared.recovery_dropped_messages,
)
LOG.warning("refreshed reasoning_content history")
else:
LOG.info(
(
@ -305,7 +296,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
prepared.payload["messages"],
prepared.cache_namespace,
prepared.recovery_notice,
trace,
trace=trace,
record_response_scope=prepared.record_response_scope,
record_response_messages=prepared.record_response_messages,
record_response_contexts=prepared.record_response_contexts,
)
else:
sent_response = self._proxy_regular_response(
@ -314,7 +308,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
prepared.payload["messages"],
prepared.cache_namespace,
prepared.recovery_notice,
trace,
trace=trace,
record_response_scope=prepared.record_response_scope,
record_response_messages=prepared.record_response_messages,
record_response_contexts=prepared.record_response_contexts,
)
if not sent_response:
self._finish_trace(
@ -549,6 +546,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
cache_namespace: str,
recovery_notice: str | None = None,
trace: TraceRequest | None = None,
record_response_scope: str | None = None,
record_response_messages: list[dict[str, Any]] | None = None,
record_response_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
) -> bool:
body = read_response_body(response)
upstream_body = body
@ -560,6 +560,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
request_messages,
cache_namespace,
content_prefix=recovery_notice,
scope=record_response_scope,
prior_messages=record_response_messages,
recording_contexts=record_response_contexts,
)
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
LOG.warning("failed to rewrite upstream JSON response: %s", exc)
@ -611,6 +614,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
cache_namespace: str,
recovery_notice: str | None = None,
trace: TraceRequest | None = None,
record_response_scope: str | None = None,
record_response_messages: list[dict[str, Any]] | None = None,
record_response_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
) -> bool:
if trace is not None:
trace.record_upstream_response(
@ -645,7 +651,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if self.config.cursor_display_reasoning
else None
)
scope = conversation_scope(request_messages, cache_namespace)
scope = (
record_response_scope
if record_response_scope is not None
else conversation_scope(request_messages, cache_namespace)
)
response_prior_messages = (
record_response_messages
if record_response_messages is not None
else request_messages
)
response_contexts = (
record_response_contexts
if record_response_contexts is not None
else [(scope, response_prior_messages)]
)
finalized = False
pending_recovery_notice = recovery_notice
while True:
@ -660,7 +680,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
line,
original_model,
accumulator,
scope,
cache_namespace,
response_contexts,
display_adapter,
pending_recovery_notice,
trace,
@ -677,7 +698,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if not finalized:
if self.config.verbose:
log_json("model streaming assistant messages", accumulator.messages())
stored = accumulator.store_reasoning(self.reasoning_store, scope)
stored = sum(
accumulator.store_reasoning(
self.reasoning_store,
scope,
cache_namespace,
prior_messages,
)
for scope, prior_messages in response_contexts
)
if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored)
return True
@ -687,7 +716,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
line: bytes,
original_model: str,
accumulator: StreamAccumulator,
scope: str,
cache_namespace: str,
response_contexts: list[tuple[str, list[dict[str, Any]]]],
display_adapter: CursorReasoningDisplayAdapter | None,
recovery_notice: str | None = None,
trace: TraceRequest | None = None,
@ -700,7 +730,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if data == b"[DONE]":
if self.config.verbose:
log_json("model streaming assistant messages", accumulator.messages())
stored = accumulator.store_reasoning(self.reasoning_store, scope)
stored = sum(
accumulator.store_reasoning(
self.reasoning_store,
scope,
cache_namespace,
prior_messages,
)
for scope, prior_messages in response_contexts
)
if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored)
prefix = b""
@ -728,7 +766,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if recovery_notice and inject_recovery_notice(chunk, recovery_notice):
recovery_notice = None
accumulator.ingest_chunk(chunk)
stored = accumulator.store_ready_reasoning(self.reasoning_store, scope)
stored = sum(
accumulator.store_ready_reasoning(
self.reasoning_store,
scope,
cache_namespace,
prior_messages,
)
for scope, prior_messages in response_contexts
)
if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored)
if trace is not None:

View File

@ -35,7 +35,7 @@ class StreamingChoice:
class StreamAccumulator:
def __init__(self) -> None:
self.choices: dict[int, StreamingChoice] = {}
self._stored_choices: dict[int, str] = {}
self._stored_choices: dict[tuple[int, str], str] = {}
def ingest_chunk(self, chunk: dict[str, Any]) -> None:
choices = chunk.get("choices")
@ -70,26 +70,70 @@ class StreamAccumulator:
self._merge_tool_call_deltas(choice, delta.get("tool_calls"))
def store_reasoning(self, store: ReasoningStore, scope: str) -> int:
def store_reasoning(
self,
store: ReasoningStore,
scope: str,
cache_namespace: str = "",
prior_messages: list[dict[str, Any]] | None = None,
) -> int:
stored = 0
for index, choice in self.choices.items():
stored += self._store_choice(index, choice, store, scope)
stored += self._store_choice(
index, choice, store, scope, "final", cache_namespace, prior_messages
)
return stored
def store_finished_reasoning(self, store: ReasoningStore, scope: str) -> int:
def store_finished_reasoning(
self,
store: ReasoningStore,
scope: str,
cache_namespace: str = "",
prior_messages: list[dict[str, Any]] | None = None,
) -> int:
stored = 0
for index, choice in self.choices.items():
if choice.finish_reason is not None:
stored += self._store_choice(index, choice, store, scope, "final")
stored += self._store_choice(
index,
choice,
store,
scope,
"final",
cache_namespace,
prior_messages,
)
return stored
def store_ready_reasoning(self, store: ReasoningStore, scope: str) -> int:
def store_ready_reasoning(
self,
store: ReasoningStore,
scope: str,
cache_namespace: str = "",
prior_messages: list[dict[str, Any]] | None = None,
) -> int:
stored = 0
for index, choice in self.choices.items():
if choice.finish_reason is not None:
stored += self._store_choice(index, choice, store, scope, "final")
stored += self._store_choice(
index,
choice,
store,
scope,
"final",
cache_namespace,
prior_messages,
)
elif self._has_identified_tool_calls(choice):
stored += self._store_choice(index, choice, store, scope, "tool_call")
stored += self._store_choice(
index,
choice,
store,
scope,
"tool_call",
cache_namespace,
prior_messages,
)
return stored
def messages(self) -> list[dict[str, Any]]:
@ -141,14 +185,22 @@ class StreamAccumulator:
store: ReasoningStore,
scope: str,
stage: str = "final",
cache_namespace: str = "",
prior_messages: list[dict[str, Any]] | None = None,
) -> int:
stage_rank = {"tool_call": 1, "final": 2}
previous_stage = self._stored_choices.get(index)
storage_key = (index, scope)
previous_stage = self._stored_choices.get(storage_key)
if stage_rank.get(previous_stage or "", 0) >= stage_rank.get(stage, 0):
return 0
stored = store.store_assistant_message(choice.to_message(), scope)
stored = store.store_assistant_message(
choice.to_message(),
scope,
cache_namespace,
prior_messages,
)
if stored:
self._stored_choices[index] = stage
self._stored_choices[storage_key] = stage
return stored
def _has_identified_tool_calls(self, choice: StreamingChoice) -> bool:

View File

@ -105,7 +105,10 @@ def message_summaries(payload: dict[str, Any]) -> list[dict[str, Any]]:
len(reasoning) if isinstance(reasoning, str) else 0
),
"has_recovery_notice": content.startswith(
"[deepseek-cursor-proxy] Recovered"
(
"[deepseek-cursor-proxy] Refreshed reasoning_content history.",
"[deepseek-cursor-proxy] Recovered",
)
),
}
summaries.append(summary)
@ -231,6 +234,12 @@ class TraceRequest:
"recovered_reasoning_messages": prepared.recovered_reasoning_messages,
"recovery_dropped_messages": prepared.recovery_dropped_messages,
"recovery_notice": prepared.recovery_notice,
"record_response_scope": prepared.record_response_scope,
"record_response_scopes": [
scope for scope, _messages in prepared.record_response_contexts
],
"continued_recovery_boundary": prepared.continued_recovery_boundary,
"retired_prefix_messages": prepared.retired_prefix_messages,
"reasoning_diagnostics": prepared.reasoning_diagnostics,
"recovery_steps": prepared.recovery_steps,
"upstream_request_summary": payload_summary(prepared.payload),

View File

@ -13,6 +13,7 @@ from .reasoning_store import (
message_signature,
tool_call_ids,
tool_call_signature,
turn_context_signature,
)
@ -73,10 +74,7 @@ CURSOR_THINKING_BLOCK_RE = re.compile(
re.IGNORECASE,
)
RECOVERY_NOTICE_TEXT = (
"[deepseek-cursor-proxy] Recovered this DeepSeek chat because older "
"tool-call reasoning was unavailable; continuing with recent context only."
)
RECOVERY_NOTICE_TEXT = "[deepseek-cursor-proxy] Refreshed reasoning_content history."
LEGACY_RECOVERY_NOTICE_TEXT = (
"Note: recovered this DeepSeek chat because older tool-call reasoning "
"was unavailable; continuing with recent context only."
@ -101,8 +99,15 @@ class PreparedRequest:
recovered_reasoning_messages: int = 0
recovery_dropped_messages: int = 0
recovery_notice: str | None = None
record_response_scope: str | None = None
record_response_messages: list[dict[str, Any]] = field(default_factory=list)
record_response_contexts: list[tuple[str, list[dict[str, Any]]]] = field(
default_factory=list
)
reasoning_diagnostics: list[dict[str, Any]] = field(default_factory=list)
recovery_steps: list[dict[str, Any]] = field(default_factory=list)
continued_recovery_boundary: bool = False
retired_prefix_messages: int = 0
def normalize_reasoning_effort(value: Any) -> str:
@ -260,7 +265,12 @@ def normalize_message(
)
lookup_scope = conversation_scope(prior_messages, cache_namespace)
lookup_keys = (
reasoning_lookup_keys(normalized, lookup_scope)
reasoning_lookup_keys(
normalized,
lookup_scope,
cache_namespace,
prior_messages,
)
if needs_reasoning
else []
)
@ -273,6 +283,13 @@ def normalize_message(
hit_kind = lookup_key["kind"]
normalized["reasoning_content"] = restored
patched = True
if not lookup_key.get("portable"):
store.backfill_portable_aliases(
normalized,
restored,
cache_namespace,
prior_messages,
)
break
if needs_reasoning and not patched:
missing = True
@ -315,11 +332,14 @@ def normalize_message(
def reasoning_lookup_keys(
message: dict[str, Any],
scope: str,
cache_namespace: str = "",
prior_messages: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
keys = [
{
"kind": "message_signature",
"key": f"scope:{scope}:signature:{message_signature(message)}",
"portable": False,
"hit": False,
}
]
@ -328,6 +348,7 @@ def reasoning_lookup_keys(
"kind": "tool_call_id",
"tool_call_id": tool_call_id,
"key": f"scope:{scope}:tool_call:{tool_call_id}",
"portable": False,
"hit": False,
}
for tool_call_id in tool_call_ids(message)
@ -340,11 +361,57 @@ def reasoning_lookup_keys(
f"scope:{scope}:tool_call_signature:"
f"{tool_call_signature(tool_call)}"
),
"portable": False,
"hit": False,
}
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
if cache_namespace and prior_messages is not None:
turn_signature = turn_context_signature(prior_messages)
keys.append(
{
"kind": "portable_message_signature",
"key": (
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"signature:{message_signature(message)}"
),
"turn_context_signature": turn_signature,
"portable": True,
"hit": False,
}
)
keys.extend(
{
"kind": "portable_tool_call_id",
"tool_call_id": tool_call_id,
"key": (
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"tool_call:{tool_call_id}"
),
"turn_context_signature": turn_signature,
"portable": True,
"hit": False,
}
for tool_call_id in tool_call_ids(message)
)
keys.extend(
{
"kind": "portable_tool_call_signature",
"function_name": str(
(tool_call.get("function") or {}).get("name") or ""
),
"key": (
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"tool_call_signature:{tool_call_signature(tool_call)}"
),
"turn_context_signature": turn_signature,
"portable": True,
"hit": False,
}
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
return keys
@ -399,6 +466,52 @@ def leading_system_messages(messages: list[dict[str, Any]]) -> list[dict[str, An
return leading_messages
def active_messages_from_recovery_boundary(
messages: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], int, dict[str, Any]] | None:
recovery_boundary_index = next(
(
index
for index in range(len(messages) - 1, -1, -1)
if has_recovery_notice(messages[index])
),
-1,
)
if recovery_boundary_index == -1:
return None
context_user_index = next(
(
index
for index in range(recovery_boundary_index - 1, -1, -1)
if messages[index].get("role") == "user"
),
-1,
)
leading_messages = leading_system_messages(messages)
recovered_tail = []
if context_user_index != -1:
recovered_tail.append(messages[context_user_index])
recovered_tail.extend(messages[recovery_boundary_index:])
active_messages = [
*leading_messages,
{"role": "system", "content": RECOVERY_SYSTEM_CONTENT},
*recovered_tail,
]
kept_context_messages = 1 if context_user_index != -1 else 0
retired_messages = (
recovery_boundary_index - len(leading_messages) - kept_context_messages
)
retired_messages = max(retired_messages, 0)
step = {
"strategy": "continued_recovery_boundary",
"recovery_boundary_index": recovery_boundary_index,
"context_user_index": context_user_index,
"retired_prefix_messages": retired_messages,
}
return active_messages, retired_messages, step
def recover_messages_from_missing_reasoning(
messages: list[dict[str, Any]],
missing_indexes: list[int],
@ -510,6 +623,12 @@ def upstream_model_for(original_model: str, config: ProxyConfig) -> str:
return config.upstream_model
def reasoning_model_family(upstream_model: str) -> str:
if upstream_model in {"deepseek-v4-pro", "deepseek-v4-flash"}:
return "deepseek-v4"
return upstream_model
def reasoning_cache_namespace(
config: ProxyConfig,
upstream_model: str,
@ -522,7 +641,7 @@ def reasoning_cache_namespace(
auth_hash = hashlib.sha256(authorization.encode("utf-8")).hexdigest()
payload = {
"base_url": config.upstream_base_url,
"model": upstream_model,
"model": reasoning_model_family(upstream_model),
"thinking": thinking,
"reasoning_effort": reasoning_effort,
"authorization_hash": auth_hash,
@ -533,6 +652,22 @@ def reasoning_cache_namespace(
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
def response_recording_contexts(
*items: tuple[str, list[dict[str, Any]]] | None,
) -> list[tuple[str, list[dict[str, Any]]]]:
contexts: list[tuple[str, list[dict[str, Any]]]] = []
seen: set[str] = set()
for item in items:
if item is None:
continue
scope, messages = item
if scope in seen:
continue
seen.add(scope)
contexts.append((scope, messages))
return contexts
def prepare_upstream_request(
payload: dict[str, Any],
config: ProxyConfig,
@ -596,19 +731,40 @@ def prepare_upstream_request(
prepared.get("reasoning_effort"),
authorization,
)
pre_repair_messages, _, _, _ = normalize_messages(
payload.get("messages"),
None,
cache_namespace,
repair_reasoning=False,
keep_reasoning=not thinking_disabled,
)
record_response_messages = pre_repair_messages
record_response_scope = conversation_scope(
record_response_messages, cache_namespace
)
messages_for_repair = pre_repair_messages
continued_recovery_boundary = False
retired_prefix_messages = 0
recovered_count = 0
recovery_dropped_messages = 0
recovery_notice = None
recovery_steps: list[dict[str, Any]] = []
if thinking_enabled and config.missing_reasoning_strategy == "recover":
boundary = active_messages_from_recovery_boundary(pre_repair_messages)
if boundary is not None:
messages_for_repair, retired_prefix_messages, boundary_step = boundary
continued_recovery_boundary = True
recovery_steps.append(boundary_step)
messages, patched_count, missing_indexes, reasoning_diagnostics = (
normalize_messages(
payload.get("messages"),
messages_for_repair,
store,
cache_namespace,
repair_reasoning=thinking_enabled,
keep_reasoning=not thinking_disabled,
)
)
recovered_count = 0
recovery_dropped_messages = 0
recovery_notice = None
recovery_steps: list[dict[str, Any]] = []
while missing_indexes and config.missing_reasoning_strategy == "recover":
recovered_messages, dropped_messages, notice, recovery_step = (
recover_messages_from_missing_reasoning(messages, missing_indexes)
@ -634,6 +790,11 @@ def prepare_upstream_request(
)
reasoning_diagnostics.extend(latest_diagnostics)
prepared["messages"] = messages
active_record_response_scope = conversation_scope(messages, cache_namespace)
record_response_contexts = response_recording_contexts(
(record_response_scope, record_response_messages),
(active_record_response_scope, messages),
)
return PreparedRequest(
payload=prepared,
@ -645,8 +806,13 @@ def prepare_upstream_request(
recovered_reasoning_messages=recovered_count,
recovery_dropped_messages=recovery_dropped_messages,
recovery_notice=recovery_notice,
record_response_scope=record_response_scope,
record_response_messages=record_response_messages,
record_response_contexts=record_response_contexts,
reasoning_diagnostics=reasoning_diagnostics,
recovery_steps=recovery_steps,
continued_recovery_boundary=continued_recovery_boundary,
retired_prefix_messages=retired_prefix_messages,
)
@ -655,6 +821,9 @@ def record_response_reasoning(
store: ReasoningStore | None,
request_messages: list[dict[str, Any]],
cache_namespace: str = "",
scope: str | None = None,
prior_messages: list[dict[str, Any]] | None = None,
recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
) -> int:
if store is None:
return 0
@ -662,13 +831,28 @@ def record_response_reasoning(
choices = response_payload.get("choices")
if not isinstance(choices, list):
return stored
scope = conversation_scope(request_messages, cache_namespace)
if recording_contexts is None:
response_scope = (
scope
if scope is not None
else conversation_scope(request_messages, cache_namespace)
)
response_prior_messages = (
prior_messages if prior_messages is not None else request_messages
)
recording_contexts = [(response_scope, response_prior_messages)]
for choice in choices:
if not isinstance(choice, dict):
continue
message = choice.get("message")
if isinstance(message, dict):
stored += store.store_assistant_message(message, scope)
for response_scope, response_prior_messages in recording_contexts:
stored += store.store_assistant_message(
message,
response_scope,
cache_namespace,
response_prior_messages,
)
return stored
@ -679,13 +863,22 @@ def rewrite_response_body(
request_messages: list[dict[str, Any]],
cache_namespace: str = "",
content_prefix: str | None = None,
scope: str | None = None,
prior_messages: list[dict[str, Any]] | None = None,
recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
) -> bytes:
response_payload = json.loads(body.decode("utf-8"))
if isinstance(response_payload, dict):
if content_prefix:
prefix_response_content(response_payload, content_prefix)
record_response_reasoning(
response_payload, store, request_messages, cache_namespace
response_payload,
store,
request_messages,
cache_namespace,
scope=scope,
prior_messages=prior_messages,
recording_contexts=recording_contexts,
)
if "model" in response_payload:
response_payload["model"] = original_model

View File

@ -738,7 +738,7 @@ class ProxyEndToEndTests(unittest.TestCase):
{"role": "user", "content": "Thanks, now continue."},
)
self.assertIn(
"cached reasoning_content was unavailable",
"refreshed reasoning_content history",
"\n".join(captured.output),
)

View File

@ -116,6 +116,50 @@ class StreamAccumulatorTests(unittest.TestCase):
self.assertEqual(accumulator.store_reasoning(store, scope), 0)
store.close()
def test_stores_same_streaming_choice_under_multiple_scopes(self) -> None:
store = ReasoningStore(":memory:")
accumulator = StreamAccumulator()
accumulator.ingest_chunk(
{
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"reasoning_content": "Need a tool.",
"tool_calls": [
{
"index": 0,
"id": "call_stream",
"type": "function",
"function": {
"name": "lookup",
"arguments": "{}",
},
}
],
},
"finish_reason": "tool_calls",
}
]
}
)
first_scope = conversation_scope([{"role": "user", "content": "full"}])
second_scope = conversation_scope([{"role": "user", "content": "active"}])
first_stored = accumulator.store_finished_reasoning(store, first_scope)
second_stored = accumulator.store_finished_reasoning(store, second_scope)
self.assertGreater(first_stored, 0)
self.assertGreater(second_stored, 0)
self.assertEqual(
store.get(f"scope:{first_scope}:tool_call:call_stream"), "Need a tool."
)
self.assertEqual(
store.get(f"scope:{second_scope}:tool_call:call_stream"), "Need a tool."
)
store.close()
def test_stores_tool_call_reasoning_before_finish_reason(self) -> None:
store = ReasoningStore(":memory:")
accumulator = StreamAccumulator()

View File

@ -470,6 +470,193 @@ class TransformTests(unittest.TestCase):
prepared.payload["messages"][1]["reasoning_content"], "first reasoning"
)
def test_strict_hit_backfills_portable_cache_for_mode_switch(self) -> None:
agent_prior = [
{"role": "system", "content": "Agent mode."},
{"role": "user", "content": "set up the task"},
{"role": "user", "content": "read README"},
]
plan_prior = [
{"role": "system", "content": "Plan mode."},
{"role": "user", "content": "set up the task"},
{"role": "user", "content": "read README"},
]
tool_call = {
"id": "call_mode_switch",
"type": "function",
"function": {"name": "read_file", "arguments": '{"path":"README.md"}'},
}
assistant_message = {
"role": "assistant",
"content": "",
"reasoning_content": "Need README before answering.",
"tool_calls": [tool_call],
}
self.store.store_assistant_message(
assistant_message,
cache_scope(agent_prior),
)
strict_prepared = prepare_upstream_request(
{
"model": "deepseek-v4-pro",
"messages": [
*agent_prior,
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
],
},
ProxyConfig(),
self.store,
)
portable_prepared = prepare_upstream_request(
{
"model": "deepseek-v4-pro",
"messages": [
*plan_prior,
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
],
},
ProxyConfig(),
self.store,
)
self.assertEqual(strict_prepared.patched_reasoning_messages, 1)
self.assertEqual(portable_prepared.patched_reasoning_messages, 1)
self.assertEqual(portable_prepared.missing_reasoning_messages, 0)
self.assertEqual(
portable_prepared.payload["messages"][3]["reasoning_content"],
"Need README before answering.",
)
self.assertTrue(
str(portable_prepared.reasoning_diagnostics[-1]["hit_kind"]).startswith(
"portable_"
)
)
def test_portable_turn_cache_restores_final_assistant_after_tool_result(
self,
) -> None:
agent_user = {"role": "user", "content": "look up project state"}
plan_user = dict(agent_user)
tool_call = {
"id": "call_project_state",
"type": "function",
"function": {"name": "lookup", "arguments": '{"query":"state"}'},
}
tool_result = {
"role": "tool",
"tool_call_id": "call_project_state",
"content": '{"state":"ready"}',
}
tool_assistant = {
"role": "assistant",
"content": "",
"reasoning_content": "Need the project state.",
"tool_calls": [tool_call],
}
final_assistant = {
"role": "assistant",
"content": "The project is ready.",
"reasoning_content": "The tool result is enough to answer.",
}
agent_initial_prior = [
{"role": "system", "content": "Agent mode."},
agent_user,
]
agent_final_prior = [*agent_initial_prior, tool_assistant, tool_result]
self.store.store_assistant_message(
tool_assistant,
cache_scope(agent_initial_prior),
DEFAULT_CACHE_NAMESPACE,
agent_initial_prior,
)
self.store.store_assistant_message(
final_assistant,
cache_scope(agent_final_prior),
DEFAULT_CACHE_NAMESPACE,
agent_final_prior,
)
prepared = prepare_upstream_request(
{
"model": "deepseek-v4-pro",
"messages": [
{"role": "system", "content": "Plan mode."},
plan_user,
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
tool_result,
{"role": "assistant", "content": "The project is ready."},
{"role": "user", "content": "continue"},
],
},
ProxyConfig(missing_reasoning_strategy="reject"),
self.store,
)
self.assertEqual(prepared.missing_reasoning_messages, 0)
self.assertEqual(prepared.patched_reasoning_messages, 2)
self.assertEqual(
prepared.payload["messages"][4]["reasoning_content"],
"The tool result is enough to answer.",
)
def test_portable_turn_cache_isolated_for_reused_tool_call_id(self) -> None:
tool_call = {
"id": "call_reused",
"type": "function",
"function": {"name": "lookup", "arguments": "{}"},
}
assistant_a = {
"role": "assistant",
"content": "",
"reasoning_content": "Reasoning for thread A.",
"tool_calls": [tool_call],
}
assistant_b = {
"role": "assistant",
"content": "",
"reasoning_content": "Reasoning for thread B.",
"tool_calls": [tool_call],
}
prior_a = [
{"role": "system", "content": "Agent mode."},
{"role": "user", "content": "thread A"},
]
prior_b = [
{"role": "system", "content": "Agent mode."},
{"role": "user", "content": "thread B"},
]
self.store.store_assistant_message(
assistant_a,
cache_scope(prior_a),
DEFAULT_CACHE_NAMESPACE,
prior_a,
)
self.store.store_assistant_message(
assistant_b,
cache_scope(prior_b),
DEFAULT_CACHE_NAMESPACE,
prior_b,
)
prepared = prepare_upstream_request(
{
"model": "deepseek-v4-pro",
"messages": [
{"role": "system", "content": "Plan mode."},
{"role": "user", "content": "thread A"},
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
],
},
ProxyConfig(),
self.store,
)
self.assertEqual(
prepared.payload["messages"][2]["reasoning_content"],
"Reasoning for thread A.",
)
def test_restores_reasoning_when_cursor_drops_tool_call_id_but_keeps_function_call(
self,
) -> None:
@ -738,6 +925,10 @@ class TransformTests(unittest.TestCase):
)
self.assertEqual(prepared.missing_reasoning_messages, 0)
self.assertEqual(prepared.recovered_reasoning_messages, 0)
self.assertEqual(prepared.recovery_dropped_messages, 0)
self.assertTrue(prepared.continued_recovery_boundary)
self.assertGreater(prepared.retired_prefix_messages, 0)
self.assertIsNone(prepared.recovery_notice)
self.assertEqual(
[message["role"] for message in prepared.payload["messages"]],
@ -756,6 +947,100 @@ class TransformTests(unittest.TestCase):
},
)
def test_recovered_response_is_recorded_under_pre_recovery_scope(self) -> None:
old_tool_call = {
"id": "call_old",
"type": "function",
"function": {
"name": "read_file",
"arguments": '{"path":"README.md"}',
},
}
new_tool_call = {
"id": "call_new",
"type": "function",
"function": {"name": "lookup", "arguments": '{"query":"new"}'},
}
first_payload = {
"model": "deepseek-v4-pro",
"messages": [
{"role": "user", "content": "old model turn"},
{"role": "assistant", "content": "", "tool_calls": [old_tool_call]},
{"role": "tool", "tool_call_id": "call_old", "content": "old result"},
{"role": "user", "content": "continue with DeepSeek"},
],
}
first_recovered = prepare_upstream_request(
first_payload,
ProxyConfig(missing_reasoning_strategy="recover"),
self.store,
)
self.assertEqual(first_recovered.recovered_reasoning_messages, 1)
response_body = json.dumps(
{
"id": "chatcmpl-test",
"object": "chat.completion",
"model": "deepseek-v4-pro",
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"content": "",
"reasoning_content": "Need the new lookup.",
"tool_calls": [new_tool_call],
},
}
],
}
).encode()
rewritten = rewrite_response_body(
response_body,
"deepseek-v4-pro",
self.store,
first_recovered.payload["messages"],
first_recovered.cache_namespace,
content_prefix=first_recovered.recovery_notice,
recording_contexts=first_recovered.record_response_contexts,
)
recovered_assistant = json.loads(rewritten)["choices"][0]["message"]
self.assertEqual(len(first_recovered.record_response_contexts), 2)
for scope, _messages in first_recovered.record_response_contexts:
self.assertEqual(
self.store.get(
f"scope:{scope}:signature:{message_signature(recovered_assistant)}"
),
"Need the new lookup.",
)
recovered_assistant.pop("reasoning_content", None)
second_payload = {
"model": "deepseek-v4-pro",
"messages": [
*first_payload["messages"],
recovered_assistant,
{"role": "tool", "tool_call_id": "call_new", "content": "new result"},
],
}
second_prepared = prepare_upstream_request(
second_payload,
ProxyConfig(missing_reasoning_strategy="recover"),
self.store,
)
self.assertEqual(second_prepared.missing_reasoning_messages, 0)
self.assertEqual(second_prepared.recovered_reasoning_messages, 0)
self.assertEqual(second_prepared.recovery_dropped_messages, 0)
self.assertTrue(second_prepared.continued_recovery_boundary)
self.assertGreater(second_prepared.retired_prefix_messages, 0)
self.assertEqual(
second_prepared.payload["messages"][2]["reasoning_content"],
"Need the new lookup.",
)
def test_recovery_boundary_accepts_legacy_notice_text(self) -> None:
legacy_recovery_notice = (
"Note: recovered this DeepSeek chat because older tool-call reasoning "
@ -844,6 +1129,10 @@ class TransformTests(unittest.TestCase):
)
self.assertEqual(prepared.missing_reasoning_messages, 0)
self.assertEqual(prepared.recovered_reasoning_messages, 0)
self.assertEqual(prepared.recovery_dropped_messages, 0)
self.assertTrue(prepared.continued_recovery_boundary)
self.assertGreater(prepared.retired_prefix_messages, 0)
self.assertIsNone(prepared.recovery_notice)
self.assertEqual(
prepared.payload["messages"][2]["reasoning_content"],
@ -985,6 +1274,64 @@ class TransformTests(unittest.TestCase):
self.assertEqual(prepared.missing_reasoning_messages, 1)
self.assertNotIn("reasoning_content", prepared.payload["messages"][1])
def test_deepseek_pro_and_flash_share_reasoning_namespace(self) -> None:
config = ProxyConfig()
namespace_pro = reasoning_cache_namespace(
config,
"deepseek-v4-pro",
{"type": "enabled"},
"high",
"Bearer key-a",
)
namespace_flash = reasoning_cache_namespace(
config,
"deepseek-v4-flash",
{"type": "enabled"},
"high",
"Bearer key-a",
)
self.assertEqual(namespace_pro, namespace_flash)
prior = [{"role": "user", "content": "read README"}]
tool_call = {
"id": "call_shared",
"type": "function",
"function": {
"name": "read_file",
"arguments": '{"path":"README.md"}',
},
}
self.store.store_assistant_message(
{
"role": "assistant",
"content": "",
"reasoning_content": "Shared DeepSeek reasoning.",
"tool_calls": [tool_call],
},
conversation_scope(prior, namespace_pro),
namespace_pro,
prior,
)
prepared = prepare_upstream_request(
{
"model": "deepseek-v4-flash",
"messages": [
*prior,
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
],
},
config,
self.store,
authorization="Bearer key-a",
)
self.assertEqual(prepared.missing_reasoning_messages, 0)
self.assertEqual(
prepared.payload["messages"][1]["reasoning_content"],
"Shared DeepSeek reasoning.",
)
def test_converted_function_message_uses_tool_schema(self) -> None:
payload = {
"model": "deepseek-v4-pro",