From 9bb45b7a316a0422b063be72688d76fecd78629e Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Fri, 24 Apr 2026 22:36:28 +0800 Subject: [PATCH] feat(proxy): fail closed with namespace-isolated reasoning cache (#6) --- README.md | 19 +- config.example.yaml | 6 + src/deepseek_cursor_proxy/config.py | 46 ++++ src/deepseek_cursor_proxy/reasoning_store.py | 68 ++++- src/deepseek_cursor_proxy/server.py | 147 ++++++++-- src/deepseek_cursor_proxy/streaming.py | 30 ++- src/deepseek_cursor_proxy/transform.py | 150 +++++++---- tests/test_config.py | 8 + tests/test_proxy_end_to_end.py | 211 ++++++++++++++- tests/test_reasoning_store.py | 46 +++- tests/test_streaming.py | 78 ++++++ tests/test_transform.py | 266 +++++++++++++++++-- 12 files changed, 975 insertions(+), 100 deletions(-) diff --git a/README.md b/README.md index d85b8dc..4ab85e5 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Compatibility proxy connecting Cursor to DeepSeek thinking models (`deepseek-v4- ## What It Does -- ✅ Caches DeepSeek `reasoning_content` from regular and streamed responses, then restores it on later tool-call turns when Cursor omits it. See [DeepSeek docs](https://api-docs.deepseek.com/guides/thinking_mode#tool-calls) for more details. +- ✅ Caches DeepSeek `reasoning_content` from regular and streamed responses, then restores it on later tool-call turns when Cursor omits it. If the exact original reasoning is unavailable, the proxy fails closed instead of sending a fake placeholder. See [DeepSeek docs](https://api-docs.deepseek.com/guides/thinking_mode#tool-calls) for more details. - ✅ Mirrors streamed `reasoning_content` into Cursor-visible `...` text so that thinking tokens are shown in Cursor's UI. For BYOK/proxy mode, Cursor renders this as normal text, not as a native collapsible thinking block. - ✅ Starts an ngrok tunnel so Cursor can reach the local proxy through a public HTTPS URL. - ✅ Provides other compatibility fixes to make DeepSeek models run well in Cursor. @@ -53,6 +53,8 @@ In Cursor, add the DeepSeek custom model and point it at this proxy: - API Key: your DeepSeek API key - Base URL: your ngrok HTTPS URL with the `/v1` API version path +The proxy respects the DeepSeek model name Cursor sends, such as `deepseek-v4-pro` or `deepseek-v4-flash`. The `model` field in `config.yaml` is only the fallback used when a request does not include a model. + For example, if ngrok dashboard shows `https://example.ngrok-free.app`, use: ```text @@ -94,6 +96,15 @@ Select `deepseek-v4-pro` in Cursor and use chat or agent mode as usual. ![Chatting with DeepSeek in Cursor](assets/cursor_chat.png) +## How It Works + +DeepSeek's [thinking mode](https://api-docs.deepseek.com/guides/thinking_mode#tool-calls) requires `reasoning_content` from assistant messages in tool-call sequences to be passed back in later requests. Cursor may omit this field, causing DeepSeek to return a 400 error. This proxy sits between Cursor and DeepSeek (`Cursor → ngrok → proxy → DeepSeek API`) and repairs requests when it has the exact original reasoning cached. + +- Core fix: every DeepSeek response, streaming or non-streaming, has its `reasoning_content` stored in a local SQLite cache keyed by message signature, tool-call ID, and tool-call function signature. On outgoing thinking-mode requests, the proxy restores missing `reasoning_content` for tool-call-related assistant messages and sends the complete history to DeepSeek. If the cache is cold, such as after a proxy restart, it returns a local error instead of fabricating reasoning. +- Multi-conversation isolation: cache keys are scoped by a SHA-256 hash of the canonical conversation prefix (roles, content, tool calls, excluding `reasoning_content`) plus the upstream model/configuration and an API-key hash. Concurrent or interleaved threads with different histories get different scopes, so reused tool-call IDs do not collide. Byte-identical cloned histories are indistinguishable unless Cursor sends a differentiating history. +- DeepSeek [prefix caching](https://api-docs.deepseek.com/guides/kv_cache) compatibility: the proxy does not inject synthetic thread IDs, timestamps, or cache-control messages into the prompt. When it restores cached reasoning, it restores the exact original string, preserving repeated prefixes for DeepSeek's automatic best-effort context cache. +- Additional compatibility fixes: the proxy converts legacy `functions`/`function_call` fields to `tools`/`tool_choice`, preserves required and named tool-choice semantics, normalizes `reasoning_effort` aliases per DeepSeek docs, strips mirrored `` blocks from assistant content, converts multi-part content arrays to plain text, logs DeepSeek prompt-cache usage when available, and mirrors `reasoning_content` into Cursor-visible `...` blocks for thinking display. + ## Debugging Run with verbose output: @@ -114,6 +125,12 @@ Use another config file: deepseek-cursor-proxy --config ./dev.config.yaml ``` +Clear the local reasoning cache: + +```bash +deepseek-cursor-proxy --clear-reasoning-cache +``` + Run tests: ```bash diff --git a/config.example.yaml b/config.example.yaml index f5c950c..4aa6462 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,6 +1,8 @@ # This file was created automatically at ~/.deepseek-cursor-proxy/config.yaml. # API keys are read from Cursor's Authorization header and forwarded upstream. +# `model` is the fallback when a request has no model; Cursor's requested +# DeepSeek model name is otherwise respected. base_url: https://api.deepseek.com model: deepseek-v4-pro thinking: enabled @@ -12,5 +14,9 @@ port: 9000 ngrok: true verbose: false request_timeout: 300 +max_request_body_bytes: 20971520 +cors: false reasoning_content_path: reasoning_content.sqlite3 +reasoning_cache_max_age_seconds: 604800 +reasoning_cache_max_rows: 10000 diff --git a/src/deepseek_cursor_proxy/config.py b/src/deepseek_cursor_proxy/config.py index dc00a06..50f23f2 100644 --- a/src/deepseek_cursor_proxy/config.py +++ b/src/deepseek_cursor_proxy/config.py @@ -18,6 +18,8 @@ MISSING = object() DEFAULT_CONFIG_TEXT = """# This file was created automatically at ~/.deepseek-cursor-proxy/config.yaml. # API keys are read from Cursor's Authorization header and forwarded upstream. +# `model` is the fallback when a request has no model; Cursor's requested +# DeepSeek model name is otherwise respected. base_url: https://api.deepseek.com model: deepseek-v4-pro thinking: enabled @@ -29,8 +31,12 @@ port: 9000 ngrok: true verbose: false request_timeout: 300 +max_request_body_bytes: 20971520 +cors: false reasoning_content_path: reasoning_content.sqlite3 +reasoning_cache_max_age_seconds: 604800 +reasoning_cache_max_rows: 10000 """ @@ -163,8 +169,12 @@ class ProxyConfig: thinking: str = "enabled" reasoning_effort: str = "high" request_timeout: float = 300.0 + max_request_body_bytes: int = 20 * 1024 * 1024 reasoning_content_path: Path = field(default_factory=default_reasoning_content_path) + reasoning_cache_max_age_seconds: int = 7 * 24 * 60 * 60 + reasoning_cache_max_rows: int = 10000 cursor_display_reasoning: bool = True + cors: bool = False verbose: bool = False ngrok: bool = False @@ -260,6 +270,15 @@ class ProxyConfig: ), 300.0, ), + max_request_body_bytes=as_int( + setting_value( + settings, + live_env, + "max_request_body_bytes", + "PROXY_MAX_REQUEST_BODY_BYTES", + ), + 20 * 1024 * 1024, + ), reasoning_content_path=as_path( setting_value( settings, @@ -270,6 +289,24 @@ class ProxyConfig: default_reasoning_content_path(), config_dir, ), + reasoning_cache_max_age_seconds=as_int( + setting_value( + settings, + live_env, + "reasoning_cache_max_age_seconds", + "REASONING_CACHE_MAX_AGE_SECONDS", + ), + 7 * 24 * 60 * 60, + ), + reasoning_cache_max_rows=as_int( + setting_value( + settings, + live_env, + "reasoning_cache_max_rows", + "REASONING_CACHE_MAX_ROWS", + ), + 10000, + ), cursor_display_reasoning=as_bool( setting_value( settings, @@ -279,6 +316,15 @@ class ProxyConfig: ), True, ), + cors=as_bool( + setting_value( + settings, + live_env, + "cors", + "PROXY_CORS", + ), + False, + ), verbose=as_bool( setting_value( settings, diff --git a/src/deepseek_cursor_proxy/reasoning_store.py b/src/deepseek_cursor_proxy/reasoning_store.py index 14f8997..385c974 100644 --- a/src/deepseek_cursor_proxy/reasoning_store.py +++ b/src/deepseek_cursor_proxy/reasoning_store.py @@ -76,8 +76,11 @@ def canonical_scope_message(message: dict[str, Any]) -> dict[str, Any]: return canonical -def conversation_scope(messages: list[dict[str, Any]]) -> str: - payload = [canonical_scope_message(message) for message in messages] +def conversation_scope(messages: list[dict[str, Any]], namespace: str = "") -> str: + scope_messages = [canonical_scope_message(message) for message in messages] + payload: Any = scope_messages + if namespace: + payload = {"namespace": namespace, "messages": scope_messages} canonical = json.dumps( payload, ensure_ascii=False, sort_keys=True, separators=(",", ":") ) @@ -85,7 +88,14 @@ def conversation_scope(messages: list[dict[str, Any]]) -> str: class ReasoningStore: - def __init__(self, reasoning_content_path: str | Path) -> None: + def __init__( + self, + reasoning_content_path: str | Path, + max_age_seconds: int | None = None, + max_rows: int | None = None, + ) -> None: + self.max_age_seconds = max_age_seconds + self.max_rows = max_rows if str(reasoning_content_path) == ":memory:": self.reasoning_content_path: str | Path = ":memory:" else: @@ -110,13 +120,14 @@ class ReasoningStore: """ ) self._conn.commit() + self.prune() def close(self) -> None: with self._lock: self._conn.close() def put(self, key: str, reasoning: str, message: dict[str, Any]) -> None: - if not reasoning: + if not isinstance(reasoning, str): return message_json = json.dumps(message, ensure_ascii=False, sort_keys=True) with self._lock: @@ -131,6 +142,7 @@ class ReasoningStore: """, (key, reasoning, message_json, time.time()), ) + self._prune_locked() self._conn.commit() def get(self, key: str) -> str | None: @@ -147,7 +159,7 @@ class ReasoningStore: if message.get("role") != "assistant": return 0 reasoning = message.get("reasoning_content") - if not isinstance(reasoning, str) or not reasoning: + if not isinstance(reasoning, str): return 0 keys = [f"scope:{scope}:signature:{message_signature(message)}"] @@ -166,11 +178,11 @@ class ReasoningStore: def lookup_for_message(self, message: dict[str, Any], scope: str) -> str | None: reasoning = self.get(f"scope:{scope}:signature:{message_signature(message)}") - if reasoning: + if reasoning is not None: return reasoning for tool_call_id in tool_call_ids(message): reasoning = self.get(f"scope:{scope}:tool_call:{tool_call_id}") - if reasoning: + if reasoning is not None: return reasoning for tool_call in message.get("tool_calls") or []: if not isinstance(tool_call, dict): @@ -178,6 +190,46 @@ class ReasoningStore: reasoning = self.get( f"scope:{scope}:tool_call_signature:{tool_call_signature(tool_call)}" ) - if reasoning: + if reasoning is not None: return reasoning return None + + def clear(self) -> int: + with self._lock: + row = self._conn.execute("SELECT COUNT(*) FROM reasoning_cache").fetchone() + count = int(row[0] if row else 0) + self._conn.execute("DELETE FROM reasoning_cache") + self._conn.commit() + return count + + def prune(self) -> int: + with self._lock: + deleted = self._prune_locked() + self._conn.commit() + return deleted + + def _prune_locked(self) -> int: + deleted = 0 + if self.max_age_seconds is not None and self.max_age_seconds > 0: + cutoff = time.time() - self.max_age_seconds + cursor = self._conn.execute( + "DELETE FROM reasoning_cache WHERE created_at < ?", + (cutoff,), + ) + deleted += cursor.rowcount if cursor.rowcount != -1 else 0 + + if self.max_rows is not None and self.max_rows > 0: + cursor = self._conn.execute( + """ + DELETE FROM reasoning_cache + WHERE key NOT IN ( + SELECT key + FROM reasoning_cache + ORDER BY created_at DESC + LIMIT ? + ) + """, + (self.max_rows,), + ) + deleted += cursor.rowcount if cursor.rowcount != -1 else 0 + return deleted diff --git a/src/deepseek_cursor_proxy/server.py b/src/deepseek_cursor_proxy/server.py index 339622a..4fdce0e 100644 --- a/src/deepseek_cursor_proxy/server.py +++ b/src/deepseek_cursor_proxy/server.py @@ -29,6 +29,10 @@ from .transform import prepare_upstream_request, rewrite_response_body LOG = logging.getLogger("deepseek_cursor_proxy") +class RequestBodyTooLarge(ValueError): + pass + + class DeepSeekProxyServer(ThreadingHTTPServer): config: ProxyConfig reasoning_store: ReasoningStore @@ -102,6 +106,12 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): try: payload = self._read_json_body() + except RequestBodyTooLarge as exc: + LOG.warning( + "rejected request path=%s status=413 reason=%s", request_path, exc + ) + self._send_json(413, {"error": {"message": str(exc)}}) + return except ValueError as exc: LOG.warning( "rejected request path=%s status=400 reason=%s", request_path, exc @@ -114,28 +124,49 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): LOG.info("cursor request: %s", summarize_chat_payload(payload)) - prepared = prepare_upstream_request(payload, self.config, self.reasoning_store) + prepared = prepare_upstream_request( + payload, + self.config, + self.reasoning_store, + authorization=cursor_authorization, + ) if prepared.patched_reasoning_messages: LOG.info( "restored reasoning_content on %s assistant message(s)", prepared.patched_reasoning_messages, ) - if prepared.fallback_reasoning_messages: + if prepared.missing_reasoning_messages: LOG.warning( - "added compatibility reasoning_content placeholder on %s uncached assistant message(s)", - prepared.fallback_reasoning_messages, + "rejected request path=%s status=409 reason=missing_reasoning_content count=%s", + request_path, + prepared.missing_reasoning_messages, ) + self._send_json( + 409, + { + "error": { + "message": ( + "Missing cached DeepSeek reasoning_content for a " + "thinking-mode tool-call history. Retry the tool-call " + "turn so the proxy can capture the original reasoning." + ), + "type": "missing_reasoning_content", + "code": "missing_reasoning_content", + } + }, + ) + return if self.config.verbose: LOG.info( ( "upstream request metadata: original_model=%s upstream_model=%s " - "patched_reasoning=%s fallback_reasoning=%s %s" + "patched_reasoning=%s missing_reasoning=%s %s" ), prepared.original_model, prepared.upstream_model, prepared.patched_reasoning_messages, - prepared.fallback_reasoning_messages, + prepared.missing_reasoning_messages, summarize_chat_payload(prepared.payload), ) @@ -191,22 +222,28 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): ) if prepared.payload.get("stream"): self._proxy_streaming_response( - response, prepared.original_model, prepared.payload["messages"] + response, + prepared.original_model, + prepared.payload["messages"], + prepared.cache_namespace, ) else: self._proxy_regular_response( - response, prepared.original_model, prepared.payload["messages"] + response, + prepared.original_model, + prepared.payload["messages"], + prepared.cache_namespace, ) LOG.info( ( "request complete status=%s stream=%s elapsed_ms=%s " - "patched_reasoning=%s fallback_reasoning=%s" + "patched_reasoning=%s missing_reasoning=%s" ), upstream_status, bool(prepared.payload.get("stream")), elapsed_ms(started), prepared.patched_reasoning_messages, - prepared.fallback_reasoning_messages, + prepared.missing_reasoning_messages, ) def _cursor_authorization(self) -> str | None: @@ -217,6 +254,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): return f"Bearer {token.strip()}" def _send_cors_headers(self) -> None: + if not self.config.cors: + return self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") self.send_header( @@ -239,18 +278,37 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): def _send_models(self) -> None: created = int(time.time()) + model_ids = list( + dict.fromkeys( + [ + self.config.upstream_model, + "deepseek-v4-pro", + "deepseek-v4-flash", + ] + ) + ) models = [ { - "id": self.config.upstream_model, + "id": model_id, "object": "model", "created": created, "owned_by": "deepseek", } + for model_id in model_ids ] self._send_json(200, {"object": "list", "data": models}) def _read_json_body(self) -> dict[str, Any]: - length = int(self.headers.get("Content-Length") or 0) + try: + length = int(self.headers.get("Content-Length") or 0) + except ValueError as exc: + raise ValueError("Invalid Content-Length") from exc + if length < 0: + raise ValueError("Invalid Content-Length") + if length > self.config.max_request_body_bytes: + raise RequestBodyTooLarge( + f"Request body is too large; limit is {self.config.max_request_body_bytes} bytes" + ) raw_body = self.rfile.read(length) if not raw_body: raise ValueError("Request body is empty") @@ -293,14 +351,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): response: Any, original_model: str, request_messages: list[dict[str, Any]], + cache_namespace: str, ) -> None: body = read_response_body(response) try: body = rewrite_response_body( - body, original_model, self.reasoning_store, request_messages + body, + original_model, + self.reasoning_store, + request_messages, + cache_namespace, ) except (json.JSONDecodeError, UnicodeDecodeError) as exc: LOG.warning("failed to rewrite upstream JSON response: %s", exc) + log_cache_usage_from_body(body) if self.config.verbose: log_bytes("cursor response body", body) @@ -319,6 +383,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): response: Any, original_model: str, request_messages: list[dict[str, Any]], + cache_namespace: str, ) -> None: self.send_response(getattr(response, "status", 200)) self._send_cors_headers() @@ -334,7 +399,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if self.config.cursor_display_reasoning else None ) - scope = conversation_scope(request_messages) + scope = conversation_scope(request_messages, cache_namespace) finalized = False while True: line = response.readline() @@ -388,6 +453,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if isinstance(chunk, dict): accumulator.ingest_chunk(chunk) + stored = accumulator.store_finished_reasoning(self.reasoning_store, scope) + if stored: + LOG.info("stored %s streaming reasoning cache key(s)", stored) + log_cache_usage(chunk.get("usage")) if display_adapter is not None: display_adapter.rewrite_chunk(chunk) if "model" in chunk: @@ -421,7 +490,7 @@ def build_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--model", - help="Upstream DeepSeek model, default from config, DEEPSEEK_MODEL, or deepseek-v4-pro", + help="Fallback DeepSeek model when the request has no model, default from config, DEEPSEEK_MODEL, or deepseek-v4-pro", ) parser.add_argument( "--base-url", @@ -450,6 +519,11 @@ def build_arg_parser() -> argparse.ArgumentParser: action="store_true", help="Do not mirror reasoning_content into Cursor-visible content", ) + parser.add_argument( + "--clear-reasoning-cache", + action="store_true", + help="Clear the local reasoning_content SQLite cache and exit", + ) return parser @@ -474,6 +548,25 @@ def log_bytes(label: str, body: bytes) -> None: log_json(label, payload) +def log_cache_usage_from_body(body: bytes) -> None: + try: + payload = json.loads(body.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + return + if isinstance(payload, dict): + log_cache_usage(payload.get("usage")) + + +def log_cache_usage(usage: Any) -> None: + if not isinstance(usage, dict): + return + hit = usage.get("prompt_cache_hit_tokens") + miss = usage.get("prompt_cache_miss_tokens") + if hit is None and miss is None: + return + LOG.info("deepseek prompt cache: hit_tokens=%s miss_tokens=%s", hit, miss) + + def sse_data(payload: dict[str, Any]) -> bytes: return ( b"data: " @@ -509,6 +602,16 @@ def read_response_body(response: Any) -> bytes: return body +def warn_if_insecure_upstream(url: str) -> None: + parsed = urlparse(url) + if parsed.scheme != "http": + return + host = parsed.hostname or "" + if host in {"127.0.0.1", "localhost", "::1"}: + return + LOG.warning("upstream base_url uses plain HTTP; bearer tokens may be exposed") + + def main(argv: list[str] | None = None) -> int: logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s" @@ -539,14 +642,24 @@ def main(argv: list[str] | None = None) -> int: if updates: config = replace(config, **updates) - store = ReasoningStore(config.reasoning_content_path) + warn_if_insecure_upstream(config.upstream_base_url) + store = ReasoningStore( + config.reasoning_content_path, + max_age_seconds=config.reasoning_cache_max_age_seconds, + max_rows=config.reasoning_cache_max_rows, + ) + if args.clear_reasoning_cache: + deleted = store.clear() + LOG.info("cleared %s reasoning cache row(s)", deleted) + store.close() + return 0 server = DeepSeekProxyServer((config.host, config.port), DeepSeekProxyHandler) server.config = config server.reasoning_store = store LOG.info("listening on http://%s:%s/v1", config.host, config.port) LOG.info( - "forwarding to %s/chat/completions as %s", + "forwarding to %s/chat/completions default_model=%s", config.upstream_base_url, config.upstream_model, ) diff --git a/src/deepseek_cursor_proxy/streaming.py b/src/deepseek_cursor_proxy/streaming.py index 9221401..b00de7e 100644 --- a/src/deepseek_cursor_proxy/streaming.py +++ b/src/deepseek_cursor_proxy/streaming.py @@ -16,6 +16,7 @@ class StreamingChoice: role: str = "assistant" content: str = "" reasoning_content: str = "" + has_reasoning_content: bool = False tool_calls: list[dict[str, Any]] = field(default_factory=list) finish_reason: str | None = None @@ -24,7 +25,7 @@ class StreamingChoice: "role": self.role, "content": self.content, } - if self.reasoning_content: + if self.has_reasoning_content: message["reasoning_content"] = self.reasoning_content if self.tool_calls: message["tool_calls"] = self.tool_calls @@ -34,6 +35,7 @@ class StreamingChoice: class StreamAccumulator: def __init__(self) -> None: self.choices: dict[int, StreamingChoice] = {} + self._stored_choices: set[int] = set() def ingest_chunk(self, chunk: dict[str, Any]) -> None: choices = chunk.get("choices") @@ -63,14 +65,22 @@ class StreamAccumulator: reasoning_content = delta.get("reasoning_content") if isinstance(reasoning_content, str): + choice.has_reasoning_content = True choice.reasoning_content += reasoning_content self._merge_tool_call_deltas(choice, delta.get("tool_calls")) def store_reasoning(self, store: ReasoningStore, scope: str) -> int: stored = 0 - for choice in self.choices.values(): - stored += store.store_assistant_message(choice.to_message(), scope) + for index, choice in self.choices.items(): + stored += self._store_choice(index, choice, store, scope) + return stored + + def store_finished_reasoning(self, store: ReasoningStore, scope: str) -> int: + stored = 0 + for index, choice in self.choices.items(): + if choice.finish_reason is not None: + stored += self._store_choice(index, choice, store, scope) return stored def messages(self) -> list[dict[str, Any]]: @@ -115,6 +125,20 @@ class StreamAccumulator: function_delta["arguments"] ) + def _store_choice( + self, + index: int, + choice: StreamingChoice, + store: ReasoningStore, + scope: str, + ) -> int: + if index in self._stored_choices: + return 0 + stored = store.store_assistant_message(choice.to_message(), scope) + if stored: + self._stored_choices.add(index) + return stored + class CursorReasoningDisplayAdapter: """Mirror reasoning_content into content for Cursor's visible thinking UI path.""" diff --git a/src/deepseek_cursor_proxy/transform.py b/src/deepseek_cursor_proxy/transform.py index 51945c6..c7177b1 100644 --- a/src/deepseek_cursor_proxy/transform.py +++ b/src/deepseek_cursor_proxy/transform.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +import hashlib import json import re from typing import Any @@ -72,8 +73,9 @@ class PreparedRequest: payload: dict[str, Any] original_model: str upstream_model: str + cache_namespace: str patched_reasoning_messages: int - fallback_reasoning_messages: int + missing_reasoning_messages: int def normalize_reasoning_effort(value: Any) -> str: @@ -158,26 +160,30 @@ def legacy_function_to_tool(function: Any) -> dict[str, Any]: def convert_function_call(function_call: Any) -> Any: if isinstance(function_call, str): - if function_call in {"auto", "none"}: + if function_call in {"auto", "none", "required"}: return function_call - if function_call == "required": - return "auto" return None if isinstance(function_call, dict) and function_call.get("name"): - return "auto" + return { + "type": "function", + "function": {"name": str(function_call["name"])}, + } return None def normalize_tool_choice(tool_choice: Any) -> Any: if isinstance(tool_choice, str): - if tool_choice in {"auto", "none"}: + if tool_choice in {"auto", "none", "required"}: return tool_choice - if tool_choice == "required": - return "auto" return None if isinstance(tool_choice, dict): if tool_choice.get("type") == "function": - return "auto" + function = tool_choice.get("function") + if isinstance(function, dict) and function.get("name"): + return { + "type": "function", + "function": {"name": str(function["name"])}, + } return tool_choice return tool_choice @@ -186,6 +192,9 @@ def normalize_message( message: Any, store: ReasoningStore | None, prior_messages: list[dict[str, Any]], + cache_namespace: str, + repair_reasoning: bool, + keep_reasoning: bool, ) -> tuple[dict[str, Any], bool, bool]: if not isinstance(message, dict): message = {"role": "user", "content": str(message)} @@ -210,49 +219,62 @@ def normalize_message( ] patched = False - fallback = False + missing = False if normalized["role"] == "assistant": - reasoning = normalized.get("reasoning_content") - if not isinstance(reasoning, str) or not reasoning: + if not keep_reasoning: normalized.pop("reasoning_content", None) - if store is not None: - restored = store.lookup_for_message( - normalized, conversation_scope(prior_messages) + elif repair_reasoning: + reasoning = normalized.get("reasoning_content") + if not isinstance(reasoning, str): + normalized.pop("reasoning_content", None) + needs_reasoning = assistant_needs_reasoning_for_tool_context( + normalized, prior_messages ) - if restored: - normalized["reasoning_content"] = restored - patched = True - if not patched and assistant_needs_reasoning_for_tool_context( - normalized, prior_messages - ): - normalized["reasoning_content"] = fallback_reasoning_content(normalized) - fallback = True + if needs_reasoning and store is not None: + restored = store.lookup_for_message( + normalized, + conversation_scope(prior_messages, cache_namespace), + ) + if restored is not None: + normalized["reasoning_content"] = restored + patched = True + if needs_reasoning and not patched: + missing = True allowed_fields = ROLE_MESSAGE_FIELDS.get(str(normalized["role"]), MESSAGE_FIELDS) normalized = { key: value for key, value in normalized.items() if key in allowed_fields } - return normalized, patched, fallback + return normalized, patched, missing def normalize_messages( - messages: Any, store: ReasoningStore | None + messages: Any, + store: ReasoningStore | None, + cache_namespace: str, + repair_reasoning: bool, + keep_reasoning: bool, ) -> tuple[list[dict[str, Any]], int, int]: if not isinstance(messages, list): return [], 0, 0 normalized_messages: list[dict[str, Any]] = [] patched_count = 0 - fallback_count = 0 + missing_count = 0 for message in messages: - normalized, patched, fallback = normalize_message( - message, store, normalized_messages + normalized, patched, missing = normalize_message( + message, + store, + normalized_messages, + cache_namespace, + repair_reasoning, + keep_reasoning, ) normalized_messages.append(normalized) if patched: patched_count += 1 - if fallback: - fallback_count += 1 - return normalized_messages, patched_count, fallback_count + if missing: + missing_count += 1 + return normalized_messages, patched_count, missing_count def assistant_needs_reasoning_for_tool_context( @@ -270,22 +292,40 @@ def assistant_needs_reasoning_for_tool_context( return False -def fallback_reasoning_content(message: dict[str, Any]) -> str: - if message.get("tool_calls"): - return "Compatibility placeholder: Cursor omitted DeepSeek reasoning_content for this tool-call turn." - return "Compatibility placeholder: Cursor omitted DeepSeek reasoning_content for this tool-result turn." - - def upstream_model_for(original_model: str, config: ProxyConfig) -> str: - if config.allow_model_passthrough and original_model.startswith("deepseek-"): + if original_model.startswith("deepseek-"): return original_model return config.upstream_model +def reasoning_cache_namespace( + config: ProxyConfig, + upstream_model: str, + thinking: Any, + reasoning_effort: Any, + authorization: str | None = None, +) -> str: + auth_hash = "" + if authorization: + auth_hash = hashlib.sha256(authorization.encode("utf-8")).hexdigest() + payload = { + "base_url": config.upstream_base_url, + "model": upstream_model, + "thinking": thinking, + "reasoning_effort": reasoning_effort, + "authorization_hash": auth_hash, + } + canonical = json.dumps( + payload, ensure_ascii=False, sort_keys=True, separators=(",", ":") + ) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest() + + def prepare_upstream_request( payload: dict[str, Any], config: ProxyConfig, store: ReasoningStore | None, + authorization: str | None = None, ) -> PreparedRequest: original_model = str(payload.get("model") or config.upstream_model) upstream_model = upstream_model_for(original_model, config) @@ -297,10 +337,6 @@ def prepare_upstream_request( prepared["max_tokens"] = payload["max_completion_tokens"] prepared["model"] = upstream_model - messages, patched_count, fallback_count = normalize_messages( - payload.get("messages"), store - ) - prepared["messages"] = messages if "tools" in prepared and isinstance(prepared["tools"], list): prepared["tools"] = [normalize_tool(tool) for tool in prepared["tools"]] @@ -325,17 +361,37 @@ def prepare_upstream_request( thinking = prepared.get("thinking") thinking_enabled = isinstance(thinking, dict) and thinking.get("type") == "enabled" + thinking_disabled = ( + isinstance(thinking, dict) and thinking.get("type") == "disabled" + ) if thinking_enabled: prepared["reasoning_effort"] = normalize_reasoning_effort( prepared.get("reasoning_effort") or config.reasoning_effort ) + cache_namespace = reasoning_cache_namespace( + config, + upstream_model, + prepared.get("thinking"), + prepared.get("reasoning_effort"), + authorization, + ) + messages, patched_count, missing_count = normalize_messages( + payload.get("messages"), + store, + cache_namespace, + repair_reasoning=thinking_enabled, + keep_reasoning=not thinking_disabled, + ) + prepared["messages"] = messages + return PreparedRequest( payload=prepared, original_model=original_model, upstream_model=upstream_model, + cache_namespace=cache_namespace, patched_reasoning_messages=patched_count, - fallback_reasoning_messages=fallback_count, + missing_reasoning_messages=missing_count, ) @@ -343,6 +399,7 @@ def record_response_reasoning( response_payload: dict[str, Any], store: ReasoningStore | None, request_messages: list[dict[str, Any]], + cache_namespace: str = "", ) -> int: if store is None: return 0 @@ -350,7 +407,7 @@ def record_response_reasoning( choices = response_payload.get("choices") if not isinstance(choices, list): return stored - scope = conversation_scope(request_messages) + scope = conversation_scope(request_messages, cache_namespace) for choice in choices: if not isinstance(choice, dict): continue @@ -365,10 +422,13 @@ def rewrite_response_body( original_model: str, store: ReasoningStore | None, request_messages: list[dict[str, Any]], + cache_namespace: str = "", ) -> bytes: response_payload = json.loads(body.decode("utf-8")) if isinstance(response_payload, dict): - record_response_reasoning(response_payload, store, request_messages) + record_response_reasoning( + response_payload, store, request_messages, cache_namespace + ) if "model" in response_payload: response_payload["model"] = original_model return json.dumps( diff --git a/tests/test_config.py b/tests/test_config.py index 4394c9a..82a4d03 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -140,12 +140,20 @@ class ConfigTests(unittest.TestCase): env={ "PROXY_VERBOSE": "true", "PROXY_NGROK": "yes", + "PROXY_CORS": "true", + "PROXY_MAX_REQUEST_BODY_BYTES": "1234", + "REASONING_CACHE_MAX_AGE_SECONDS": "60", + "REASONING_CACHE_MAX_ROWS": "50", }, config_path=Path("/does/not/exist"), ) self.assertTrue(config.verbose) self.assertTrue(config.ngrok) + self.assertTrue(config.cors) + self.assertEqual(config.max_request_body_bytes, 1234) + self.assertEqual(config.reasoning_cache_max_age_seconds, 60) + self.assertEqual(config.reasoning_cache_max_rows, 50) def test_cursor_reasoning_display_can_be_disabled_from_config(self) -> None: with TemporaryDirectory() as temp_dir: diff --git a/tests/test_proxy_end_to_end.py b/tests/test_proxy_end_to_end.py index b1aa714..d45a005 100644 --- a/tests/test_proxy_end_to_end.py +++ b/tests/test_proxy_end_to_end.py @@ -16,6 +16,7 @@ from deepseek_cursor_proxy.reasoning_store import ( message_signature, ) from deepseek_cursor_proxy.server import DeepSeekProxyHandler, DeepSeekProxyServer +from deepseek_cursor_proxy.transform import reasoning_cache_namespace TOOL_REASONING = "I need the current date before answering." @@ -253,6 +254,85 @@ class ReasoningStreamingDeepSeekHandler(BaseHTTPRequestHandler): self.wfile.flush() +class ToolCallStreamingBeforeDoneDeepSeekHandler(BaseHTTPRequestHandler): + requests: list[dict] = [] + + def log_message(self, fmt: str, *args: object) -> None: + return + + def do_POST(self) -> None: + length = int(self.headers.get("Content-Length") or 0) + payload = json.loads(self.rfile.read(length).decode("utf-8")) + self.__class__.requests.append(payload) + + if payload.get("stream"): + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.end_headers() + chunks = [ + { + "id": "chatcmpl-stream-tool", + "object": "chat.completion.chunk", + "created": 1, + "model": "deepseek-v4-pro", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": "Streamed tool reasoning.", + "tool_calls": [ + { + "index": 0, + "id": "call_stream_tool", + "type": "function", + "function": { + "name": "lookup", + "arguments": "{}", + }, + } + ], + }, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl-stream-tool", + "object": "chat.completion.chunk", + "created": 1, + "model": "deepseek-v4-pro", + "choices": [ + {"index": 0, "delta": {}, "finish_reason": "tool_calls"} + ], + }, + ] + for chunk in chunks: + self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode("utf-8")) + self.wfile.flush() + time.sleep(1) + self.wfile.write(b"data: [DONE]\n\n") + self.wfile.flush() + return + + messages = payload.get("messages", []) + if ( + len(messages) >= 2 + and messages[1].get("reasoning_content") == "Streamed tool reasoning." + ): + self._send_json(200, plain_response("stream follow-up accepted")) + return + self._send_json(400, {"error": {"message": "missing streamed reasoning"}}) + + def _send_json(self, status: int, payload: dict) -> None: + body = json.dumps(payload).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def tool_call_response() -> dict: return { "id": "chatcmpl-tool", @@ -511,7 +591,21 @@ class ProxyEndToEndTests(unittest.TestCase): self.assertEqual(caught.exception.code, 401) self.assertEqual(FakeDeepSeekHandler.requests, []) - def test_proxy_adds_fallback_reasoning_for_uncached_cursor_tool_history( + def test_proxy_rejects_oversized_request_body(self) -> None: + self.proxy.server.config = replace( + self.proxy.server.config, max_request_body_bytes=10 + ) + + status, payload = post_json( + f"{self.proxy.url}/v1/chat/completions", + first_cursor_request(), + ) + + self.assertEqual(status, 413) + self.assertIn("too large", payload["error"]["message"]) + self.assertEqual(FakeDeepSeekHandler.requests, []) + + def test_proxy_rejects_uncached_cursor_tool_history_without_placeholder( self, ) -> None: status, _ = post_json( @@ -519,9 +613,8 @@ class ProxyEndToEndTests(unittest.TestCase): second_cursor_request(include_reasoning=False), ) - self.assertEqual(status, 200) - upstream_messages = FakeDeepSeekHandler.requests[0]["messages"] - self.assertIn("reasoning_content", upstream_messages[1]) + self.assertEqual(status, 409) + self.assertEqual(FakeDeepSeekHandler.requests, []) class InterleavedConversationTests(unittest.TestCase): @@ -737,10 +830,17 @@ class ReasoningStreamingProxyTests(unittest.TestCase): "content": FINAL_CONTENT, "reasoning_content": "Need context.", } + cache_namespace = reasoning_cache_namespace( + self.proxy.server.config, + "deepseek-v4-pro", + {"type": "enabled"}, + "high", + "Bearer sk-cursor-test", + ) self.assertEqual( self.store.get( "scope:" - + conversation_scope(request_messages) + + conversation_scope(request_messages, cache_namespace) + ":signature:" + message_signature(stored_message) ), @@ -748,6 +848,107 @@ class ReasoningStreamingProxyTests(unittest.TestCase): ) +class StreamingToolRaceProxyTests(unittest.TestCase): + def setUp(self) -> None: + ToolCallStreamingBeforeDoneDeepSeekHandler.requests = [] + self.upstream = ServerFixture( + ThreadingHTTPServer( + ("127.0.0.1", 0), ToolCallStreamingBeforeDoneDeepSeekHandler + ) + ).start() + self.store = ReasoningStore(":memory:") + proxy = DeepSeekProxyServer(("127.0.0.1", 0), DeepSeekProxyHandler) + proxy.config = ProxyConfig( + upstream_base_url=self.upstream.url, + upstream_model="deepseek-v4-pro", + ) + proxy.reasoning_store = self.store + self.proxy = ServerFixture(proxy).start() + + def tearDown(self) -> None: + self.proxy.close() + self.upstream.close() + self.store.close() + + def test_streaming_tool_reasoning_is_available_before_done(self) -> None: + request_messages = [{"role": "user", "content": "stream tool"}] + request = Request( + f"{self.proxy.url}/v1/chat/completions", + data=json.dumps( + { + "model": "deepseek-v4-pro", + "stream": True, + "messages": request_messages, + "tools": [ + { + "type": "function", + "function": { + "name": "lookup", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + } + ).encode("utf-8"), + method="POST", + headers={ + "Authorization": "Bearer sk-cursor-test", + "Content-Type": "application/json", + }, + ) + + with urlopen(request, timeout=3) as response: + while True: + line = response.readline().decode("utf-8") + self.assertNotEqual(line, "") + if '"finish_reason":"tool_calls"' in line: + break + + status, payload = post_json( + f"{self.proxy.url}/v1/chat/completions", + { + "model": "deepseek-v4-pro", + "messages": [ + *request_messages, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_stream_tool", + "type": "function", + "function": { + "name": "lookup", + "arguments": "{}", + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_stream_tool", + "content": "tool result", + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "lookup", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + ) + response.read() + + self.assertEqual(status, 200, payload) + self.assertEqual( + payload["choices"][0]["message"]["content"], "stream follow-up accepted" + ) + + def first_cursor_request() -> dict: return { "model": "deepseek-v4-pro", diff --git a/tests/test_reasoning_store.py b/tests/test_reasoning_store.py index add2d18..9a00307 100644 --- a/tests/test_reasoning_store.py +++ b/tests/test_reasoning_store.py @@ -5,7 +5,7 @@ import stat from tempfile import TemporaryDirectory import unittest -from deepseek_cursor_proxy.reasoning_store import ReasoningStore +from deepseek_cursor_proxy.reasoning_store import ReasoningStore, conversation_scope class ReasoningStoreTests(unittest.TestCase): @@ -21,6 +21,50 @@ class ReasoningStoreTests(unittest.TestCase): self.assertTrue(reasoning_content_path.exists()) self.assertEqual(stat.S_IMODE(reasoning_content_path.stat().st_mode), 0o600) + def test_store_prunes_to_max_rows_and_can_clear(self) -> None: + store = ReasoningStore(":memory:", max_rows=2) + try: + store.put("a", "reasoning a", {"role": "assistant"}) + store.put("b", "reasoning b", {"role": "assistant"}) + store.put("c", "reasoning c", {"role": "assistant"}) + + self.assertIsNone(store.get("a")) + self.assertEqual(store.get("b"), "reasoning b") + self.assertEqual(store.get("c"), "reasoning c") + self.assertEqual(store.clear(), 2) + self.assertIsNone(store.get("b")) + self.assertIsNone(store.get("c")) + finally: + store.close() + + def test_empty_reasoning_content_is_stored_as_present_value(self) -> None: + store = ReasoningStore(":memory:") + try: + scope = conversation_scope([{"role": "user", "content": "lookup"}]) + tool_call = { + "id": "call_empty", + "type": "function", + "function": {"name": "lookup", "arguments": "{}"}, + } + message = { + "role": "assistant", + "content": "", + "reasoning_content": "", + "tool_calls": [tool_call], + } + + self.assertGreater(store.store_assistant_message(message, scope), 0) + self.assertEqual(store.get(f"scope:{scope}:tool_call:call_empty"), "") + self.assertEqual( + store.lookup_for_message( + {"role": "assistant", "content": "", "tool_calls": [tool_call]}, + scope, + ), + "", + ) + finally: + store.close() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 01ad47d..26b7cd3 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -77,6 +77,84 @@ class StreamAccumulatorTests(unittest.TestCase): ) store.close() + def test_stores_reasoning_when_choice_finishes_before_done(self) -> None: + store = ReasoningStore(":memory:") + accumulator = StreamAccumulator() + accumulator.ingest_chunk( + { + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": "Need a tool.", + "tool_calls": [ + { + "index": 0, + "id": "call_stream", + "type": "function", + "function": { + "name": "lookup", + "arguments": "{}", + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + ) + + scope = conversation_scope([{"role": "user", "content": "lookup"}]) + stored = accumulator.store_finished_reasoning(store, scope) + + self.assertGreater(stored, 0) + self.assertEqual( + store.get(f"scope:{scope}:tool_call:call_stream"), "Need a tool." + ) + self.assertEqual(accumulator.store_reasoning(store, scope), 0) + store.close() + + def test_stores_empty_reasoning_content_when_stream_field_is_present( + self, + ) -> None: + store = ReasoningStore(":memory:") + accumulator = StreamAccumulator() + accumulator.ingest_chunk( + { + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": "", + "tool_calls": [ + { + "index": 0, + "id": "call_empty", + "type": "function", + "function": { + "name": "lookup", + "arguments": "{}", + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + ) + + scope = conversation_scope([{"role": "user", "content": "lookup"}]) + stored = accumulator.store_finished_reasoning(store, scope) + + self.assertGreater(stored, 0) + self.assertEqual(store.get(f"scope:{scope}:tool_call:call_empty"), "") + self.assertEqual(accumulator.messages()[0]["reasoning_content"], "") + store.close() + def test_returns_accumulated_messages_for_logging(self) -> None: accumulator = StreamAccumulator() accumulator.ingest_chunk( diff --git a/tests/test_transform.py b/tests/test_transform.py index 910be42..074862a 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -8,11 +8,25 @@ from deepseek_cursor_proxy.reasoning_store import ReasoningStore, conversation_s from deepseek_cursor_proxy.transform import ( extract_text_content, prepare_upstream_request, + reasoning_cache_namespace, rewrite_response_body, strip_cursor_thinking_blocks, ) +DEFAULT_CONFIG = ProxyConfig() +DEFAULT_CACHE_NAMESPACE = reasoning_cache_namespace( + DEFAULT_CONFIG, + "deepseek-v4-pro", + {"type": "enabled"}, + "high", +) + + +def cache_scope(messages: list[dict]) -> str: + return conversation_scope(messages, DEFAULT_CACHE_NAMESPACE) + + class TransformTests(unittest.TestCase): def setUp(self) -> None: self.store = ReasoningStore(":memory:") @@ -75,19 +89,30 @@ class TransformTests(unittest.TestCase): prepared = prepare_upstream_request(payload, config, self.store) self.assertEqual(prepared.original_model, "deepseek-v4-flash") - self.assertEqual(prepared.upstream_model, "deepseek-v4-pro") - self.assertEqual(prepared.payload["model"], "deepseek-v4-pro") + self.assertEqual(prepared.upstream_model, "deepseek-v4-flash") + self.assertEqual(prepared.payload["model"], "deepseek-v4-flash") self.assertEqual(prepared.payload["thinking"], {"type": "enabled"}) self.assertEqual(prepared.payload["reasoning_effort"], "high") self.assertEqual(prepared.payload["max_tokens"], 123) self.assertEqual(prepared.payload["tools"][0]["type"], "function") self.assertEqual( prepared.payload["tool_choice"], - "auto", + {"type": "function", "function": {"name": "lookup"}}, ) self.assertNotIn("parallel_tool_calls", prepared.payload) - def test_normalizes_unsupported_required_tool_choice_to_auto(self) -> None: + def test_uses_config_model_only_when_request_model_is_missing(self) -> None: + prepared = prepare_upstream_request( + {"messages": [{"role": "user", "content": "hi"}]}, + ProxyConfig(upstream_model="deepseek-v4-flash"), + self.store, + ) + + self.assertEqual(prepared.original_model, "deepseek-v4-flash") + self.assertEqual(prepared.upstream_model, "deepseek-v4-flash") + self.assertEqual(prepared.payload["model"], "deepseek-v4-flash") + + def test_preserves_required_tool_choice(self) -> None: payload = { "model": "deepseek-v4-pro", "messages": [{"role": "user", "content": "call a tool"}], @@ -97,7 +122,25 @@ class TransformTests(unittest.TestCase): prepared = prepare_upstream_request(payload, ProxyConfig(), self.store) - self.assertEqual(prepared.payload["tool_choice"], "auto") + self.assertEqual(prepared.payload["tool_choice"], "required") + + def test_preserves_named_tool_choice(self) -> None: + payload = { + "model": "deepseek-v4-pro", + "messages": [{"role": "user", "content": "call lookup"}], + "tools": [{"type": "function", "function": {"name": "lookup"}}], + "tool_choice": { + "type": "function", + "function": {"name": "lookup"}, + }, + } + + prepared = prepare_upstream_request(payload, ProxyConfig(), self.store) + + self.assertEqual( + prepared.payload["tool_choice"], + {"type": "function", "function": {"name": "lookup"}}, + ) def test_restores_reasoning_content_for_cached_tool_call(self) -> None: prior_messages = [{"role": "user", "content": "read README"}] @@ -117,7 +160,7 @@ class TransformTests(unittest.TestCase): ], } self.store.store_assistant_message( - assistant_message, conversation_scope(prior_messages) + assistant_message, cache_scope(prior_messages) ) payload = { @@ -151,6 +194,81 @@ class TransformTests(unittest.TestCase): "Need the file contents before answering.", ) + def test_accepts_empty_reasoning_content_when_present_for_tool_call( + self, + ) -> None: + payload = { + "model": "deepseek-v4-pro", + "messages": [ + {"role": "user", "content": "read README"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "", + "tool_calls": [ + { + "id": "call_empty", + "type": "function", + "function": { + "name": "read_file", + "arguments": '{"path":"README.md"}', + }, + } + ], + }, + {"role": "tool", "tool_call_id": "call_empty", "content": "file text"}, + ], + } + + prepared = prepare_upstream_request(payload, ProxyConfig(), self.store) + + self.assertEqual(prepared.patched_reasoning_messages, 0) + self.assertEqual(prepared.missing_reasoning_messages, 0) + self.assertIn("reasoning_content", prepared.payload["messages"][1]) + self.assertEqual(prepared.payload["messages"][1]["reasoning_content"], "") + + def test_restores_empty_reasoning_content_from_cache(self) -> None: + prior_messages = [{"role": "user", "content": "read README"}] + tool_call = { + "id": "call_empty", + "type": "function", + "function": { + "name": "read_file", + "arguments": '{"path":"README.md"}', + }, + } + self.store.store_assistant_message( + { + "role": "assistant", + "content": "", + "reasoning_content": "", + "tool_calls": [tool_call], + }, + cache_scope(prior_messages), + ) + + prepared = prepare_upstream_request( + { + "model": "deepseek-v4-pro", + "messages": [ + *prior_messages, + {"role": "assistant", "content": "", "tool_calls": [tool_call]}, + { + "role": "tool", + "tool_call_id": "call_empty", + "content": "file text", + }, + ], + }, + ProxyConfig(), + self.store, + ) + + self.assertEqual(prepared.patched_reasoning_messages, 1) + self.assertEqual(prepared.missing_reasoning_messages, 0) + self.assertIn("reasoning_content", prepared.payload["messages"][1]) + self.assertEqual(prepared.payload["messages"][1]["reasoning_content"], "") + def test_restores_reasoning_content_for_cached_final_tool_turn_message( self, ) -> None: @@ -179,7 +297,7 @@ class TransformTests(unittest.TestCase): "reasoning_content": "The tool result is enough to answer.", } self.store.store_assistant_message( - assistant_message, conversation_scope(prior_messages) + assistant_message, cache_scope(prior_messages) ) payload = { @@ -235,8 +353,8 @@ class TransformTests(unittest.TestCase): prior_a = [{"role": "user", "content": "thread A"}] prior_b = [{"role": "user", "content": "thread B"}] - self.store.store_assistant_message(assistant_a, conversation_scope(prior_a)) - self.store.store_assistant_message(assistant_b, conversation_scope(prior_b)) + self.store.store_assistant_message(assistant_a, cache_scope(prior_a)) + self.store.store_assistant_message(assistant_b, cache_scope(prior_b)) payload_a = { "model": "deepseek-v4-pro", @@ -267,7 +385,7 @@ class TransformTests(unittest.TestCase): def test_exact_message_signature_wins_over_tool_call_id_fallback(self) -> None: prior = [{"role": "user", "content": "same conversation prefix"}] - scope = conversation_scope(prior) + scope = cache_scope(prior) first_tool_call = { "id": "call_reused", "type": "function", @@ -336,7 +454,7 @@ class TransformTests(unittest.TestCase): } ], } - self.store.store_assistant_message(assistant_message, conversation_scope(prior)) + self.store.store_assistant_message(assistant_message, cache_scope(prior)) payload = { "model": "deepseek-v4-pro", @@ -386,7 +504,7 @@ class TransformTests(unittest.TestCase): "reasoning_content": "Need to call the file tool.", "tool_calls": [tool_call], }, - conversation_scope(prior), + cache_scope(prior), ) prepared = prepare_upstream_request( @@ -412,7 +530,7 @@ class TransformTests(unittest.TestCase): "Need to call the file tool.", ) - def test_adds_fallback_reasoning_for_uncached_assistant_tool_call(self) -> None: + def test_reports_missing_reasoning_for_uncached_assistant_tool_call(self) -> None: payload = { "model": "deepseek-v4-pro", "messages": [ @@ -442,10 +560,10 @@ class TransformTests(unittest.TestCase): prepared = prepare_upstream_request(payload, ProxyConfig(), self.store) self.assertEqual(prepared.patched_reasoning_messages, 0) - self.assertEqual(prepared.fallback_reasoning_messages, 1) - self.assertIn("reasoning_content", prepared.payload["messages"][1]) + self.assertEqual(prepared.missing_reasoning_messages, 1) + self.assertNotIn("reasoning_content", prepared.payload["messages"][1]) - def test_adds_fallback_reasoning_for_uncached_assistant_after_tool_result( + def test_reports_missing_reasoning_for_uncached_assistant_after_tool_result( self, ) -> None: payload = { @@ -479,10 +597,10 @@ class TransformTests(unittest.TestCase): prepared = prepare_upstream_request(payload, ProxyConfig(), self.store) - self.assertEqual(prepared.fallback_reasoning_messages, 1) - self.assertIn("reasoning_content", prepared.payload["messages"][3]) + self.assertEqual(prepared.missing_reasoning_messages, 1) + self.assertNotIn("reasoning_content", prepared.payload["messages"][3]) - def test_does_not_add_fallback_reasoning_for_plain_chat_history(self) -> None: + def test_does_not_report_missing_reasoning_for_plain_chat_history(self) -> None: payload = { "model": "deepseek-v4-pro", "messages": [ @@ -494,7 +612,86 @@ class TransformTests(unittest.TestCase): prepared = prepare_upstream_request(payload, ProxyConfig(), self.store) - self.assertEqual(prepared.fallback_reasoning_messages, 0) + self.assertEqual(prepared.missing_reasoning_messages, 0) + self.assertNotIn("reasoning_content", prepared.payload["messages"][1]) + + def test_does_not_repair_reasoning_when_thinking_is_disabled(self) -> None: + payload = { + "model": "deepseek-v4-pro", + "messages": [ + {"role": "user", "content": "read README"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Should be removed in non-thinking mode.", + "tool_calls": [ + { + "id": "call_uncached", + "type": "function", + "function": { + "name": "read_file", + "arguments": '{"path":"README.md"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_uncached", + "content": "file text", + }, + ], + } + + prepared = prepare_upstream_request( + payload, ProxyConfig(thinking="disabled"), self.store + ) + + self.assertEqual(prepared.missing_reasoning_messages, 0) + self.assertNotIn("reasoning_content", prepared.payload["messages"][1]) + + def test_reasoning_cache_is_namespaced_by_authorization(self) -> None: + config = ProxyConfig() + prior = [{"role": "user", "content": "read README"}] + namespace_a = reasoning_cache_namespace( + config, + config.upstream_model, + {"type": "enabled"}, + "high", + "Bearer key-a", + ) + tool_call = { + "id": "call_123", + "type": "function", + "function": { + "name": "read_file", + "arguments": '{"path":"README.md"}', + }, + } + self.store.store_assistant_message( + { + "role": "assistant", + "content": "", + "reasoning_content": "Reasoning for key A.", + "tool_calls": [tool_call], + }, + conversation_scope(prior, namespace_a), + ) + + prepared = prepare_upstream_request( + { + "model": "deepseek-v4-pro", + "messages": [ + *prior, + {"role": "assistant", "content": "", "tool_calls": [tool_call]}, + ], + }, + config, + self.store, + authorization="Bearer key-b", + ) + + self.assertEqual(prepared.missing_reasoning_messages, 1) self.assertNotIn("reasoning_content", prepared.payload["messages"][1]) def test_converted_function_message_uses_tool_schema(self) -> None: @@ -561,6 +758,35 @@ class TransformTests(unittest.TestCase): "I need to inspect the repo.", ) + def test_rewrite_response_preserves_prompt_cache_usage_fields(self) -> None: + body = json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "model": "deepseek-v4-pro", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "ok"}, + } + ], + "usage": { + "prompt_tokens": 10, + "prompt_cache_hit_tokens": 6, + "prompt_cache_miss_tokens": 4, + "completion_tokens": 1, + "total_tokens": 11, + }, + } + ).encode() + + rewritten = rewrite_response_body(body, "deepseek-v4-flash", self.store, []) + payload = json.loads(rewritten) + + self.assertEqual(payload["usage"]["prompt_cache_hit_tokens"], 6) + self.assertEqual(payload["usage"]["prompt_cache_miss_tokens"], 4) + if __name__ == "__main__": unittest.main()