diff --git a/src/deepseek_cursor_proxy/reasoning_store.py b/src/deepseek_cursor_proxy/reasoning_store.py index 385c974..dc30ac4 100644 --- a/src/deepseek_cursor_proxy/reasoning_store.py +++ b/src/deepseek_cursor_proxy/reasoning_store.py @@ -62,6 +62,13 @@ def message_signature(message: dict[str, Any]) -> str: return hashlib.sha256(canonical.encode("utf-8")).hexdigest() +def _sha256_json(payload: Any) -> str: + canonical = json.dumps( + payload, ensure_ascii=False, sort_keys=True, separators=(",", ":") + ) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest() + + def canonical_scope_message(message: dict[str, Any]) -> dict[str, Any]: canonical: dict[str, Any] = {"role": message.get("role")} for key in ("content", "name", "tool_call_id", "prefix"): @@ -81,10 +88,71 @@ def conversation_scope(messages: list[dict[str, Any]], namespace: str = "") -> s payload: Any = scope_messages if namespace: payload = {"namespace": namespace, "messages": scope_messages} - canonical = json.dumps( - payload, ensure_ascii=False, sort_keys=True, separators=(",", ":") + return _sha256_json(payload) + + +def turn_context_signature(prior_messages: list[dict[str, Any]]) -> str: + last_user_index = next( + ( + index + for index in range(len(prior_messages) - 1, -1, -1) + if prior_messages[index].get("role") == "user" + ), + -1, ) - return hashlib.sha256(canonical.encode("utf-8")).hexdigest() + start_index = 0 + if last_user_index != -1: + start_index = last_user_index + while start_index > 0 and prior_messages[start_index - 1].get("role") == "user": + start_index -= 1 + + context_messages = [ + canonical_scope_message(message) + for message in prior_messages[start_index:] + if message.get("role") != "system" + ] + return _sha256_json(context_messages) + + +def scoped_reasoning_keys(message: dict[str, Any], scope: str) -> list[str]: + keys = [f"scope:{scope}:signature:{message_signature(message)}"] + keys.extend( + f"scope:{scope}:tool_call:{tool_call_id}" + for tool_call_id in tool_call_ids(message) + ) + keys.extend( + f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}" + for tool_call in (message.get("tool_calls") or []) + if isinstance(tool_call, dict) + ) + return keys + + +def portable_reasoning_keys( + message: dict[str, Any], + cache_namespace: str, + prior_messages: list[dict[str, Any]], +) -> list[str]: + if not cache_namespace: + return [] + + turn_signature = turn_context_signature(prior_messages) + keys = [ + f"namespace:{cache_namespace}:turn:{turn_signature}:" + f"signature:{message_signature(message)}" + ] + keys.extend( + f"namespace:{cache_namespace}:turn:{turn_signature}:" + f"tool_call:{tool_call_id}" + for tool_call_id in tool_call_ids(message) + ) + keys.extend( + f"namespace:{cache_namespace}:turn:{turn_signature}:" + f"tool_call_signature:{tool_call_signature(tool_call)}" + for tool_call in (message.get("tool_calls") or []) + if isinstance(tool_call, dict) + ) + return keys class ReasoningStore: @@ -155,45 +223,65 @@ class ReasoningStore: return None return str(row[0]) - def store_assistant_message(self, message: dict[str, Any], scope: str) -> int: + def store_assistant_message( + self, + message: dict[str, Any], + scope: str, + cache_namespace: str = "", + prior_messages: list[dict[str, Any]] | None = None, + ) -> int: if message.get("role") != "assistant": return 0 reasoning = message.get("reasoning_content") if not isinstance(reasoning, str): return 0 - keys = [f"scope:{scope}:signature:{message_signature(message)}"] - keys.extend( - f"scope:{scope}:tool_call:{tool_call_id}" - for tool_call_id in tool_call_ids(message) - ) - keys.extend( - f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}" - for tool_call in (message.get("tool_calls") or []) - if isinstance(tool_call, dict) - ) + keys = scoped_reasoning_keys(message, scope) + if prior_messages is not None: + keys.extend( + portable_reasoning_keys(message, cache_namespace, prior_messages) + ) + keys = list(dict.fromkeys(keys)) for key in keys: self.put(key, reasoning, message) return len(keys) - def lookup_for_message(self, message: dict[str, Any], scope: str) -> str | None: - reasoning = self.get(f"scope:{scope}:signature:{message_signature(message)}") - if reasoning is not None: - return reasoning - for tool_call_id in tool_call_ids(message): - reasoning = self.get(f"scope:{scope}:tool_call:{tool_call_id}") - if reasoning is not None: - return reasoning - for tool_call in message.get("tool_calls") or []: - if not isinstance(tool_call, dict): - continue - reasoning = self.get( - f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}" + def lookup_for_message( + self, + message: dict[str, Any], + scope: str, + cache_namespace: str = "", + prior_messages: list[dict[str, Any]] | None = None, + ) -> str | None: + keys = scoped_reasoning_keys(message, scope) + if prior_messages is not None: + keys.extend( + portable_reasoning_keys(message, cache_namespace, prior_messages) ) + for key in keys: + reasoning = self.get(key) if reasoning is not None: return reasoning return None + def backfill_portable_aliases( + self, + message: dict[str, Any], + reasoning: str, + cache_namespace: str, + prior_messages: list[dict[str, Any]], + ) -> int: + if not isinstance(reasoning, str): + return 0 + keys = portable_reasoning_keys(message, cache_namespace, prior_messages) + if not keys: + return 0 + message_with_reasoning = dict(message) + message_with_reasoning["reasoning_content"] = reasoning + for key in dict.fromkeys(keys): + self.put(key, reasoning, message_with_reasoning) + return len(keys) + def clear(self) -> int: with self._lock: row = self._conn.execute("SELECT COUNT(*) FROM reasoning_cache").fetchone() diff --git a/src/deepseek_cursor_proxy/server.py b/src/deepseek_cursor_proxy/server.py index 3726bb4..39be7ee 100644 --- a/src/deepseek_cursor_proxy/server.py +++ b/src/deepseek_cursor_proxy/server.py @@ -160,16 +160,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): ) if prepared.recovered_reasoning_messages: if prepared.recovery_notice: - LOG.warning( - ( - "recovered request because cached reasoning_content was " - "unavailable for %s assistant message(s); omitted %s " - "older message(s) from forwarded history and will show " - "a Cursor notice" - ), - prepared.recovered_reasoning_messages, - prepared.recovery_dropped_messages, - ) + LOG.warning("refreshed reasoning_content history") else: LOG.info( ( @@ -305,7 +296,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): prepared.payload["messages"], prepared.cache_namespace, prepared.recovery_notice, - trace, + trace=trace, + record_response_scope=prepared.record_response_scope, + record_response_messages=prepared.record_response_messages, + record_response_contexts=prepared.record_response_contexts, ) else: sent_response = self._proxy_regular_response( @@ -314,7 +308,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): prepared.payload["messages"], prepared.cache_namespace, prepared.recovery_notice, - trace, + trace=trace, + record_response_scope=prepared.record_response_scope, + record_response_messages=prepared.record_response_messages, + record_response_contexts=prepared.record_response_contexts, ) if not sent_response: self._finish_trace( @@ -549,6 +546,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): cache_namespace: str, recovery_notice: str | None = None, trace: TraceRequest | None = None, + record_response_scope: str | None = None, + record_response_messages: list[dict[str, Any]] | None = None, + record_response_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None, ) -> bool: body = read_response_body(response) upstream_body = body @@ -560,6 +560,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): request_messages, cache_namespace, content_prefix=recovery_notice, + scope=record_response_scope, + prior_messages=record_response_messages, + recording_contexts=record_response_contexts, ) except (json.JSONDecodeError, UnicodeDecodeError) as exc: LOG.warning("failed to rewrite upstream JSON response: %s", exc) @@ -611,6 +614,9 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): cache_namespace: str, recovery_notice: str | None = None, trace: TraceRequest | None = None, + record_response_scope: str | None = None, + record_response_messages: list[dict[str, Any]] | None = None, + record_response_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None, ) -> bool: if trace is not None: trace.record_upstream_response( @@ -645,7 +651,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if self.config.cursor_display_reasoning else None ) - scope = conversation_scope(request_messages, cache_namespace) + scope = ( + record_response_scope + if record_response_scope is not None + else conversation_scope(request_messages, cache_namespace) + ) + response_prior_messages = ( + record_response_messages + if record_response_messages is not None + else request_messages + ) + response_contexts = ( + record_response_contexts + if record_response_contexts is not None + else [(scope, response_prior_messages)] + ) finalized = False pending_recovery_notice = recovery_notice while True: @@ -660,7 +680,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): line, original_model, accumulator, - scope, + cache_namespace, + response_contexts, display_adapter, pending_recovery_notice, trace, @@ -677,7 +698,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if not finalized: if self.config.verbose: log_json("model streaming assistant messages", accumulator.messages()) - stored = accumulator.store_reasoning(self.reasoning_store, scope) + stored = sum( + accumulator.store_reasoning( + self.reasoning_store, + scope, + cache_namespace, + prior_messages, + ) + for scope, prior_messages in response_contexts + ) if stored: LOG.info("stored %s streaming reasoning cache key(s)", stored) return True @@ -687,7 +716,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): line: bytes, original_model: str, accumulator: StreamAccumulator, - scope: str, + cache_namespace: str, + response_contexts: list[tuple[str, list[dict[str, Any]]]], display_adapter: CursorReasoningDisplayAdapter | None, recovery_notice: str | None = None, trace: TraceRequest | None = None, @@ -700,7 +730,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if data == b"[DONE]": if self.config.verbose: log_json("model streaming assistant messages", accumulator.messages()) - stored = accumulator.store_reasoning(self.reasoning_store, scope) + stored = sum( + accumulator.store_reasoning( + self.reasoning_store, + scope, + cache_namespace, + prior_messages, + ) + for scope, prior_messages in response_contexts + ) if stored: LOG.info("stored %s streaming reasoning cache key(s)", stored) prefix = b"" @@ -728,7 +766,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if recovery_notice and inject_recovery_notice(chunk, recovery_notice): recovery_notice = None accumulator.ingest_chunk(chunk) - stored = accumulator.store_ready_reasoning(self.reasoning_store, scope) + stored = sum( + accumulator.store_ready_reasoning( + self.reasoning_store, + scope, + cache_namespace, + prior_messages, + ) + for scope, prior_messages in response_contexts + ) if stored: LOG.info("stored %s streaming reasoning cache key(s)", stored) if trace is not None: diff --git a/src/deepseek_cursor_proxy/streaming.py b/src/deepseek_cursor_proxy/streaming.py index aafee1e..520be63 100644 --- a/src/deepseek_cursor_proxy/streaming.py +++ b/src/deepseek_cursor_proxy/streaming.py @@ -35,7 +35,7 @@ class StreamingChoice: class StreamAccumulator: def __init__(self) -> None: self.choices: dict[int, StreamingChoice] = {} - self._stored_choices: dict[int, str] = {} + self._stored_choices: dict[tuple[int, str], str] = {} def ingest_chunk(self, chunk: dict[str, Any]) -> None: choices = chunk.get("choices") @@ -70,26 +70,70 @@ class StreamAccumulator: self._merge_tool_call_deltas(choice, delta.get("tool_calls")) - def store_reasoning(self, store: ReasoningStore, scope: str) -> int: + def store_reasoning( + self, + store: ReasoningStore, + scope: str, + cache_namespace: str = "", + prior_messages: list[dict[str, Any]] | None = None, + ) -> int: stored = 0 for index, choice in self.choices.items(): - stored += self._store_choice(index, choice, store, scope) + stored += self._store_choice( + index, choice, store, scope, "final", cache_namespace, prior_messages + ) return stored - def store_finished_reasoning(self, store: ReasoningStore, scope: str) -> int: + def store_finished_reasoning( + self, + store: ReasoningStore, + scope: str, + cache_namespace: str = "", + prior_messages: list[dict[str, Any]] | None = None, + ) -> int: stored = 0 for index, choice in self.choices.items(): if choice.finish_reason is not None: - stored += self._store_choice(index, choice, store, scope, "final") + stored += self._store_choice( + index, + choice, + store, + scope, + "final", + cache_namespace, + prior_messages, + ) return stored - def store_ready_reasoning(self, store: ReasoningStore, scope: str) -> int: + def store_ready_reasoning( + self, + store: ReasoningStore, + scope: str, + cache_namespace: str = "", + prior_messages: list[dict[str, Any]] | None = None, + ) -> int: stored = 0 for index, choice in self.choices.items(): if choice.finish_reason is not None: - stored += self._store_choice(index, choice, store, scope, "final") + stored += self._store_choice( + index, + choice, + store, + scope, + "final", + cache_namespace, + prior_messages, + ) elif self._has_identified_tool_calls(choice): - stored += self._store_choice(index, choice, store, scope, "tool_call") + stored += self._store_choice( + index, + choice, + store, + scope, + "tool_call", + cache_namespace, + prior_messages, + ) return stored def messages(self) -> list[dict[str, Any]]: @@ -141,14 +185,22 @@ class StreamAccumulator: store: ReasoningStore, scope: str, stage: str = "final", + cache_namespace: str = "", + prior_messages: list[dict[str, Any]] | None = None, ) -> int: stage_rank = {"tool_call": 1, "final": 2} - previous_stage = self._stored_choices.get(index) + storage_key = (index, scope) + previous_stage = self._stored_choices.get(storage_key) if stage_rank.get(previous_stage or "", 0) >= stage_rank.get(stage, 0): return 0 - stored = store.store_assistant_message(choice.to_message(), scope) + stored = store.store_assistant_message( + choice.to_message(), + scope, + cache_namespace, + prior_messages, + ) if stored: - self._stored_choices[index] = stage + self._stored_choices[storage_key] = stage return stored def _has_identified_tool_calls(self, choice: StreamingChoice) -> bool: diff --git a/src/deepseek_cursor_proxy/trace.py b/src/deepseek_cursor_proxy/trace.py index fa86680..ec2fa11 100644 --- a/src/deepseek_cursor_proxy/trace.py +++ b/src/deepseek_cursor_proxy/trace.py @@ -105,7 +105,10 @@ def message_summaries(payload: dict[str, Any]) -> list[dict[str, Any]]: len(reasoning) if isinstance(reasoning, str) else 0 ), "has_recovery_notice": content.startswith( - "[deepseek-cursor-proxy] Recovered" + ( + "[deepseek-cursor-proxy] Refreshed reasoning_content history.", + "[deepseek-cursor-proxy] Recovered", + ) ), } summaries.append(summary) @@ -231,6 +234,12 @@ class TraceRequest: "recovered_reasoning_messages": prepared.recovered_reasoning_messages, "recovery_dropped_messages": prepared.recovery_dropped_messages, "recovery_notice": prepared.recovery_notice, + "record_response_scope": prepared.record_response_scope, + "record_response_scopes": [ + scope for scope, _messages in prepared.record_response_contexts + ], + "continued_recovery_boundary": prepared.continued_recovery_boundary, + "retired_prefix_messages": prepared.retired_prefix_messages, "reasoning_diagnostics": prepared.reasoning_diagnostics, "recovery_steps": prepared.recovery_steps, "upstream_request_summary": payload_summary(prepared.payload), diff --git a/src/deepseek_cursor_proxy/transform.py b/src/deepseek_cursor_proxy/transform.py index 7e4ac74..3dda5a5 100644 --- a/src/deepseek_cursor_proxy/transform.py +++ b/src/deepseek_cursor_proxy/transform.py @@ -13,6 +13,7 @@ from .reasoning_store import ( message_signature, tool_call_ids, tool_call_signature, + turn_context_signature, ) @@ -73,10 +74,7 @@ CURSOR_THINKING_BLOCK_RE = re.compile( re.IGNORECASE, ) -RECOVERY_NOTICE_TEXT = ( - "[deepseek-cursor-proxy] Recovered this DeepSeek chat because older " - "tool-call reasoning was unavailable; continuing with recent context only." -) +RECOVERY_NOTICE_TEXT = "[deepseek-cursor-proxy] Refreshed reasoning_content history." LEGACY_RECOVERY_NOTICE_TEXT = ( "Note: recovered this DeepSeek chat because older tool-call reasoning " "was unavailable; continuing with recent context only." @@ -101,8 +99,15 @@ class PreparedRequest: recovered_reasoning_messages: int = 0 recovery_dropped_messages: int = 0 recovery_notice: str | None = None + record_response_scope: str | None = None + record_response_messages: list[dict[str, Any]] = field(default_factory=list) + record_response_contexts: list[tuple[str, list[dict[str, Any]]]] = field( + default_factory=list + ) reasoning_diagnostics: list[dict[str, Any]] = field(default_factory=list) recovery_steps: list[dict[str, Any]] = field(default_factory=list) + continued_recovery_boundary: bool = False + retired_prefix_messages: int = 0 def normalize_reasoning_effort(value: Any) -> str: @@ -260,7 +265,12 @@ def normalize_message( ) lookup_scope = conversation_scope(prior_messages, cache_namespace) lookup_keys = ( - reasoning_lookup_keys(normalized, lookup_scope) + reasoning_lookup_keys( + normalized, + lookup_scope, + cache_namespace, + prior_messages, + ) if needs_reasoning else [] ) @@ -273,6 +283,13 @@ def normalize_message( hit_kind = lookup_key["kind"] normalized["reasoning_content"] = restored patched = True + if not lookup_key.get("portable"): + store.backfill_portable_aliases( + normalized, + restored, + cache_namespace, + prior_messages, + ) break if needs_reasoning and not patched: missing = True @@ -315,11 +332,14 @@ def normalize_message( def reasoning_lookup_keys( message: dict[str, Any], scope: str, + cache_namespace: str = "", + prior_messages: list[dict[str, Any]] | None = None, ) -> list[dict[str, Any]]: keys = [ { "kind": "message_signature", "key": f"scope:{scope}:signature:{message_signature(message)}", + "portable": False, "hit": False, } ] @@ -328,6 +348,7 @@ def reasoning_lookup_keys( "kind": "tool_call_id", "tool_call_id": tool_call_id, "key": f"scope:{scope}:tool_call:{tool_call_id}", + "portable": False, "hit": False, } for tool_call_id in tool_call_ids(message) @@ -340,11 +361,57 @@ def reasoning_lookup_keys( f"scope:{scope}:tool_call_signature:" f"{tool_call_signature(tool_call)}" ), + "portable": False, "hit": False, } for tool_call in (message.get("tool_calls") or []) if isinstance(tool_call, dict) ) + if cache_namespace and prior_messages is not None: + turn_signature = turn_context_signature(prior_messages) + keys.append( + { + "kind": "portable_message_signature", + "key": ( + f"namespace:{cache_namespace}:turn:{turn_signature}:" + f"signature:{message_signature(message)}" + ), + "turn_context_signature": turn_signature, + "portable": True, + "hit": False, + } + ) + keys.extend( + { + "kind": "portable_tool_call_id", + "tool_call_id": tool_call_id, + "key": ( + f"namespace:{cache_namespace}:turn:{turn_signature}:" + f"tool_call:{tool_call_id}" + ), + "turn_context_signature": turn_signature, + "portable": True, + "hit": False, + } + for tool_call_id in tool_call_ids(message) + ) + keys.extend( + { + "kind": "portable_tool_call_signature", + "function_name": str( + (tool_call.get("function") or {}).get("name") or "" + ), + "key": ( + f"namespace:{cache_namespace}:turn:{turn_signature}:" + f"tool_call_signature:{tool_call_signature(tool_call)}" + ), + "turn_context_signature": turn_signature, + "portable": True, + "hit": False, + } + for tool_call in (message.get("tool_calls") or []) + if isinstance(tool_call, dict) + ) return keys @@ -399,6 +466,52 @@ def leading_system_messages(messages: list[dict[str, Any]]) -> list[dict[str, An return leading_messages +def active_messages_from_recovery_boundary( + messages: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], int, dict[str, Any]] | None: + recovery_boundary_index = next( + ( + index + for index in range(len(messages) - 1, -1, -1) + if has_recovery_notice(messages[index]) + ), + -1, + ) + if recovery_boundary_index == -1: + return None + + context_user_index = next( + ( + index + for index in range(recovery_boundary_index - 1, -1, -1) + if messages[index].get("role") == "user" + ), + -1, + ) + leading_messages = leading_system_messages(messages) + recovered_tail = [] + if context_user_index != -1: + recovered_tail.append(messages[context_user_index]) + recovered_tail.extend(messages[recovery_boundary_index:]) + active_messages = [ + *leading_messages, + {"role": "system", "content": RECOVERY_SYSTEM_CONTENT}, + *recovered_tail, + ] + kept_context_messages = 1 if context_user_index != -1 else 0 + retired_messages = ( + recovery_boundary_index - len(leading_messages) - kept_context_messages + ) + retired_messages = max(retired_messages, 0) + step = { + "strategy": "continued_recovery_boundary", + "recovery_boundary_index": recovery_boundary_index, + "context_user_index": context_user_index, + "retired_prefix_messages": retired_messages, + } + return active_messages, retired_messages, step + + def recover_messages_from_missing_reasoning( messages: list[dict[str, Any]], missing_indexes: list[int], @@ -510,6 +623,12 @@ def upstream_model_for(original_model: str, config: ProxyConfig) -> str: return config.upstream_model +def reasoning_model_family(upstream_model: str) -> str: + if upstream_model in {"deepseek-v4-pro", "deepseek-v4-flash"}: + return "deepseek-v4" + return upstream_model + + def reasoning_cache_namespace( config: ProxyConfig, upstream_model: str, @@ -522,7 +641,7 @@ def reasoning_cache_namespace( auth_hash = hashlib.sha256(authorization.encode("utf-8")).hexdigest() payload = { "base_url": config.upstream_base_url, - "model": upstream_model, + "model": reasoning_model_family(upstream_model), "thinking": thinking, "reasoning_effort": reasoning_effort, "authorization_hash": auth_hash, @@ -533,6 +652,22 @@ def reasoning_cache_namespace( return hashlib.sha256(canonical.encode("utf-8")).hexdigest() +def response_recording_contexts( + *items: tuple[str, list[dict[str, Any]]] | None, +) -> list[tuple[str, list[dict[str, Any]]]]: + contexts: list[tuple[str, list[dict[str, Any]]]] = [] + seen: set[str] = set() + for item in items: + if item is None: + continue + scope, messages = item + if scope in seen: + continue + seen.add(scope) + contexts.append((scope, messages)) + return contexts + + def prepare_upstream_request( payload: dict[str, Any], config: ProxyConfig, @@ -596,19 +731,40 @@ def prepare_upstream_request( prepared.get("reasoning_effort"), authorization, ) + pre_repair_messages, _, _, _ = normalize_messages( + payload.get("messages"), + None, + cache_namespace, + repair_reasoning=False, + keep_reasoning=not thinking_disabled, + ) + record_response_messages = pre_repair_messages + record_response_scope = conversation_scope( + record_response_messages, cache_namespace + ) + messages_for_repair = pre_repair_messages + continued_recovery_boundary = False + retired_prefix_messages = 0 + recovered_count = 0 + recovery_dropped_messages = 0 + recovery_notice = None + recovery_steps: list[dict[str, Any]] = [] + if thinking_enabled and config.missing_reasoning_strategy == "recover": + boundary = active_messages_from_recovery_boundary(pre_repair_messages) + if boundary is not None: + messages_for_repair, retired_prefix_messages, boundary_step = boundary + continued_recovery_boundary = True + recovery_steps.append(boundary_step) + messages, patched_count, missing_indexes, reasoning_diagnostics = ( normalize_messages( - payload.get("messages"), + messages_for_repair, store, cache_namespace, repair_reasoning=thinking_enabled, keep_reasoning=not thinking_disabled, ) ) - recovered_count = 0 - recovery_dropped_messages = 0 - recovery_notice = None - recovery_steps: list[dict[str, Any]] = [] while missing_indexes and config.missing_reasoning_strategy == "recover": recovered_messages, dropped_messages, notice, recovery_step = ( recover_messages_from_missing_reasoning(messages, missing_indexes) @@ -634,6 +790,11 @@ def prepare_upstream_request( ) reasoning_diagnostics.extend(latest_diagnostics) prepared["messages"] = messages + active_record_response_scope = conversation_scope(messages, cache_namespace) + record_response_contexts = response_recording_contexts( + (record_response_scope, record_response_messages), + (active_record_response_scope, messages), + ) return PreparedRequest( payload=prepared, @@ -645,8 +806,13 @@ def prepare_upstream_request( recovered_reasoning_messages=recovered_count, recovery_dropped_messages=recovery_dropped_messages, recovery_notice=recovery_notice, + record_response_scope=record_response_scope, + record_response_messages=record_response_messages, + record_response_contexts=record_response_contexts, reasoning_diagnostics=reasoning_diagnostics, recovery_steps=recovery_steps, + continued_recovery_boundary=continued_recovery_boundary, + retired_prefix_messages=retired_prefix_messages, ) @@ -655,6 +821,9 @@ def record_response_reasoning( store: ReasoningStore | None, request_messages: list[dict[str, Any]], cache_namespace: str = "", + scope: str | None = None, + prior_messages: list[dict[str, Any]] | None = None, + recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None, ) -> int: if store is None: return 0 @@ -662,13 +831,28 @@ def record_response_reasoning( choices = response_payload.get("choices") if not isinstance(choices, list): return stored - scope = conversation_scope(request_messages, cache_namespace) + if recording_contexts is None: + response_scope = ( + scope + if scope is not None + else conversation_scope(request_messages, cache_namespace) + ) + response_prior_messages = ( + prior_messages if prior_messages is not None else request_messages + ) + recording_contexts = [(response_scope, response_prior_messages)] for choice in choices: if not isinstance(choice, dict): continue message = choice.get("message") if isinstance(message, dict): - stored += store.store_assistant_message(message, scope) + for response_scope, response_prior_messages in recording_contexts: + stored += store.store_assistant_message( + message, + response_scope, + cache_namespace, + response_prior_messages, + ) return stored @@ -679,13 +863,22 @@ def rewrite_response_body( request_messages: list[dict[str, Any]], cache_namespace: str = "", content_prefix: str | None = None, + scope: str | None = None, + prior_messages: list[dict[str, Any]] | None = None, + recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None, ) -> bytes: response_payload = json.loads(body.decode("utf-8")) if isinstance(response_payload, dict): if content_prefix: prefix_response_content(response_payload, content_prefix) record_response_reasoning( - response_payload, store, request_messages, cache_namespace + response_payload, + store, + request_messages, + cache_namespace, + scope=scope, + prior_messages=prior_messages, + recording_contexts=recording_contexts, ) if "model" in response_payload: response_payload["model"] = original_model diff --git a/tests/test_proxy_end_to_end.py b/tests/test_proxy_end_to_end.py index 739da7f..b7e2da4 100644 --- a/tests/test_proxy_end_to_end.py +++ b/tests/test_proxy_end_to_end.py @@ -738,7 +738,7 @@ class ProxyEndToEndTests(unittest.TestCase): {"role": "user", "content": "Thanks, now continue."}, ) self.assertIn( - "cached reasoning_content was unavailable", + "refreshed reasoning_content history", "\n".join(captured.output), ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index cd2a2de..ade94c0 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -116,6 +116,50 @@ class StreamAccumulatorTests(unittest.TestCase): self.assertEqual(accumulator.store_reasoning(store, scope), 0) store.close() + def test_stores_same_streaming_choice_under_multiple_scopes(self) -> None: + store = ReasoningStore(":memory:") + accumulator = StreamAccumulator() + accumulator.ingest_chunk( + { + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": "Need a tool.", + "tool_calls": [ + { + "index": 0, + "id": "call_stream", + "type": "function", + "function": { + "name": "lookup", + "arguments": "{}", + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + ) + + first_scope = conversation_scope([{"role": "user", "content": "full"}]) + second_scope = conversation_scope([{"role": "user", "content": "active"}]) + first_stored = accumulator.store_finished_reasoning(store, first_scope) + second_stored = accumulator.store_finished_reasoning(store, second_scope) + + self.assertGreater(first_stored, 0) + self.assertGreater(second_stored, 0) + self.assertEqual( + store.get(f"scope:{first_scope}:tool_call:call_stream"), "Need a tool." + ) + self.assertEqual( + store.get(f"scope:{second_scope}:tool_call:call_stream"), "Need a tool." + ) + store.close() + def test_stores_tool_call_reasoning_before_finish_reason(self) -> None: store = ReasoningStore(":memory:") accumulator = StreamAccumulator() diff --git a/tests/test_transform.py b/tests/test_transform.py index f6cbc81..ffcb01c 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -470,6 +470,193 @@ class TransformTests(unittest.TestCase): prepared.payload["messages"][1]["reasoning_content"], "first reasoning" ) + def test_strict_hit_backfills_portable_cache_for_mode_switch(self) -> None: + agent_prior = [ + {"role": "system", "content": "Agent mode."}, + {"role": "user", "content": "set up the task"}, + {"role": "user", "content": "read README"}, + ] + plan_prior = [ + {"role": "system", "content": "Plan mode."}, + {"role": "user", "content": "set up the task"}, + {"role": "user", "content": "read README"}, + ] + tool_call = { + "id": "call_mode_switch", + "type": "function", + "function": {"name": "read_file", "arguments": '{"path":"README.md"}'}, + } + assistant_message = { + "role": "assistant", + "content": "", + "reasoning_content": "Need README before answering.", + "tool_calls": [tool_call], + } + self.store.store_assistant_message( + assistant_message, + cache_scope(agent_prior), + ) + + strict_prepared = prepare_upstream_request( + { + "model": "deepseek-v4-pro", + "messages": [ + *agent_prior, + {"role": "assistant", "content": "", "tool_calls": [tool_call]}, + ], + }, + ProxyConfig(), + self.store, + ) + portable_prepared = prepare_upstream_request( + { + "model": "deepseek-v4-pro", + "messages": [ + *plan_prior, + {"role": "assistant", "content": "", "tool_calls": [tool_call]}, + ], + }, + ProxyConfig(), + self.store, + ) + + self.assertEqual(strict_prepared.patched_reasoning_messages, 1) + self.assertEqual(portable_prepared.patched_reasoning_messages, 1) + self.assertEqual(portable_prepared.missing_reasoning_messages, 0) + self.assertEqual( + portable_prepared.payload["messages"][3]["reasoning_content"], + "Need README before answering.", + ) + self.assertTrue( + str(portable_prepared.reasoning_diagnostics[-1]["hit_kind"]).startswith( + "portable_" + ) + ) + + def test_portable_turn_cache_restores_final_assistant_after_tool_result( + self, + ) -> None: + agent_user = {"role": "user", "content": "look up project state"} + plan_user = dict(agent_user) + tool_call = { + "id": "call_project_state", + "type": "function", + "function": {"name": "lookup", "arguments": '{"query":"state"}'}, + } + tool_result = { + "role": "tool", + "tool_call_id": "call_project_state", + "content": '{"state":"ready"}', + } + tool_assistant = { + "role": "assistant", + "content": "", + "reasoning_content": "Need the project state.", + "tool_calls": [tool_call], + } + final_assistant = { + "role": "assistant", + "content": "The project is ready.", + "reasoning_content": "The tool result is enough to answer.", + } + agent_initial_prior = [ + {"role": "system", "content": "Agent mode."}, + agent_user, + ] + agent_final_prior = [*agent_initial_prior, tool_assistant, tool_result] + self.store.store_assistant_message( + tool_assistant, + cache_scope(agent_initial_prior), + DEFAULT_CACHE_NAMESPACE, + agent_initial_prior, + ) + self.store.store_assistant_message( + final_assistant, + cache_scope(agent_final_prior), + DEFAULT_CACHE_NAMESPACE, + agent_final_prior, + ) + + prepared = prepare_upstream_request( + { + "model": "deepseek-v4-pro", + "messages": [ + {"role": "system", "content": "Plan mode."}, + plan_user, + {"role": "assistant", "content": "", "tool_calls": [tool_call]}, + tool_result, + {"role": "assistant", "content": "The project is ready."}, + {"role": "user", "content": "continue"}, + ], + }, + ProxyConfig(missing_reasoning_strategy="reject"), + self.store, + ) + + self.assertEqual(prepared.missing_reasoning_messages, 0) + self.assertEqual(prepared.patched_reasoning_messages, 2) + self.assertEqual( + prepared.payload["messages"][4]["reasoning_content"], + "The tool result is enough to answer.", + ) + + def test_portable_turn_cache_isolated_for_reused_tool_call_id(self) -> None: + tool_call = { + "id": "call_reused", + "type": "function", + "function": {"name": "lookup", "arguments": "{}"}, + } + assistant_a = { + "role": "assistant", + "content": "", + "reasoning_content": "Reasoning for thread A.", + "tool_calls": [tool_call], + } + assistant_b = { + "role": "assistant", + "content": "", + "reasoning_content": "Reasoning for thread B.", + "tool_calls": [tool_call], + } + prior_a = [ + {"role": "system", "content": "Agent mode."}, + {"role": "user", "content": "thread A"}, + ] + prior_b = [ + {"role": "system", "content": "Agent mode."}, + {"role": "user", "content": "thread B"}, + ] + self.store.store_assistant_message( + assistant_a, + cache_scope(prior_a), + DEFAULT_CACHE_NAMESPACE, + prior_a, + ) + self.store.store_assistant_message( + assistant_b, + cache_scope(prior_b), + DEFAULT_CACHE_NAMESPACE, + prior_b, + ) + + prepared = prepare_upstream_request( + { + "model": "deepseek-v4-pro", + "messages": [ + {"role": "system", "content": "Plan mode."}, + {"role": "user", "content": "thread A"}, + {"role": "assistant", "content": "", "tool_calls": [tool_call]}, + ], + }, + ProxyConfig(), + self.store, + ) + + self.assertEqual( + prepared.payload["messages"][2]["reasoning_content"], + "Reasoning for thread A.", + ) + def test_restores_reasoning_when_cursor_drops_tool_call_id_but_keeps_function_call( self, ) -> None: @@ -738,6 +925,10 @@ class TransformTests(unittest.TestCase): ) self.assertEqual(prepared.missing_reasoning_messages, 0) + self.assertEqual(prepared.recovered_reasoning_messages, 0) + self.assertEqual(prepared.recovery_dropped_messages, 0) + self.assertTrue(prepared.continued_recovery_boundary) + self.assertGreater(prepared.retired_prefix_messages, 0) self.assertIsNone(prepared.recovery_notice) self.assertEqual( [message["role"] for message in prepared.payload["messages"]], @@ -756,6 +947,100 @@ class TransformTests(unittest.TestCase): }, ) + def test_recovered_response_is_recorded_under_pre_recovery_scope(self) -> None: + old_tool_call = { + "id": "call_old", + "type": "function", + "function": { + "name": "read_file", + "arguments": '{"path":"README.md"}', + }, + } + new_tool_call = { + "id": "call_new", + "type": "function", + "function": {"name": "lookup", "arguments": '{"query":"new"}'}, + } + first_payload = { + "model": "deepseek-v4-pro", + "messages": [ + {"role": "user", "content": "old model turn"}, + {"role": "assistant", "content": "", "tool_calls": [old_tool_call]}, + {"role": "tool", "tool_call_id": "call_old", "content": "old result"}, + {"role": "user", "content": "continue with DeepSeek"}, + ], + } + first_recovered = prepare_upstream_request( + first_payload, + ProxyConfig(missing_reasoning_strategy="recover"), + self.store, + ) + self.assertEqual(first_recovered.recovered_reasoning_messages, 1) + + response_body = json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "model": "deepseek-v4-pro", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": "", + "reasoning_content": "Need the new lookup.", + "tool_calls": [new_tool_call], + }, + } + ], + } + ).encode() + rewritten = rewrite_response_body( + response_body, + "deepseek-v4-pro", + self.store, + first_recovered.payload["messages"], + first_recovered.cache_namespace, + content_prefix=first_recovered.recovery_notice, + recording_contexts=first_recovered.record_response_contexts, + ) + recovered_assistant = json.loads(rewritten)["choices"][0]["message"] + self.assertEqual(len(first_recovered.record_response_contexts), 2) + for scope, _messages in first_recovered.record_response_contexts: + self.assertEqual( + self.store.get( + f"scope:{scope}:signature:{message_signature(recovered_assistant)}" + ), + "Need the new lookup.", + ) + recovered_assistant.pop("reasoning_content", None) + + second_payload = { + "model": "deepseek-v4-pro", + "messages": [ + *first_payload["messages"], + recovered_assistant, + {"role": "tool", "tool_call_id": "call_new", "content": "new result"}, + ], + } + + second_prepared = prepare_upstream_request( + second_payload, + ProxyConfig(missing_reasoning_strategy="recover"), + self.store, + ) + + self.assertEqual(second_prepared.missing_reasoning_messages, 0) + self.assertEqual(second_prepared.recovered_reasoning_messages, 0) + self.assertEqual(second_prepared.recovery_dropped_messages, 0) + self.assertTrue(second_prepared.continued_recovery_boundary) + self.assertGreater(second_prepared.retired_prefix_messages, 0) + self.assertEqual( + second_prepared.payload["messages"][2]["reasoning_content"], + "Need the new lookup.", + ) + def test_recovery_boundary_accepts_legacy_notice_text(self) -> None: legacy_recovery_notice = ( "Note: recovered this DeepSeek chat because older tool-call reasoning " @@ -844,6 +1129,10 @@ class TransformTests(unittest.TestCase): ) self.assertEqual(prepared.missing_reasoning_messages, 0) + self.assertEqual(prepared.recovered_reasoning_messages, 0) + self.assertEqual(prepared.recovery_dropped_messages, 0) + self.assertTrue(prepared.continued_recovery_boundary) + self.assertGreater(prepared.retired_prefix_messages, 0) self.assertIsNone(prepared.recovery_notice) self.assertEqual( prepared.payload["messages"][2]["reasoning_content"], @@ -985,6 +1274,64 @@ class TransformTests(unittest.TestCase): self.assertEqual(prepared.missing_reasoning_messages, 1) self.assertNotIn("reasoning_content", prepared.payload["messages"][1]) + def test_deepseek_pro_and_flash_share_reasoning_namespace(self) -> None: + config = ProxyConfig() + namespace_pro = reasoning_cache_namespace( + config, + "deepseek-v4-pro", + {"type": "enabled"}, + "high", + "Bearer key-a", + ) + namespace_flash = reasoning_cache_namespace( + config, + "deepseek-v4-flash", + {"type": "enabled"}, + "high", + "Bearer key-a", + ) + self.assertEqual(namespace_pro, namespace_flash) + + prior = [{"role": "user", "content": "read README"}] + tool_call = { + "id": "call_shared", + "type": "function", + "function": { + "name": "read_file", + "arguments": '{"path":"README.md"}', + }, + } + self.store.store_assistant_message( + { + "role": "assistant", + "content": "", + "reasoning_content": "Shared DeepSeek reasoning.", + "tool_calls": [tool_call], + }, + conversation_scope(prior, namespace_pro), + namespace_pro, + prior, + ) + + prepared = prepare_upstream_request( + { + "model": "deepseek-v4-flash", + "messages": [ + *prior, + {"role": "assistant", "content": "", "tool_calls": [tool_call]}, + ], + }, + config, + self.store, + authorization="Bearer key-a", + ) + + self.assertEqual(prepared.missing_reasoning_messages, 0) + self.assertEqual( + prepared.payload["messages"][1]["reasoning_content"], + "Shared DeepSeek reasoning.", + ) + def test_converted_function_message_uses_tool_schema(self) -> None: payload = { "model": "deepseek-v4-pro",