fix(cache): fix reasoning recovery during mode/model switch (#28)
parent
cf3a9d3875
commit
5f14da32c6
|
|
@ -62,6 +62,13 @@ def message_signature(message: dict[str, Any]) -> str:
|
|||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _sha256_json(payload: Any) -> str:
|
||||
canonical = json.dumps(
|
||||
payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")
|
||||
)
|
||||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def canonical_scope_message(message: dict[str, Any]) -> dict[str, Any]:
|
||||
canonical: dict[str, Any] = {"role": message.get("role")}
|
||||
for key in ("content", "name", "tool_call_id", "prefix"):
|
||||
|
|
@ -81,10 +88,71 @@ def conversation_scope(messages: list[dict[str, Any]], namespace: str = "") -> s
|
|||
payload: Any = scope_messages
|
||||
if namespace:
|
||||
payload = {"namespace": namespace, "messages": scope_messages}
|
||||
canonical = json.dumps(
|
||||
payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")
|
||||
return _sha256_json(payload)
|
||||
|
||||
|
||||
def turn_context_signature(prior_messages: list[dict[str, Any]]) -> str:
|
||||
last_user_index = next(
|
||||
(
|
||||
index
|
||||
for index in range(len(prior_messages) - 1, -1, -1)
|
||||
if prior_messages[index].get("role") == "user"
|
||||
),
|
||||
-1,
|
||||
)
|
||||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
||||
start_index = 0
|
||||
if last_user_index != -1:
|
||||
start_index = last_user_index
|
||||
while start_index > 0 and prior_messages[start_index - 1].get("role") == "user":
|
||||
start_index -= 1
|
||||
|
||||
context_messages = [
|
||||
canonical_scope_message(message)
|
||||
for message in prior_messages[start_index:]
|
||||
if message.get("role") != "system"
|
||||
]
|
||||
return _sha256_json(context_messages)
|
||||
|
||||
|
||||
def scoped_reasoning_keys(message: dict[str, Any], scope: str) -> list[str]:
|
||||
keys = [f"scope:{scope}:signature:{message_signature(message)}"]
|
||||
keys.extend(
|
||||
f"scope:{scope}:tool_call:{tool_call_id}"
|
||||
for tool_call_id in tool_call_ids(message)
|
||||
)
|
||||
keys.extend(
|
||||
f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}"
|
||||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
return keys
|
||||
|
||||
|
||||
def portable_reasoning_keys(
|
||||
message: dict[str, Any],
|
||||
cache_namespace: str,
|
||||
prior_messages: list[dict[str, Any]],
|
||||
) -> list[str]:
|
||||
if not cache_namespace:
|
||||
return []
|
||||
|
||||
turn_signature = turn_context_signature(prior_messages)
|
||||
keys = [
|
||||
f"namespace:{cache_namespace}:turn:{turn_signature}:"
|
||||
f"signature:{message_signature(message)}"
|
||||
]
|
||||
keys.extend(
|
||||
f"namespace:{cache_namespace}:turn:{turn_signature}:"
|
||||
f"tool_call:{tool_call_id}"
|
||||
for tool_call_id in tool_call_ids(message)
|
||||
)
|
||||
keys.extend(
|
||||
f"namespace:{cache_namespace}:turn:{turn_signature}:"
|
||||
f"tool_call_signature:{tool_call_signature(tool_call)}"
|
||||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
return keys
|
||||
|
||||
|
||||
class ReasoningStore:
|
||||
|
|
@ -155,45 +223,65 @@ class ReasoningStore:
|
|||
return None
|
||||
return str(row[0])
|
||||
|
||||
def store_assistant_message(self, message: dict[str, Any], scope: str) -> int:
|
||||
def store_assistant_message(
|
||||
self,
|
||||
message: dict[str, Any],
|
||||
scope: str,
|
||||
cache_namespace: str = "",
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
if message.get("role") != "assistant":
|
||||
return 0
|
||||
reasoning = message.get("reasoning_content")
|
||||
if not isinstance(reasoning, str):
|
||||
return 0
|
||||
|
||||
keys = [f"scope:{scope}:signature:{message_signature(message)}"]
|
||||
keys.extend(
|
||||
f"scope:{scope}:tool_call:{tool_call_id}"
|
||||
for tool_call_id in tool_call_ids(message)
|
||||
)
|
||||
keys.extend(
|
||||
f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}"
|
||||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
keys = scoped_reasoning_keys(message, scope)
|
||||
if prior_messages is not None:
|
||||
keys.extend(
|
||||
portable_reasoning_keys(message, cache_namespace, prior_messages)
|
||||
)
|
||||
keys = list(dict.fromkeys(keys))
|
||||
for key in keys:
|
||||
self.put(key, reasoning, message)
|
||||
return len(keys)
|
||||
|
||||
def lookup_for_message(self, message: dict[str, Any], scope: str) -> str | None:
|
||||
reasoning = self.get(f"scope:{scope}:signature:{message_signature(message)}")
|
||||
if reasoning is not None:
|
||||
return reasoning
|
||||
for tool_call_id in tool_call_ids(message):
|
||||
reasoning = self.get(f"scope:{scope}:tool_call:{tool_call_id}")
|
||||
if reasoning is not None:
|
||||
return reasoning
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
reasoning = self.get(
|
||||
f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}"
|
||||
def lookup_for_message(
|
||||
self,
|
||||
message: dict[str, Any],
|
||||
scope: str,
|
||||
cache_namespace: str = "",
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
) -> str | None:
|
||||
keys = scoped_reasoning_keys(message, scope)
|
||||
if prior_messages is not None:
|
||||
keys.extend(
|
||||
portable_reasoning_keys(message, cache_namespace, prior_messages)
|
||||
)
|
||||
for key in keys:
|
||||
reasoning = self.get(key)
|
||||
if reasoning is not None:
|
||||
return reasoning
|
||||
return None
|
||||
|
||||
def backfill_portable_aliases(
|
||||
self,
|
||||
message: dict[str, Any],
|
||||
reasoning: str,
|
||||
cache_namespace: str,
|
||||
prior_messages: list[dict[str, Any]],
|
||||
) -> int:
|
||||
if not isinstance(reasoning, str):
|
||||
return 0
|
||||
keys = portable_reasoning_keys(message, cache_namespace, prior_messages)
|
||||
if not keys:
|
||||
return 0
|
||||
message_with_reasoning = dict(message)
|
||||
message_with_reasoning["reasoning_content"] = reasoning
|
||||
for key in dict.fromkeys(keys):
|
||||
self.put(key, reasoning, message_with_reasoning)
|
||||
return len(keys)
|
||||
|
||||
def clear(self) -> int:
|
||||
with self._lock:
|
||||
row = self._conn.execute("SELECT COUNT(*) FROM reasoning_cache").fetchone()
|
||||
|
|
|
|||
|
|
@ -160,16 +160,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
)
|
||||
if prepared.recovered_reasoning_messages:
|
||||
if prepared.recovery_notice:
|
||||
LOG.warning(
|
||||
(
|
||||
"recovered request because cached reasoning_content was "
|
||||
"unavailable for %s assistant message(s); omitted %s "
|
||||
"older message(s) from forwarded history and will show "
|
||||
"a Cursor notice"
|
||||
),
|
||||
prepared.recovered_reasoning_messages,
|
||||
prepared.recovery_dropped_messages,
|
||||
)
|
||||
LOG.warning("refreshed reasoning_content history")
|
||||
else:
|
||||
LOG.info(
|
||||
(
|
||||
|
|
@ -305,7 +296,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
prepared.payload["messages"],
|
||||
prepared.cache_namespace,
|
||||
prepared.recovery_notice,
|
||||
trace,
|
||||
trace=trace,
|
||||
record_response_scope=prepared.record_response_scope,
|
||||
record_response_messages=prepared.record_response_messages,
|
||||
record_response_contexts=prepared.record_response_contexts,
|
||||
)
|
||||
else:
|
||||
sent_response = self._proxy_regular_response(
|
||||
|
|
@ -314,7 +308,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
prepared.payload["messages"],
|
||||
prepared.cache_namespace,
|
||||
prepared.recovery_notice,
|
||||
trace,
|
||||
trace=trace,
|
||||
record_response_scope=prepared.record_response_scope,
|
||||
record_response_messages=prepared.record_response_messages,
|
||||
record_response_contexts=prepared.record_response_contexts,
|
||||
)
|
||||
if not sent_response:
|
||||
self._finish_trace(
|
||||
|
|
@ -549,6 +546,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
cache_namespace: str,
|
||||
recovery_notice: str | None = None,
|
||||
trace: TraceRequest | None = None,
|
||||
record_response_scope: str | None = None,
|
||||
record_response_messages: list[dict[str, Any]] | None = None,
|
||||
record_response_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
|
||||
) -> bool:
|
||||
body = read_response_body(response)
|
||||
upstream_body = body
|
||||
|
|
@ -560,6 +560,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
request_messages,
|
||||
cache_namespace,
|
||||
content_prefix=recovery_notice,
|
||||
scope=record_response_scope,
|
||||
prior_messages=record_response_messages,
|
||||
recording_contexts=record_response_contexts,
|
||||
)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
LOG.warning("failed to rewrite upstream JSON response: %s", exc)
|
||||
|
|
@ -611,6 +614,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
cache_namespace: str,
|
||||
recovery_notice: str | None = None,
|
||||
trace: TraceRequest | None = None,
|
||||
record_response_scope: str | None = None,
|
||||
record_response_messages: list[dict[str, Any]] | None = None,
|
||||
record_response_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
|
||||
) -> bool:
|
||||
if trace is not None:
|
||||
trace.record_upstream_response(
|
||||
|
|
@ -645,7 +651,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
if self.config.cursor_display_reasoning
|
||||
else None
|
||||
)
|
||||
scope = conversation_scope(request_messages, cache_namespace)
|
||||
scope = (
|
||||
record_response_scope
|
||||
if record_response_scope is not None
|
||||
else conversation_scope(request_messages, cache_namespace)
|
||||
)
|
||||
response_prior_messages = (
|
||||
record_response_messages
|
||||
if record_response_messages is not None
|
||||
else request_messages
|
||||
)
|
||||
response_contexts = (
|
||||
record_response_contexts
|
||||
if record_response_contexts is not None
|
||||
else [(scope, response_prior_messages)]
|
||||
)
|
||||
finalized = False
|
||||
pending_recovery_notice = recovery_notice
|
||||
while True:
|
||||
|
|
@ -660,7 +680,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
line,
|
||||
original_model,
|
||||
accumulator,
|
||||
scope,
|
||||
cache_namespace,
|
||||
response_contexts,
|
||||
display_adapter,
|
||||
pending_recovery_notice,
|
||||
trace,
|
||||
|
|
@ -677,7 +698,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
if not finalized:
|
||||
if self.config.verbose:
|
||||
log_json("model streaming assistant messages", accumulator.messages())
|
||||
stored = accumulator.store_reasoning(self.reasoning_store, scope)
|
||||
stored = sum(
|
||||
accumulator.store_reasoning(
|
||||
self.reasoning_store,
|
||||
scope,
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
for scope, prior_messages in response_contexts
|
||||
)
|
||||
if stored:
|
||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||
return True
|
||||
|
|
@ -687,7 +716,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
line: bytes,
|
||||
original_model: str,
|
||||
accumulator: StreamAccumulator,
|
||||
scope: str,
|
||||
cache_namespace: str,
|
||||
response_contexts: list[tuple[str, list[dict[str, Any]]]],
|
||||
display_adapter: CursorReasoningDisplayAdapter | None,
|
||||
recovery_notice: str | None = None,
|
||||
trace: TraceRequest | None = None,
|
||||
|
|
@ -700,7 +730,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
if data == b"[DONE]":
|
||||
if self.config.verbose:
|
||||
log_json("model streaming assistant messages", accumulator.messages())
|
||||
stored = accumulator.store_reasoning(self.reasoning_store, scope)
|
||||
stored = sum(
|
||||
accumulator.store_reasoning(
|
||||
self.reasoning_store,
|
||||
scope,
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
for scope, prior_messages in response_contexts
|
||||
)
|
||||
if stored:
|
||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||
prefix = b""
|
||||
|
|
@ -728,7 +766,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
if recovery_notice and inject_recovery_notice(chunk, recovery_notice):
|
||||
recovery_notice = None
|
||||
accumulator.ingest_chunk(chunk)
|
||||
stored = accumulator.store_ready_reasoning(self.reasoning_store, scope)
|
||||
stored = sum(
|
||||
accumulator.store_ready_reasoning(
|
||||
self.reasoning_store,
|
||||
scope,
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
for scope, prior_messages in response_contexts
|
||||
)
|
||||
if stored:
|
||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||
if trace is not None:
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class StreamingChoice:
|
|||
class StreamAccumulator:
|
||||
def __init__(self) -> None:
|
||||
self.choices: dict[int, StreamingChoice] = {}
|
||||
self._stored_choices: dict[int, str] = {}
|
||||
self._stored_choices: dict[tuple[int, str], str] = {}
|
||||
|
||||
def ingest_chunk(self, chunk: dict[str, Any]) -> None:
|
||||
choices = chunk.get("choices")
|
||||
|
|
@ -70,26 +70,70 @@ class StreamAccumulator:
|
|||
|
||||
self._merge_tool_call_deltas(choice, delta.get("tool_calls"))
|
||||
|
||||
def store_reasoning(self, store: ReasoningStore, scope: str) -> int:
|
||||
def store_reasoning(
|
||||
self,
|
||||
store: ReasoningStore,
|
||||
scope: str,
|
||||
cache_namespace: str = "",
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
stored = 0
|
||||
for index, choice in self.choices.items():
|
||||
stored += self._store_choice(index, choice, store, scope)
|
||||
stored += self._store_choice(
|
||||
index, choice, store, scope, "final", cache_namespace, prior_messages
|
||||
)
|
||||
return stored
|
||||
|
||||
def store_finished_reasoning(self, store: ReasoningStore, scope: str) -> int:
|
||||
def store_finished_reasoning(
|
||||
self,
|
||||
store: ReasoningStore,
|
||||
scope: str,
|
||||
cache_namespace: str = "",
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
stored = 0
|
||||
for index, choice in self.choices.items():
|
||||
if choice.finish_reason is not None:
|
||||
stored += self._store_choice(index, choice, store, scope, "final")
|
||||
stored += self._store_choice(
|
||||
index,
|
||||
choice,
|
||||
store,
|
||||
scope,
|
||||
"final",
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
return stored
|
||||
|
||||
def store_ready_reasoning(self, store: ReasoningStore, scope: str) -> int:
|
||||
def store_ready_reasoning(
|
||||
self,
|
||||
store: ReasoningStore,
|
||||
scope: str,
|
||||
cache_namespace: str = "",
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
stored = 0
|
||||
for index, choice in self.choices.items():
|
||||
if choice.finish_reason is not None:
|
||||
stored += self._store_choice(index, choice, store, scope, "final")
|
||||
stored += self._store_choice(
|
||||
index,
|
||||
choice,
|
||||
store,
|
||||
scope,
|
||||
"final",
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
elif self._has_identified_tool_calls(choice):
|
||||
stored += self._store_choice(index, choice, store, scope, "tool_call")
|
||||
stored += self._store_choice(
|
||||
index,
|
||||
choice,
|
||||
store,
|
||||
scope,
|
||||
"tool_call",
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
return stored
|
||||
|
||||
def messages(self) -> list[dict[str, Any]]:
|
||||
|
|
@ -141,14 +185,22 @@ class StreamAccumulator:
|
|||
store: ReasoningStore,
|
||||
scope: str,
|
||||
stage: str = "final",
|
||||
cache_namespace: str = "",
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
stage_rank = {"tool_call": 1, "final": 2}
|
||||
previous_stage = self._stored_choices.get(index)
|
||||
storage_key = (index, scope)
|
||||
previous_stage = self._stored_choices.get(storage_key)
|
||||
if stage_rank.get(previous_stage or "", 0) >= stage_rank.get(stage, 0):
|
||||
return 0
|
||||
stored = store.store_assistant_message(choice.to_message(), scope)
|
||||
stored = store.store_assistant_message(
|
||||
choice.to_message(),
|
||||
scope,
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
if stored:
|
||||
self._stored_choices[index] = stage
|
||||
self._stored_choices[storage_key] = stage
|
||||
return stored
|
||||
|
||||
def _has_identified_tool_calls(self, choice: StreamingChoice) -> bool:
|
||||
|
|
|
|||
|
|
@ -105,7 +105,10 @@ def message_summaries(payload: dict[str, Any]) -> list[dict[str, Any]]:
|
|||
len(reasoning) if isinstance(reasoning, str) else 0
|
||||
),
|
||||
"has_recovery_notice": content.startswith(
|
||||
"[deepseek-cursor-proxy] Recovered"
|
||||
(
|
||||
"[deepseek-cursor-proxy] Refreshed reasoning_content history.",
|
||||
"[deepseek-cursor-proxy] Recovered",
|
||||
)
|
||||
),
|
||||
}
|
||||
summaries.append(summary)
|
||||
|
|
@ -231,6 +234,12 @@ class TraceRequest:
|
|||
"recovered_reasoning_messages": prepared.recovered_reasoning_messages,
|
||||
"recovery_dropped_messages": prepared.recovery_dropped_messages,
|
||||
"recovery_notice": prepared.recovery_notice,
|
||||
"record_response_scope": prepared.record_response_scope,
|
||||
"record_response_scopes": [
|
||||
scope for scope, _messages in prepared.record_response_contexts
|
||||
],
|
||||
"continued_recovery_boundary": prepared.continued_recovery_boundary,
|
||||
"retired_prefix_messages": prepared.retired_prefix_messages,
|
||||
"reasoning_diagnostics": prepared.reasoning_diagnostics,
|
||||
"recovery_steps": prepared.recovery_steps,
|
||||
"upstream_request_summary": payload_summary(prepared.payload),
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from .reasoning_store import (
|
|||
message_signature,
|
||||
tool_call_ids,
|
||||
tool_call_signature,
|
||||
turn_context_signature,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -73,10 +74,7 @@ CURSOR_THINKING_BLOCK_RE = re.compile(
|
|||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
RECOVERY_NOTICE_TEXT = (
|
||||
"[deepseek-cursor-proxy] Recovered this DeepSeek chat because older "
|
||||
"tool-call reasoning was unavailable; continuing with recent context only."
|
||||
)
|
||||
RECOVERY_NOTICE_TEXT = "[deepseek-cursor-proxy] Refreshed reasoning_content history."
|
||||
LEGACY_RECOVERY_NOTICE_TEXT = (
|
||||
"Note: recovered this DeepSeek chat because older tool-call reasoning "
|
||||
"was unavailable; continuing with recent context only."
|
||||
|
|
@ -101,8 +99,15 @@ class PreparedRequest:
|
|||
recovered_reasoning_messages: int = 0
|
||||
recovery_dropped_messages: int = 0
|
||||
recovery_notice: str | None = None
|
||||
record_response_scope: str | None = None
|
||||
record_response_messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
record_response_contexts: list[tuple[str, list[dict[str, Any]]]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
reasoning_diagnostics: list[dict[str, Any]] = field(default_factory=list)
|
||||
recovery_steps: list[dict[str, Any]] = field(default_factory=list)
|
||||
continued_recovery_boundary: bool = False
|
||||
retired_prefix_messages: int = 0
|
||||
|
||||
|
||||
def normalize_reasoning_effort(value: Any) -> str:
|
||||
|
|
@ -260,7 +265,12 @@ def normalize_message(
|
|||
)
|
||||
lookup_scope = conversation_scope(prior_messages, cache_namespace)
|
||||
lookup_keys = (
|
||||
reasoning_lookup_keys(normalized, lookup_scope)
|
||||
reasoning_lookup_keys(
|
||||
normalized,
|
||||
lookup_scope,
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
if needs_reasoning
|
||||
else []
|
||||
)
|
||||
|
|
@ -273,6 +283,13 @@ def normalize_message(
|
|||
hit_kind = lookup_key["kind"]
|
||||
normalized["reasoning_content"] = restored
|
||||
patched = True
|
||||
if not lookup_key.get("portable"):
|
||||
store.backfill_portable_aliases(
|
||||
normalized,
|
||||
restored,
|
||||
cache_namespace,
|
||||
prior_messages,
|
||||
)
|
||||
break
|
||||
if needs_reasoning and not patched:
|
||||
missing = True
|
||||
|
|
@ -315,11 +332,14 @@ def normalize_message(
|
|||
def reasoning_lookup_keys(
|
||||
message: dict[str, Any],
|
||||
scope: str,
|
||||
cache_namespace: str = "",
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
keys = [
|
||||
{
|
||||
"kind": "message_signature",
|
||||
"key": f"scope:{scope}:signature:{message_signature(message)}",
|
||||
"portable": False,
|
||||
"hit": False,
|
||||
}
|
||||
]
|
||||
|
|
@ -328,6 +348,7 @@ def reasoning_lookup_keys(
|
|||
"kind": "tool_call_id",
|
||||
"tool_call_id": tool_call_id,
|
||||
"key": f"scope:{scope}:tool_call:{tool_call_id}",
|
||||
"portable": False,
|
||||
"hit": False,
|
||||
}
|
||||
for tool_call_id in tool_call_ids(message)
|
||||
|
|
@ -340,11 +361,57 @@ def reasoning_lookup_keys(
|
|||
f"scope:{scope}:tool_call_signature:"
|
||||
f"{tool_call_signature(tool_call)}"
|
||||
),
|
||||
"portable": False,
|
||||
"hit": False,
|
||||
}
|
||||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
if cache_namespace and prior_messages is not None:
|
||||
turn_signature = turn_context_signature(prior_messages)
|
||||
keys.append(
|
||||
{
|
||||
"kind": "portable_message_signature",
|
||||
"key": (
|
||||
f"namespace:{cache_namespace}:turn:{turn_signature}:"
|
||||
f"signature:{message_signature(message)}"
|
||||
),
|
||||
"turn_context_signature": turn_signature,
|
||||
"portable": True,
|
||||
"hit": False,
|
||||
}
|
||||
)
|
||||
keys.extend(
|
||||
{
|
||||
"kind": "portable_tool_call_id",
|
||||
"tool_call_id": tool_call_id,
|
||||
"key": (
|
||||
f"namespace:{cache_namespace}:turn:{turn_signature}:"
|
||||
f"tool_call:{tool_call_id}"
|
||||
),
|
||||
"turn_context_signature": turn_signature,
|
||||
"portable": True,
|
||||
"hit": False,
|
||||
}
|
||||
for tool_call_id in tool_call_ids(message)
|
||||
)
|
||||
keys.extend(
|
||||
{
|
||||
"kind": "portable_tool_call_signature",
|
||||
"function_name": str(
|
||||
(tool_call.get("function") or {}).get("name") or ""
|
||||
),
|
||||
"key": (
|
||||
f"namespace:{cache_namespace}:turn:{turn_signature}:"
|
||||
f"tool_call_signature:{tool_call_signature(tool_call)}"
|
||||
),
|
||||
"turn_context_signature": turn_signature,
|
||||
"portable": True,
|
||||
"hit": False,
|
||||
}
|
||||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
return keys
|
||||
|
||||
|
||||
|
|
@ -399,6 +466,52 @@ def leading_system_messages(messages: list[dict[str, Any]]) -> list[dict[str, An
|
|||
return leading_messages
|
||||
|
||||
|
||||
def active_messages_from_recovery_boundary(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], int, dict[str, Any]] | None:
|
||||
recovery_boundary_index = next(
|
||||
(
|
||||
index
|
||||
for index in range(len(messages) - 1, -1, -1)
|
||||
if has_recovery_notice(messages[index])
|
||||
),
|
||||
-1,
|
||||
)
|
||||
if recovery_boundary_index == -1:
|
||||
return None
|
||||
|
||||
context_user_index = next(
|
||||
(
|
||||
index
|
||||
for index in range(recovery_boundary_index - 1, -1, -1)
|
||||
if messages[index].get("role") == "user"
|
||||
),
|
||||
-1,
|
||||
)
|
||||
leading_messages = leading_system_messages(messages)
|
||||
recovered_tail = []
|
||||
if context_user_index != -1:
|
||||
recovered_tail.append(messages[context_user_index])
|
||||
recovered_tail.extend(messages[recovery_boundary_index:])
|
||||
active_messages = [
|
||||
*leading_messages,
|
||||
{"role": "system", "content": RECOVERY_SYSTEM_CONTENT},
|
||||
*recovered_tail,
|
||||
]
|
||||
kept_context_messages = 1 if context_user_index != -1 else 0
|
||||
retired_messages = (
|
||||
recovery_boundary_index - len(leading_messages) - kept_context_messages
|
||||
)
|
||||
retired_messages = max(retired_messages, 0)
|
||||
step = {
|
||||
"strategy": "continued_recovery_boundary",
|
||||
"recovery_boundary_index": recovery_boundary_index,
|
||||
"context_user_index": context_user_index,
|
||||
"retired_prefix_messages": retired_messages,
|
||||
}
|
||||
return active_messages, retired_messages, step
|
||||
|
||||
|
||||
def recover_messages_from_missing_reasoning(
|
||||
messages: list[dict[str, Any]],
|
||||
missing_indexes: list[int],
|
||||
|
|
@ -510,6 +623,12 @@ def upstream_model_for(original_model: str, config: ProxyConfig) -> str:
|
|||
return config.upstream_model
|
||||
|
||||
|
||||
def reasoning_model_family(upstream_model: str) -> str:
|
||||
if upstream_model in {"deepseek-v4-pro", "deepseek-v4-flash"}:
|
||||
return "deepseek-v4"
|
||||
return upstream_model
|
||||
|
||||
|
||||
def reasoning_cache_namespace(
|
||||
config: ProxyConfig,
|
||||
upstream_model: str,
|
||||
|
|
@ -522,7 +641,7 @@ def reasoning_cache_namespace(
|
|||
auth_hash = hashlib.sha256(authorization.encode("utf-8")).hexdigest()
|
||||
payload = {
|
||||
"base_url": config.upstream_base_url,
|
||||
"model": upstream_model,
|
||||
"model": reasoning_model_family(upstream_model),
|
||||
"thinking": thinking,
|
||||
"reasoning_effort": reasoning_effort,
|
||||
"authorization_hash": auth_hash,
|
||||
|
|
@ -533,6 +652,22 @@ def reasoning_cache_namespace(
|
|||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def response_recording_contexts(
|
||||
*items: tuple[str, list[dict[str, Any]]] | None,
|
||||
) -> list[tuple[str, list[dict[str, Any]]]]:
|
||||
contexts: list[tuple[str, list[dict[str, Any]]]] = []
|
||||
seen: set[str] = set()
|
||||
for item in items:
|
||||
if item is None:
|
||||
continue
|
||||
scope, messages = item
|
||||
if scope in seen:
|
||||
continue
|
||||
seen.add(scope)
|
||||
contexts.append((scope, messages))
|
||||
return contexts
|
||||
|
||||
|
||||
def prepare_upstream_request(
|
||||
payload: dict[str, Any],
|
||||
config: ProxyConfig,
|
||||
|
|
@ -596,19 +731,40 @@ def prepare_upstream_request(
|
|||
prepared.get("reasoning_effort"),
|
||||
authorization,
|
||||
)
|
||||
pre_repair_messages, _, _, _ = normalize_messages(
|
||||
payload.get("messages"),
|
||||
None,
|
||||
cache_namespace,
|
||||
repair_reasoning=False,
|
||||
keep_reasoning=not thinking_disabled,
|
||||
)
|
||||
record_response_messages = pre_repair_messages
|
||||
record_response_scope = conversation_scope(
|
||||
record_response_messages, cache_namespace
|
||||
)
|
||||
messages_for_repair = pre_repair_messages
|
||||
continued_recovery_boundary = False
|
||||
retired_prefix_messages = 0
|
||||
recovered_count = 0
|
||||
recovery_dropped_messages = 0
|
||||
recovery_notice = None
|
||||
recovery_steps: list[dict[str, Any]] = []
|
||||
if thinking_enabled and config.missing_reasoning_strategy == "recover":
|
||||
boundary = active_messages_from_recovery_boundary(pre_repair_messages)
|
||||
if boundary is not None:
|
||||
messages_for_repair, retired_prefix_messages, boundary_step = boundary
|
||||
continued_recovery_boundary = True
|
||||
recovery_steps.append(boundary_step)
|
||||
|
||||
messages, patched_count, missing_indexes, reasoning_diagnostics = (
|
||||
normalize_messages(
|
||||
payload.get("messages"),
|
||||
messages_for_repair,
|
||||
store,
|
||||
cache_namespace,
|
||||
repair_reasoning=thinking_enabled,
|
||||
keep_reasoning=not thinking_disabled,
|
||||
)
|
||||
)
|
||||
recovered_count = 0
|
||||
recovery_dropped_messages = 0
|
||||
recovery_notice = None
|
||||
recovery_steps: list[dict[str, Any]] = []
|
||||
while missing_indexes and config.missing_reasoning_strategy == "recover":
|
||||
recovered_messages, dropped_messages, notice, recovery_step = (
|
||||
recover_messages_from_missing_reasoning(messages, missing_indexes)
|
||||
|
|
@ -634,6 +790,11 @@ def prepare_upstream_request(
|
|||
)
|
||||
reasoning_diagnostics.extend(latest_diagnostics)
|
||||
prepared["messages"] = messages
|
||||
active_record_response_scope = conversation_scope(messages, cache_namespace)
|
||||
record_response_contexts = response_recording_contexts(
|
||||
(record_response_scope, record_response_messages),
|
||||
(active_record_response_scope, messages),
|
||||
)
|
||||
|
||||
return PreparedRequest(
|
||||
payload=prepared,
|
||||
|
|
@ -645,8 +806,13 @@ def prepare_upstream_request(
|
|||
recovered_reasoning_messages=recovered_count,
|
||||
recovery_dropped_messages=recovery_dropped_messages,
|
||||
recovery_notice=recovery_notice,
|
||||
record_response_scope=record_response_scope,
|
||||
record_response_messages=record_response_messages,
|
||||
record_response_contexts=record_response_contexts,
|
||||
reasoning_diagnostics=reasoning_diagnostics,
|
||||
recovery_steps=recovery_steps,
|
||||
continued_recovery_boundary=continued_recovery_boundary,
|
||||
retired_prefix_messages=retired_prefix_messages,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -655,6 +821,9 @@ def record_response_reasoning(
|
|||
store: ReasoningStore | None,
|
||||
request_messages: list[dict[str, Any]],
|
||||
cache_namespace: str = "",
|
||||
scope: str | None = None,
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
|
||||
) -> int:
|
||||
if store is None:
|
||||
return 0
|
||||
|
|
@ -662,13 +831,28 @@ def record_response_reasoning(
|
|||
choices = response_payload.get("choices")
|
||||
if not isinstance(choices, list):
|
||||
return stored
|
||||
scope = conversation_scope(request_messages, cache_namespace)
|
||||
if recording_contexts is None:
|
||||
response_scope = (
|
||||
scope
|
||||
if scope is not None
|
||||
else conversation_scope(request_messages, cache_namespace)
|
||||
)
|
||||
response_prior_messages = (
|
||||
prior_messages if prior_messages is not None else request_messages
|
||||
)
|
||||
recording_contexts = [(response_scope, response_prior_messages)]
|
||||
for choice in choices:
|
||||
if not isinstance(choice, dict):
|
||||
continue
|
||||
message = choice.get("message")
|
||||
if isinstance(message, dict):
|
||||
stored += store.store_assistant_message(message, scope)
|
||||
for response_scope, response_prior_messages in recording_contexts:
|
||||
stored += store.store_assistant_message(
|
||||
message,
|
||||
response_scope,
|
||||
cache_namespace,
|
||||
response_prior_messages,
|
||||
)
|
||||
return stored
|
||||
|
||||
|
||||
|
|
@ -679,13 +863,22 @@ def rewrite_response_body(
|
|||
request_messages: list[dict[str, Any]],
|
||||
cache_namespace: str = "",
|
||||
content_prefix: str | None = None,
|
||||
scope: str | None = None,
|
||||
prior_messages: list[dict[str, Any]] | None = None,
|
||||
recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
|
||||
) -> bytes:
|
||||
response_payload = json.loads(body.decode("utf-8"))
|
||||
if isinstance(response_payload, dict):
|
||||
if content_prefix:
|
||||
prefix_response_content(response_payload, content_prefix)
|
||||
record_response_reasoning(
|
||||
response_payload, store, request_messages, cache_namespace
|
||||
response_payload,
|
||||
store,
|
||||
request_messages,
|
||||
cache_namespace,
|
||||
scope=scope,
|
||||
prior_messages=prior_messages,
|
||||
recording_contexts=recording_contexts,
|
||||
)
|
||||
if "model" in response_payload:
|
||||
response_payload["model"] = original_model
|
||||
|
|
|
|||
|
|
@ -738,7 +738,7 @@ class ProxyEndToEndTests(unittest.TestCase):
|
|||
{"role": "user", "content": "Thanks, now continue."},
|
||||
)
|
||||
self.assertIn(
|
||||
"cached reasoning_content was unavailable",
|
||||
"refreshed reasoning_content history",
|
||||
"\n".join(captured.output),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -116,6 +116,50 @@ class StreamAccumulatorTests(unittest.TestCase):
|
|||
self.assertEqual(accumulator.store_reasoning(store, scope), 0)
|
||||
store.close()
|
||||
|
||||
def test_stores_same_streaming_choice_under_multiple_scopes(self) -> None:
|
||||
store = ReasoningStore(":memory:")
|
||||
accumulator = StreamAccumulator()
|
||||
accumulator.ingest_chunk(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"reasoning_content": "Need a tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "call_stream",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup",
|
||||
"arguments": "{}",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
first_scope = conversation_scope([{"role": "user", "content": "full"}])
|
||||
second_scope = conversation_scope([{"role": "user", "content": "active"}])
|
||||
first_stored = accumulator.store_finished_reasoning(store, first_scope)
|
||||
second_stored = accumulator.store_finished_reasoning(store, second_scope)
|
||||
|
||||
self.assertGreater(first_stored, 0)
|
||||
self.assertGreater(second_stored, 0)
|
||||
self.assertEqual(
|
||||
store.get(f"scope:{first_scope}:tool_call:call_stream"), "Need a tool."
|
||||
)
|
||||
self.assertEqual(
|
||||
store.get(f"scope:{second_scope}:tool_call:call_stream"), "Need a tool."
|
||||
)
|
||||
store.close()
|
||||
|
||||
def test_stores_tool_call_reasoning_before_finish_reason(self) -> None:
|
||||
store = ReasoningStore(":memory:")
|
||||
accumulator = StreamAccumulator()
|
||||
|
|
|
|||
|
|
@ -470,6 +470,193 @@ class TransformTests(unittest.TestCase):
|
|||
prepared.payload["messages"][1]["reasoning_content"], "first reasoning"
|
||||
)
|
||||
|
||||
def test_strict_hit_backfills_portable_cache_for_mode_switch(self) -> None:
|
||||
agent_prior = [
|
||||
{"role": "system", "content": "Agent mode."},
|
||||
{"role": "user", "content": "set up the task"},
|
||||
{"role": "user", "content": "read README"},
|
||||
]
|
||||
plan_prior = [
|
||||
{"role": "system", "content": "Plan mode."},
|
||||
{"role": "user", "content": "set up the task"},
|
||||
{"role": "user", "content": "read README"},
|
||||
]
|
||||
tool_call = {
|
||||
"id": "call_mode_switch",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"README.md"}'},
|
||||
}
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Need README before answering.",
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
self.store.store_assistant_message(
|
||||
assistant_message,
|
||||
cache_scope(agent_prior),
|
||||
)
|
||||
|
||||
strict_prepared = prepare_upstream_request(
|
||||
{
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
*agent_prior,
|
||||
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
|
||||
],
|
||||
},
|
||||
ProxyConfig(),
|
||||
self.store,
|
||||
)
|
||||
portable_prepared = prepare_upstream_request(
|
||||
{
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
*plan_prior,
|
||||
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
|
||||
],
|
||||
},
|
||||
ProxyConfig(),
|
||||
self.store,
|
||||
)
|
||||
|
||||
self.assertEqual(strict_prepared.patched_reasoning_messages, 1)
|
||||
self.assertEqual(portable_prepared.patched_reasoning_messages, 1)
|
||||
self.assertEqual(portable_prepared.missing_reasoning_messages, 0)
|
||||
self.assertEqual(
|
||||
portable_prepared.payload["messages"][3]["reasoning_content"],
|
||||
"Need README before answering.",
|
||||
)
|
||||
self.assertTrue(
|
||||
str(portable_prepared.reasoning_diagnostics[-1]["hit_kind"]).startswith(
|
||||
"portable_"
|
||||
)
|
||||
)
|
||||
|
||||
def test_portable_turn_cache_restores_final_assistant_after_tool_result(
|
||||
self,
|
||||
) -> None:
|
||||
agent_user = {"role": "user", "content": "look up project state"}
|
||||
plan_user = dict(agent_user)
|
||||
tool_call = {
|
||||
"id": "call_project_state",
|
||||
"type": "function",
|
||||
"function": {"name": "lookup", "arguments": '{"query":"state"}'},
|
||||
}
|
||||
tool_result = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_project_state",
|
||||
"content": '{"state":"ready"}',
|
||||
}
|
||||
tool_assistant = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Need the project state.",
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
final_assistant = {
|
||||
"role": "assistant",
|
||||
"content": "The project is ready.",
|
||||
"reasoning_content": "The tool result is enough to answer.",
|
||||
}
|
||||
agent_initial_prior = [
|
||||
{"role": "system", "content": "Agent mode."},
|
||||
agent_user,
|
||||
]
|
||||
agent_final_prior = [*agent_initial_prior, tool_assistant, tool_result]
|
||||
self.store.store_assistant_message(
|
||||
tool_assistant,
|
||||
cache_scope(agent_initial_prior),
|
||||
DEFAULT_CACHE_NAMESPACE,
|
||||
agent_initial_prior,
|
||||
)
|
||||
self.store.store_assistant_message(
|
||||
final_assistant,
|
||||
cache_scope(agent_final_prior),
|
||||
DEFAULT_CACHE_NAMESPACE,
|
||||
agent_final_prior,
|
||||
)
|
||||
|
||||
prepared = prepare_upstream_request(
|
||||
{
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Plan mode."},
|
||||
plan_user,
|
||||
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
|
||||
tool_result,
|
||||
{"role": "assistant", "content": "The project is ready."},
|
||||
{"role": "user", "content": "continue"},
|
||||
],
|
||||
},
|
||||
ProxyConfig(missing_reasoning_strategy="reject"),
|
||||
self.store,
|
||||
)
|
||||
|
||||
self.assertEqual(prepared.missing_reasoning_messages, 0)
|
||||
self.assertEqual(prepared.patched_reasoning_messages, 2)
|
||||
self.assertEqual(
|
||||
prepared.payload["messages"][4]["reasoning_content"],
|
||||
"The tool result is enough to answer.",
|
||||
)
|
||||
|
||||
def test_portable_turn_cache_isolated_for_reused_tool_call_id(self) -> None:
|
||||
tool_call = {
|
||||
"id": "call_reused",
|
||||
"type": "function",
|
||||
"function": {"name": "lookup", "arguments": "{}"},
|
||||
}
|
||||
assistant_a = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Reasoning for thread A.",
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
assistant_b = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Reasoning for thread B.",
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
prior_a = [
|
||||
{"role": "system", "content": "Agent mode."},
|
||||
{"role": "user", "content": "thread A"},
|
||||
]
|
||||
prior_b = [
|
||||
{"role": "system", "content": "Agent mode."},
|
||||
{"role": "user", "content": "thread B"},
|
||||
]
|
||||
self.store.store_assistant_message(
|
||||
assistant_a,
|
||||
cache_scope(prior_a),
|
||||
DEFAULT_CACHE_NAMESPACE,
|
||||
prior_a,
|
||||
)
|
||||
self.store.store_assistant_message(
|
||||
assistant_b,
|
||||
cache_scope(prior_b),
|
||||
DEFAULT_CACHE_NAMESPACE,
|
||||
prior_b,
|
||||
)
|
||||
|
||||
prepared = prepare_upstream_request(
|
||||
{
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Plan mode."},
|
||||
{"role": "user", "content": "thread A"},
|
||||
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
|
||||
],
|
||||
},
|
||||
ProxyConfig(),
|
||||
self.store,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
prepared.payload["messages"][2]["reasoning_content"],
|
||||
"Reasoning for thread A.",
|
||||
)
|
||||
|
||||
def test_restores_reasoning_when_cursor_drops_tool_call_id_but_keeps_function_call(
|
||||
self,
|
||||
) -> None:
|
||||
|
|
@ -738,6 +925,10 @@ class TransformTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(prepared.missing_reasoning_messages, 0)
|
||||
self.assertEqual(prepared.recovered_reasoning_messages, 0)
|
||||
self.assertEqual(prepared.recovery_dropped_messages, 0)
|
||||
self.assertTrue(prepared.continued_recovery_boundary)
|
||||
self.assertGreater(prepared.retired_prefix_messages, 0)
|
||||
self.assertIsNone(prepared.recovery_notice)
|
||||
self.assertEqual(
|
||||
[message["role"] for message in prepared.payload["messages"]],
|
||||
|
|
@ -756,6 +947,100 @@ class TransformTests(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_recovered_response_is_recorded_under_pre_recovery_scope(self) -> None:
|
||||
old_tool_call = {
|
||||
"id": "call_old",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"arguments": '{"path":"README.md"}',
|
||||
},
|
||||
}
|
||||
new_tool_call = {
|
||||
"id": "call_new",
|
||||
"type": "function",
|
||||
"function": {"name": "lookup", "arguments": '{"query":"new"}'},
|
||||
}
|
||||
first_payload = {
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
{"role": "user", "content": "old model turn"},
|
||||
{"role": "assistant", "content": "", "tool_calls": [old_tool_call]},
|
||||
{"role": "tool", "tool_call_id": "call_old", "content": "old result"},
|
||||
{"role": "user", "content": "continue with DeepSeek"},
|
||||
],
|
||||
}
|
||||
first_recovered = prepare_upstream_request(
|
||||
first_payload,
|
||||
ProxyConfig(missing_reasoning_strategy="recover"),
|
||||
self.store,
|
||||
)
|
||||
self.assertEqual(first_recovered.recovered_reasoning_messages, 1)
|
||||
|
||||
response_body = json.dumps(
|
||||
{
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"model": "deepseek-v4-pro",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "tool_calls",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Need the new lookup.",
|
||||
"tool_calls": [new_tool_call],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
).encode()
|
||||
rewritten = rewrite_response_body(
|
||||
response_body,
|
||||
"deepseek-v4-pro",
|
||||
self.store,
|
||||
first_recovered.payload["messages"],
|
||||
first_recovered.cache_namespace,
|
||||
content_prefix=first_recovered.recovery_notice,
|
||||
recording_contexts=first_recovered.record_response_contexts,
|
||||
)
|
||||
recovered_assistant = json.loads(rewritten)["choices"][0]["message"]
|
||||
self.assertEqual(len(first_recovered.record_response_contexts), 2)
|
||||
for scope, _messages in first_recovered.record_response_contexts:
|
||||
self.assertEqual(
|
||||
self.store.get(
|
||||
f"scope:{scope}:signature:{message_signature(recovered_assistant)}"
|
||||
),
|
||||
"Need the new lookup.",
|
||||
)
|
||||
recovered_assistant.pop("reasoning_content", None)
|
||||
|
||||
second_payload = {
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
*first_payload["messages"],
|
||||
recovered_assistant,
|
||||
{"role": "tool", "tool_call_id": "call_new", "content": "new result"},
|
||||
],
|
||||
}
|
||||
|
||||
second_prepared = prepare_upstream_request(
|
||||
second_payload,
|
||||
ProxyConfig(missing_reasoning_strategy="recover"),
|
||||
self.store,
|
||||
)
|
||||
|
||||
self.assertEqual(second_prepared.missing_reasoning_messages, 0)
|
||||
self.assertEqual(second_prepared.recovered_reasoning_messages, 0)
|
||||
self.assertEqual(second_prepared.recovery_dropped_messages, 0)
|
||||
self.assertTrue(second_prepared.continued_recovery_boundary)
|
||||
self.assertGreater(second_prepared.retired_prefix_messages, 0)
|
||||
self.assertEqual(
|
||||
second_prepared.payload["messages"][2]["reasoning_content"],
|
||||
"Need the new lookup.",
|
||||
)
|
||||
|
||||
def test_recovery_boundary_accepts_legacy_notice_text(self) -> None:
|
||||
legacy_recovery_notice = (
|
||||
"Note: recovered this DeepSeek chat because older tool-call reasoning "
|
||||
|
|
@ -844,6 +1129,10 @@ class TransformTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(prepared.missing_reasoning_messages, 0)
|
||||
self.assertEqual(prepared.recovered_reasoning_messages, 0)
|
||||
self.assertEqual(prepared.recovery_dropped_messages, 0)
|
||||
self.assertTrue(prepared.continued_recovery_boundary)
|
||||
self.assertGreater(prepared.retired_prefix_messages, 0)
|
||||
self.assertIsNone(prepared.recovery_notice)
|
||||
self.assertEqual(
|
||||
prepared.payload["messages"][2]["reasoning_content"],
|
||||
|
|
@ -985,6 +1274,64 @@ class TransformTests(unittest.TestCase):
|
|||
self.assertEqual(prepared.missing_reasoning_messages, 1)
|
||||
self.assertNotIn("reasoning_content", prepared.payload["messages"][1])
|
||||
|
||||
def test_deepseek_pro_and_flash_share_reasoning_namespace(self) -> None:
|
||||
config = ProxyConfig()
|
||||
namespace_pro = reasoning_cache_namespace(
|
||||
config,
|
||||
"deepseek-v4-pro",
|
||||
{"type": "enabled"},
|
||||
"high",
|
||||
"Bearer key-a",
|
||||
)
|
||||
namespace_flash = reasoning_cache_namespace(
|
||||
config,
|
||||
"deepseek-v4-flash",
|
||||
{"type": "enabled"},
|
||||
"high",
|
||||
"Bearer key-a",
|
||||
)
|
||||
self.assertEqual(namespace_pro, namespace_flash)
|
||||
|
||||
prior = [{"role": "user", "content": "read README"}]
|
||||
tool_call = {
|
||||
"id": "call_shared",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"arguments": '{"path":"README.md"}',
|
||||
},
|
||||
}
|
||||
self.store.store_assistant_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Shared DeepSeek reasoning.",
|
||||
"tool_calls": [tool_call],
|
||||
},
|
||||
conversation_scope(prior, namespace_pro),
|
||||
namespace_pro,
|
||||
prior,
|
||||
)
|
||||
|
||||
prepared = prepare_upstream_request(
|
||||
{
|
||||
"model": "deepseek-v4-flash",
|
||||
"messages": [
|
||||
*prior,
|
||||
{"role": "assistant", "content": "", "tool_calls": [tool_call]},
|
||||
],
|
||||
},
|
||||
config,
|
||||
self.store,
|
||||
authorization="Bearer key-a",
|
||||
)
|
||||
|
||||
self.assertEqual(prepared.missing_reasoning_messages, 0)
|
||||
self.assertEqual(
|
||||
prepared.payload["messages"][1]["reasoning_content"],
|
||||
"Shared DeepSeek reasoning.",
|
||||
)
|
||||
|
||||
def test_converted_function_message_uses_tool_schema(self) -> None:
|
||||
payload = {
|
||||
"model": "deepseek-v4-pro",
|
||||
|
|
|
|||
Loading…
Reference in New Issue