From cf3a9d3875591b5f17a82e86c81e114bd61ca520 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Wed, 29 Apr 2026 19:20:56 +0800 Subject: [PATCH] feat(trace): add request tracing and diagnostics (#27) --- .gitignore | 1 + README.md | 6 + src/deepseek_cursor_proxy/config.py | 1 + src/deepseek_cursor_proxy/server.py | 219 +++++++++++++++-- src/deepseek_cursor_proxy/trace.py | 318 +++++++++++++++++++++++++ src/deepseek_cursor_proxy/transform.py | 184 +++++++++++--- tests/test_config.py | 1 + tests/test_proxy_end_to_end.py | 126 ++++++++++ tests/test_server.py | 4 + tests/test_trace.py | 63 +++++ 10 files changed, 877 insertions(+), 46 deletions(-) create mode 100644 src/deepseek_cursor_proxy/trace.py create mode 100644 tests/test_trace.py diff --git a/.gitignore b/.gitignore index 6ea72f7..8966628 100644 --- a/.gitignore +++ b/.gitignore @@ -140,6 +140,7 @@ celerybeat.pid .deepseek_cursor_reasoning.sqlite3* .deepseek-cursor/ .deepseek-cursor-proxy/ +trace-dumps/ reasoning_content.sqlite3* .venv env/ diff --git a/README.md b/README.md index 01bb4e5..f5eb68e 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,12 @@ Run without ngrok for local curl testing: deepseek-cursor-proxy --no-ngrok --port 9000 --verbose ``` +Capture full structured request traces for debugging: + +```bash +deepseek-cursor-proxy --verbose --trace-dir ./trace-dumps +``` + Use another config file: ```bash diff --git a/src/deepseek_cursor_proxy/config.py b/src/deepseek_cursor_proxy/config.py index aaa53c9..fb4981c 100644 --- a/src/deepseek_cursor_proxy/config.py +++ b/src/deepseek_cursor_proxy/config.py @@ -194,6 +194,7 @@ class ProxyConfig: cors: bool = DEFAULT_CORS verbose: bool = DEFAULT_VERBOSE ngrok: bool = DEFAULT_NGROK + trace_dir: Path | None = None @classmethod def from_file( diff --git a/src/deepseek_cursor_proxy/server.py b/src/deepseek_cursor_proxy/server.py index 6914bf6..3726bb4 100644 --- a/src/deepseek_cursor_proxy/server.py +++ b/src/deepseek_cursor_proxy/server.py @@ -23,6 +23,7 @@ from .config import ( ) from .reasoning_store import ReasoningStore, conversation_scope from .streaming import CursorReasoningDisplayAdapter, StreamAccumulator +from .trace import TraceRequest, TraceWriter from .tunnel import NgrokTunnel, local_tunnel_target from .transform import ( RECOVERY_NOTICE_CONTENT, @@ -41,6 +42,7 @@ class RequestBodyTooLarge(ValueError): class DeepSeekProxyServer(ThreadingHTTPServer): config: ProxyConfig reasoning_store: ReasoningStore + trace_writer: TraceWriter | None class DeepSeekProxyHandler(BaseHTTPRequestHandler): @@ -54,6 +56,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): def reasoning_store(self) -> ReasoningStore: return self.server.reasoning_store # type: ignore[return-value] + @property + def trace_writer(self) -> TraceWriter | None: + return getattr(self.server, "trace_writer", None) + def log_message(self, fmt: str, *args: Any) -> None: LOG.info("%s - %s", self.address_string(), fmt % args) @@ -82,6 +88,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): def do_POST(self) -> None: started = time.monotonic() request_path = urlparse(self.path).path + trace = self._start_trace(request_path) if self.config.verbose: LOG.info( "incoming POST %s from %s content_length=%s user_agent=%s", @@ -93,8 +100,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if request_path not in {"/chat/completions", "/v1/chat/completions"}: LOG.warning("rejected unsupported POST path=%s status=404", request_path) self._send_json( - 404, {"error": {"message": "Only /v1/chat/completions is supported"}} + 404, + {"error": {"message": "Only /v1/chat/completions is supported"}}, + trace=trace, ) + self._finish_trace(trace, "rejected", http_status=404) return cursor_authorization = self._cursor_authorization() if cursor_authorization is None: @@ -103,8 +113,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): request_path, ) self._send_json( - 401, {"error": {"message": "Missing Authorization bearer token"}} + 401, + {"error": {"message": "Missing Authorization bearer token"}}, + trace=trace, ) + self._finish_trace(trace, "rejected", http_status=401) return try: @@ -113,15 +126,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): LOG.warning( "rejected request path=%s status=413 reason=%s", request_path, exc ) - self._send_json(413, {"error": {"message": str(exc)}}) + self._send_json(413, {"error": {"message": str(exc)}}, trace=trace) + self._finish_trace(trace, "rejected", http_status=413, reason=str(exc)) return except ValueError as exc: LOG.warning( "rejected request path=%s status=400 reason=%s", request_path, exc ) - self._send_json(400, {"error": {"message": str(exc)}}) + self._send_json(400, {"error": {"message": str(exc)}}, trace=trace) + self._finish_trace(trace, "rejected", http_status=400, reason=str(exc)) return + if trace is not None: + trace.record_cursor_body(payload) + if self.config.verbose: log_json("cursor request body", payload) @@ -133,6 +151,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): self.reasoning_store, authorization=cursor_authorization, ) + if trace is not None: + trace.record_transform(prepared) if prepared.patched_reasoning_messages: LOG.info( "restored reasoning_content on %s assistant message(s)", @@ -187,8 +207,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): "missing_reasoning_messages": prepared.missing_reasoning_messages, } }, + trace=trace, ) + self._finish_trace(trace, "rejected", http_status=409) return + LOG.info( "deepseek send: %s patched=%s recovered=%s", compact_request_stats(prepared.payload), @@ -216,14 +239,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): prepared.payload, ensure_ascii=False, separators=(",", ":") ).encode("utf-8") upstream_url = f"{self.config.upstream_base_url}/chat/completions" + upstream_headers = self._upstream_headers( + stream=bool(prepared.payload.get("stream")), + authorization=cursor_authorization, + ) + if trace is not None: + trace.record_upstream_request( + url=upstream_url, + headers=upstream_headers, + body_bytes=upstream_body, + ) request = Request( upstream_url, data=upstream_body, method="POST", - headers=self._upstream_headers( - stream=bool(prepared.payload.get("stream")), - authorization=cursor_authorization, - ), + headers=upstream_headers, ) try: @@ -237,7 +267,13 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): bool(prepared.payload.get("stream")), elapsed_ms(started), ) - self._send_upstream_error(exc) + self._send_upstream_error(exc, trace=trace) + self._finish_trace( + trace, + "upstream_error", + http_status=exc.code, + stream=bool(prepared.payload.get("stream")), + ) return except URLError as exc: LOG.warning( @@ -246,8 +282,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): exc.reason, ) self._send_json( - 502, {"error": {"message": f"Upstream request failed: {exc.reason}"}} + 502, + {"error": {"message": f"Upstream request failed: {exc.reason}"}}, + trace=trace, ) + self._finish_trace(trace, "upstream_error", http_status=502) return with response: @@ -266,6 +305,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): prepared.payload["messages"], prepared.cache_namespace, prepared.recovery_notice, + trace, ) else: sent_response = self._proxy_regular_response( @@ -274,8 +314,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): prepared.payload["messages"], prepared.cache_namespace, prepared.recovery_notice, + trace, ) if not sent_response: + self._finish_trace( + trace, + "client_disconnected", + http_status=upstream_status, + stream=bool(prepared.payload.get("stream")), + ) return LOG.info( ( @@ -289,6 +336,40 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): prepared.missing_reasoning_messages, prepared.recovered_reasoning_messages, ) + self._finish_trace( + trace, + "completed", + http_status=upstream_status, + stream=bool(prepared.payload.get("stream")), + ) + + def _start_trace(self, request_path: str) -> TraceRequest | None: + writer = self.trace_writer + if writer is None: + return None + try: + return writer.start_request( + method=self.command, + path=request_path, + client_address=self.client_address[0], + headers={name: value for name, value in self.headers.items()}, + ) + except OSError as exc: + LOG.warning("failed to start request trace: %s", exc) + return None + + def _finish_trace( + self, + trace: TraceRequest | None, + status: str, + **extra: Any, + ) -> None: + if trace is None: + return + try: + trace.finish(status, **extra) + except OSError as exc: + LOG.warning("failed to write request trace: %s", exc) def _cursor_authorization(self) -> str | None: auth_header = self.headers.get("Authorization", "") @@ -309,10 +390,25 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): self.send_header("Access-Control-Expose-Headers", "Content-Length") self.send_header("Access-Control-Allow-Credentials", "true") - def _send_json(self, status: int, payload: dict[str, Any]) -> None: + def _send_json( + self, + status: int, + payload: dict[str, Any], + *, + trace: TraceRequest | None = None, + ) -> None: body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode( "utf-8" ) + if trace is not None: + trace.record_cursor_response( + status=status, + headers={ + "Content-Type": "application/json", + "Content-Length": str(len(body)), + }, + body=body, + ) sent_headers = self._send_response_headers( status, [ @@ -414,15 +510,31 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): headers["Accept-Language"] = accept_language return headers - def _send_upstream_error(self, exc: HTTPError) -> None: + def _send_upstream_error( + self, + exc: HTTPError, + *, + trace: TraceRequest | None = None, + ) -> None: body = read_response_body(exc) if self.config.verbose: log_bytes("upstream error body", body) + headers = { + "Content-Type": exc.headers.get("Content-Type", "application/json"), + "Content-Length": str(len(body)), + } + if trace is not None: + trace.record_upstream_response( + status=exc.code, + headers={name: value for name, value in exc.headers.items()}, + body=body, + ) + trace.record_cursor_response(status=exc.code, headers=headers, body=body) sent_headers = self._send_response_headers( exc.code, [ - ("Content-Type", exc.headers.get("Content-Type", "application/json")), - ("Content-Length", str(len(body))), + ("Content-Type", headers["Content-Type"]), + ("Content-Length", headers["Content-Length"]), ], "sending upstream error headers", ) @@ -436,8 +548,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): request_messages: list[dict[str, Any]], cache_namespace: str, recovery_notice: str | None = None, + trace: TraceRequest | None = None, ) -> bool: body = read_response_body(response) + upstream_body = body try: body = rewrite_response_body( body, @@ -454,14 +568,34 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if self.config.verbose: log_bytes("cursor response body", body) + headers = { + "Content-Type": response.headers.get("Content-Type", "application/json"), + "Content-Length": str(len(body)), + } + if trace is not None: + trace.record_upstream_response( + status=getattr(response, "status", 200), + headers=response_headers(response), + body=upstream_body, + stream=False, + ) + try: + upstream_payload = json.loads(upstream_body.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + upstream_payload = None + if isinstance(upstream_payload, dict): + trace.record_usage(upstream_payload.get("usage")) + trace.record_cursor_response( + status=getattr(response, "status", 200), + headers=headers, + body=body, + ) + sent_headers = self._send_response_headers( getattr(response, "status", 200), [ - ( - "Content-Type", - response.headers.get("Content-Type", "application/json"), - ), - ("Content-Length", str(len(body))), + ("Content-Type", headers["Content-Type"]), + ("Content-Length", headers["Content-Length"]), ], "sending upstream response headers", ) @@ -476,7 +610,22 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): request_messages: list[dict[str, Any]], cache_namespace: str, recovery_notice: str | None = None, + trace: TraceRequest | None = None, ) -> bool: + if trace is not None: + trace.record_upstream_response( + status=getattr(response, "status", 200), + headers=response_headers(response), + stream=True, + ) + trace.record_cursor_response( + status=getattr(response, "status", 200), + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "close", + }, + ) sent_headers = self._send_response_headers( getattr(response, "status", 200), [ @@ -514,7 +663,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): scope, display_adapter, pending_recovery_notice, + trace, ) + if trace is not None: + trace.record_stream_chunk(line, rewritten) if not self._write_to_client( rewritten, "sending streaming response chunk", flush=True ): @@ -538,6 +690,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): scope: str, display_adapter: CursorReasoningDisplayAdapter | None, recovery_notice: str | None = None, + trace: TraceRequest | None = None, ) -> tuple[bytes, bool, str | None]: stripped = line.strip() if not stripped.startswith(b"data:"): @@ -578,6 +731,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): stored = accumulator.store_ready_reasoning(self.reasoning_store, scope) if stored: LOG.info("stored %s streaming reasoning cache key(s)", stored) + if trace is not None: + trace.record_usage(chunk.get("usage")) log_usage(chunk.get("usage")) if display_adapter is not None: display_adapter.rewrite_chunk(chunk) @@ -653,6 +808,11 @@ def build_arg_parser() -> argparse.ArgumentParser: default=None, help="Log detailed request lifecycle metadata and full payloads", ) + parser.add_argument( + "--trace-dir", + type=Path, + help="Write full structured request traces to this directory", + ) parser.add_argument( "--display-reasoning", action=argparse.BooleanOptionalAction, @@ -891,6 +1051,13 @@ def read_response_body(response: Any) -> bytes: return body +def response_headers(response: Any) -> dict[str, str]: + headers = getattr(response, "headers", {}) + if hasattr(headers, "items"): + return {str(name): str(value) for name, value in headers.items()} + return {} + + def warn_if_insecure_upstream(url: str) -> None: parsed = urlparse(url) if parsed.scheme != "http": @@ -930,6 +1097,8 @@ def main(argv: list[str] | None = None) -> int: updates["ngrok"] = args.ngrok if args.verbose is not None: updates["verbose"] = args.verbose + if args.trace_dir is not None: + updates["trace_dir"] = args.trace_dir if args.display_reasoning is not None: updates["cursor_display_reasoning"] = args.display_reasoning if args.cors is not None: @@ -960,9 +1129,18 @@ def main(argv: list[str] | None = None) -> int: LOG.info("cleared %s reasoning cache row(s)", deleted) store.close() return 0 + trace_writer: TraceWriter | None = None + if config.trace_dir is not None: + try: + trace_writer = TraceWriter(config.trace_dir) + except OSError as exc: + LOG.error("failed to initialize trace directory: %s", exc) + store.close() + return 2 server = DeepSeekProxyServer((config.host, config.port), DeepSeekProxyHandler) server.config = config server.reasoning_store = store + server.trace_writer = trace_writer LOG.info("listening on http://%s:%s/v1", config.host, config.port) LOG.info( @@ -988,6 +1166,9 @@ def main(argv: list[str] | None = None) -> int: ) else: LOG.info("logging mode=normal metadata=safe_summaries bodies=false") + if trace_writer is not None: + LOG.info("trace session directory: %s", trace_writer.session_dir) + LOG.warning("trace logging enabled; prompts and code will be written to disk") tunnel: NgrokTunnel | None = None if config.ngrok: diff --git a/src/deepseek_cursor_proxy/trace.py b/src/deepseek_cursor_proxy/trace.py new file mode 100644 index 0000000..fa86680 --- /dev/null +++ b/src/deepseek_cursor_proxy/trace.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +import hashlib +import json +import os +from pathlib import Path +import threading +import time +from typing import Any + + +TRACE_SCHEMA_VERSION = 1 + + +def utc_now_iso() -> str: + return datetime.now(timezone.utc).isoformat(timespec="milliseconds") + + +def sha256_text(value: str) -> str: + return hashlib.sha256(value.encode("utf-8")).hexdigest() + + +def authorization_summary(authorization: str | None) -> dict[str, Any]: + if not authorization: + return {"present": False} + return {"present": True, "sha256": sha256_text(authorization)} + + +def sanitized_headers(headers: dict[str, str] | None) -> dict[str, Any]: + if not headers: + return {} + sanitized: dict[str, Any] = {} + for name, value in headers.items(): + if name.lower() == "authorization": + sanitized[name] = authorization_summary(value) + else: + sanitized[name] = value + return sanitized + + +def jsonable_body(body: bytes) -> dict[str, Any]: + text = body.decode("utf-8", errors="replace") + try: + payload = json.loads(text) + except json.JSONDecodeError: + return {"text": text} + return {"json": payload} + + +def tool_names(payload: dict[str, Any]) -> list[str]: + names: list[str] = [] + tools = payload.get("tools") + if not isinstance(tools, list): + return names + for tool in tools: + if not isinstance(tool, dict): + names.append("") + continue + function = tool.get("function") + if isinstance(function, dict): + names.append(str(function.get("name") or "")) + else: + names.append("") + return names + + +def content_stats(content: Any) -> dict[str, Any]: + if content is None: + text = "" + elif isinstance(content, str): + text = content + else: + text = json.dumps(content, ensure_ascii=False, sort_keys=True) + return {"length": len(text), "sha256": sha256_text(text)} + + +def message_summaries(payload: dict[str, Any]) -> list[dict[str, Any]]: + messages = payload.get("messages") + if not isinstance(messages, list): + return [] + summaries: list[dict[str, Any]] = [] + for index, message in enumerate(messages): + if not isinstance(message, dict): + summaries.append({"index": index, "type": type(message).__name__}) + continue + tool_calls = message.get("tool_calls") + tool_call_ids = [] + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if isinstance(tool_call, dict) and tool_call.get("id"): + tool_call_ids.append(str(tool_call["id"])) + reasoning = message.get("reasoning_content") + content = str(message.get("content") or "") + summary: dict[str, Any] = { + "index": index, + "role": message.get("role"), + "content": content_stats(message.get("content")), + "has_tool_calls": bool(tool_call_ids or tool_calls), + "tool_call_ids": tool_call_ids, + "tool_call_id": message.get("tool_call_id"), + "has_reasoning_content": isinstance(reasoning, str), + "reasoning_content_length": ( + len(reasoning) if isinstance(reasoning, str) else 0 + ), + "has_recovery_notice": content.startswith( + "[deepseek-cursor-proxy] Recovered" + ), + } + summaries.append(summary) + return summaries + + +def payload_summary(payload: dict[str, Any]) -> dict[str, Any]: + messages = payload.get("messages") + if not isinstance(messages, list): + messages = [] + system_hashes = [ + content_stats(message.get("content"))["sha256"] + for message in messages + if isinstance(message, dict) and message.get("role") == "system" + ] + return { + "model": payload.get("model"), + "stream": bool(payload.get("stream")), + "message_count": len(messages), + "tool_count": ( + len(payload.get("tools") or []) + if isinstance(payload.get("tools"), list) + else 0 + ), + "tool_names": tool_names(payload), + "system_prompt_hashes": system_hashes, + "messages": message_summaries(payload), + } + + +def write_json_private(path: Path, payload: dict[str, Any]) -> None: + tmp_path = path.with_name(f".{path.name}.tmp") + tmp_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + tmp_path.chmod(0o600) + tmp_path.replace(path) + path.chmod(0o600) + + +class TraceWriter: + def __init__(self, base_dir: str | Path) -> None: + self.base_dir = Path(base_dir).expanduser() + self.base_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + session_name = ( + datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S.%fZ") + + f"-pid{os.getpid()}" + ) + self.session_dir = self.base_dir / session_name + self.session_dir.mkdir(mode=0o700) + self._lock = threading.Lock() + self._next_sequence = 1 + self._write_manifest() + + def start_request( + self, + *, + method: str, + path: str, + client_address: str, + headers: dict[str, str], + ) -> "TraceRequest": + with self._lock: + sequence = self._next_sequence + self._next_sequence += 1 + trace_path = self.session_dir / f"request-{sequence:06d}.json" + return TraceRequest( + writer=self, + sequence=sequence, + path=trace_path, + data={ + "schema_version": TRACE_SCHEMA_VERSION, + "sequence": sequence, + "created_at": utc_now_iso(), + "request": { + "method": method, + "path": path, + "client_address": client_address, + "headers": sanitized_headers(headers), + }, + "transform": {}, + "upstream": {}, + "cursor_response": {}, + "completion": {}, + }, + ) + + def _write_manifest(self) -> None: + write_json_private( + self.session_dir / "manifest.json", + { + "schema_version": TRACE_SCHEMA_VERSION, + "created_at": utc_now_iso(), + "pid": os.getpid(), + "base_dir": str(self.base_dir), + "session_dir": str(self.session_dir), + "format": "one JSON file per proxied POST request", + }, + ) + + +@dataclass +class TraceRequest: + writer: TraceWriter + sequence: int + path: Path + data: dict[str, Any] + _started: float = field(default_factory=time.monotonic) + _finished: bool = False + + def record_cursor_body(self, payload: dict[str, Any]) -> None: + self.data["request"]["body"] = payload + self.data["request"]["summary"] = payload_summary(payload) + + def record_transform(self, prepared: Any) -> None: + self.data["transform"] = { + "original_model": prepared.original_model, + "upstream_model": prepared.upstream_model, + "cache_namespace": prepared.cache_namespace, + "patched_reasoning_messages": prepared.patched_reasoning_messages, + "missing_reasoning_messages": prepared.missing_reasoning_messages, + "recovered_reasoning_messages": prepared.recovered_reasoning_messages, + "recovery_dropped_messages": prepared.recovery_dropped_messages, + "recovery_notice": prepared.recovery_notice, + "reasoning_diagnostics": prepared.reasoning_diagnostics, + "recovery_steps": prepared.recovery_steps, + "upstream_request_summary": payload_summary(prepared.payload), + "upstream_request_body": prepared.payload, + } + + def record_upstream_request( + self, + *, + url: str, + headers: dict[str, str], + body_bytes: bytes, + ) -> None: + self.data["upstream"]["request"] = { + "url": url, + "headers": sanitized_headers(headers), + "body_bytes": len(body_bytes), + } + + def record_upstream_response( + self, + *, + status: int, + headers: dict[str, str] | None = None, + body: bytes | None = None, + stream: bool | None = None, + ) -> None: + response: dict[str, Any] = {"status": status} + if headers is not None: + response["headers"] = sanitized_headers(headers) + if stream is not None: + response["stream"] = stream + if body is not None: + response["body"] = jsonable_body(body) + self.data["upstream"]["response"] = response + + def record_cursor_response( + self, + *, + status: int, + headers: dict[str, str] | None = None, + body: bytes | None = None, + ) -> None: + response: dict[str, Any] = {"status": status} + if headers is not None: + response["headers"] = sanitized_headers(headers) + if body is not None: + response["body"] = jsonable_body(body) + self.data["cursor_response"].update(response) + + def record_stream_chunk(self, upstream_line: bytes, cursor_line: bytes) -> None: + upstream_stream = self.data["upstream"].setdefault("stream", {"chunks": []}) + cursor_stream = self.data["cursor_response"].setdefault( + "stream", {"chunks": []} + ) + index = len(upstream_stream["chunks"]) + upstream_stream["chunks"].append( + { + "index": index, + "line": upstream_line.decode("utf-8", errors="replace"), + } + ) + cursor_stream["chunks"].append( + { + "index": index, + "line": cursor_line.decode("utf-8", errors="replace"), + } + ) + + def record_usage(self, usage: Any) -> None: + if isinstance(usage, dict): + self.data["upstream"]["usage"] = usage + + def finish(self, status: str, **extra: Any) -> None: + if self._finished: + return + completion = { + "status": status, + "finished_at": utc_now_iso(), + "elapsed_ms": round((time.monotonic() - self._started) * 1000), + } + completion.update(extra) + self.data["completion"] = completion + write_json_private(self.path, self.data) + self._finished = True diff --git a/src/deepseek_cursor_proxy/transform.py b/src/deepseek_cursor_proxy/transform.py index 7ba2eb5..7e4ac74 100644 --- a/src/deepseek_cursor_proxy/transform.py +++ b/src/deepseek_cursor_proxy/transform.py @@ -1,13 +1,19 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field import hashlib import json import re from typing import Any from .config import ProxyConfig -from .reasoning_store import ReasoningStore, conversation_scope +from .reasoning_store import ( + ReasoningStore, + conversation_scope, + message_signature, + tool_call_ids, + tool_call_signature, +) SUPPORTED_REQUEST_FIELDS = { @@ -95,6 +101,8 @@ class PreparedRequest: recovered_reasoning_messages: int = 0 recovery_dropped_messages: int = 0 recovery_notice: str | None = None + reasoning_diagnostics: list[dict[str, Any]] = field(default_factory=list) + recovery_steps: list[dict[str, Any]] = field(default_factory=list) def normalize_reasoning_effort(value: Any) -> str: @@ -214,7 +222,7 @@ def normalize_message( cache_namespace: str, repair_reasoning: bool, keep_reasoning: bool, -) -> tuple[dict[str, Any], bool, bool]: +) -> tuple[dict[str, Any], bool, bool, dict[str, Any] | None]: if not isinstance(message, dict): message = {"role": "user", "content": str(message)} normalized = {key: value for key, value in message.items() if key in MESSAGE_FIELDS} @@ -239,6 +247,7 @@ def normalize_message( patched = False missing = False + diagnostic: dict[str, Any] | None = None if normalized["role"] == "assistant": if not keep_reasoning: normalized.pop("reasoning_content", None) @@ -249,22 +258,94 @@ def normalize_message( needs_reasoning = assistant_needs_reasoning_for_tool_context( normalized, prior_messages ) + lookup_scope = conversation_scope(prior_messages, cache_namespace) + lookup_keys = ( + reasoning_lookup_keys(normalized, lookup_scope) + if needs_reasoning + else [] + ) + hit_kind = None if needs_reasoning and store is not None: - restored = store.lookup_for_message( - normalized, - conversation_scope(prior_messages, cache_namespace), - ) - if restored is not None: - normalized["reasoning_content"] = restored - patched = True + for lookup_key in lookup_keys: + restored = store.get(str(lookup_key["key"])) + if restored is not None: + lookup_key["hit"] = True + hit_kind = lookup_key["kind"] + normalized["reasoning_content"] = restored + patched = True + break if needs_reasoning and not patched: missing = True + if needs_reasoning: + diagnostic = { + "message_index": len(prior_messages), + "role": "assistant", + "needs_reasoning": True, + "had_reasoning_content": False, + "patched": patched, + "missing": missing, + "lookup_scope": lookup_scope, + "message_signature": message_signature(normalized), + "tool_call_ids": tool_call_ids(normalized), + "lookup_keys": lookup_keys, + "hit_kind": hit_kind, + } + elif assistant_needs_reasoning_for_tool_context(normalized, prior_messages): + diagnostic = { + "message_index": len(prior_messages), + "role": "assistant", + "needs_reasoning": True, + "had_reasoning_content": True, + "patched": False, + "missing": False, + "lookup_scope": conversation_scope(prior_messages, cache_namespace), + "message_signature": message_signature(normalized), + "tool_call_ids": tool_call_ids(normalized), + "lookup_keys": [], + "hit_kind": "request", + } allowed_fields = ROLE_MESSAGE_FIELDS.get(str(normalized["role"]), MESSAGE_FIELDS) normalized = { key: value for key, value in normalized.items() if key in allowed_fields } - return normalized, patched, missing + return normalized, patched, missing, diagnostic + + +def reasoning_lookup_keys( + message: dict[str, Any], + scope: str, +) -> list[dict[str, Any]]: + keys = [ + { + "kind": "message_signature", + "key": f"scope:{scope}:signature:{message_signature(message)}", + "hit": False, + } + ] + keys.extend( + { + "kind": "tool_call_id", + "tool_call_id": tool_call_id, + "key": f"scope:{scope}:tool_call:{tool_call_id}", + "hit": False, + } + for tool_call_id in tool_call_ids(message) + ) + keys.extend( + { + "kind": "tool_call_signature", + "function_name": str((tool_call.get("function") or {}).get("name") or ""), + "key": ( + f"scope:{scope}:tool_call_signature:" + f"{tool_call_signature(tool_call)}" + ), + "hit": False, + } + for tool_call in (message.get("tool_calls") or []) + if isinstance(tool_call, dict) + ) + return keys def normalize_messages( @@ -273,14 +354,15 @@ def normalize_messages( cache_namespace: str, repair_reasoning: bool, keep_reasoning: bool, -) -> tuple[list[dict[str, Any]], int, list[int]]: +) -> tuple[list[dict[str, Any]], int, list[int], list[dict[str, Any]]]: if not isinstance(messages, list): - return [], 0, [] + return [], 0, [], [] normalized_messages: list[dict[str, Any]] = [] patched_count = 0 missing_indexes: list[int] = [] + diagnostics: list[dict[str, Any]] = [] for message in messages: - normalized, patched, missing = normalize_message( + normalized, patched, missing, diagnostic = normalize_message( message, store, normalized_messages, @@ -293,7 +375,9 @@ def normalize_messages( patched_count += 1 if missing: missing_indexes.append(len(normalized_messages) - 1) - return normalized_messages, patched_count, missing_indexes + if diagnostic is not None: + diagnostics.append(diagnostic) + return normalized_messages, patched_count, missing_indexes, diagnostics def has_recovery_notice(message: dict[str, Any]) -> bool: @@ -318,7 +402,7 @@ def leading_system_messages(messages: list[dict[str, Any]]) -> list[dict[str, An def recover_messages_from_missing_reasoning( messages: list[dict[str, Any]], missing_indexes: list[int], -) -> tuple[list[dict[str, Any]], int, str | None]: +) -> tuple[list[dict[str, Any]], int, str | None, dict[str, Any]]: recovery_boundary_index = next( ( index @@ -351,7 +435,19 @@ def recover_messages_from_missing_reasoning( omitted_messages = ( recovery_boundary_index - len(leading_messages) - kept_context_messages ) - return recovered, omitted_messages, None + return ( + recovered, + omitted_messages, + None, + { + "strategy": "recovery_boundary", + "missing_indexes": missing_indexes, + "recovery_boundary_index": recovery_boundary_index, + "context_user_index": context_user_index, + "dropped_messages": omitted_messages, + "notice": None, + }, + ) last_user_index = next( ( @@ -362,13 +458,35 @@ def recover_messages_from_missing_reasoning( -1, ) if last_user_index == -1: - return messages, 0, None + return ( + messages, + 0, + None, + { + "strategy": "none", + "missing_indexes": missing_indexes, + "last_user_index": None, + "dropped_messages": 0, + "notice": None, + }, + ) recovered = leading_system_messages(messages) omitted_messages = len(messages) - len(recovered) - 1 recovered.append({"role": "system", "content": RECOVERY_SYSTEM_CONTENT}) recovered.append(messages[last_user_index]) - return recovered, omitted_messages, RECOVERY_NOTICE_CONTENT + return ( + recovered, + omitted_messages, + RECOVERY_NOTICE_CONTENT, + { + "strategy": "latest_user", + "missing_indexes": missing_indexes, + "last_user_index": last_user_index, + "dropped_messages": omitted_messages, + "notice": RECOVERY_NOTICE_CONTENT, + }, + ) def assistant_needs_reasoning_for_tool_context( @@ -478,33 +596,43 @@ def prepare_upstream_request( prepared.get("reasoning_effort"), authorization, ) - messages, patched_count, missing_indexes = normalize_messages( - payload.get("messages"), - store, - cache_namespace, - repair_reasoning=thinking_enabled, - keep_reasoning=not thinking_disabled, + messages, patched_count, missing_indexes, reasoning_diagnostics = ( + normalize_messages( + payload.get("messages"), + store, + cache_namespace, + repair_reasoning=thinking_enabled, + keep_reasoning=not thinking_disabled, + ) ) recovered_count = 0 recovery_dropped_messages = 0 recovery_notice = None + recovery_steps: list[dict[str, Any]] = [] while missing_indexes and config.missing_reasoning_strategy == "recover": - recovered_messages, dropped_messages, notice = ( + recovered_messages, dropped_messages, notice, recovery_step = ( recover_messages_from_missing_reasoning(messages, missing_indexes) ) + recovery_steps.append(recovery_step) if not dropped_messages: break recovered_count += len(missing_indexes) recovery_dropped_messages += dropped_messages if notice: recovery_notice = notice - messages, patched_count, missing_indexes = normalize_messages( + ( + messages, + patched_count, + missing_indexes, + latest_diagnostics, + ) = normalize_messages( recovered_messages, store, cache_namespace, repair_reasoning=thinking_enabled, keep_reasoning=not thinking_disabled, ) + reasoning_diagnostics.extend(latest_diagnostics) prepared["messages"] = messages return PreparedRequest( @@ -517,6 +645,8 @@ def prepare_upstream_request( recovered_reasoning_messages=recovered_count, recovery_dropped_messages=recovery_dropped_messages, recovery_notice=recovery_notice, + reasoning_diagnostics=reasoning_diagnostics, + recovery_steps=recovery_steps, ) diff --git a/tests/test_config.py b/tests/test_config.py index a303053..956d2c9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -39,6 +39,7 @@ class ConfigTests(unittest.TestCase): home / ".deepseek-cursor-proxy" / "reasoning_content.sqlite3", ) self.assertEqual(ProxyConfig().ngrok, DEFAULT_NGROK) + self.assertIsNone(ProxyConfig().trace_dir) def test_missing_default_config_file_is_populated(self) -> None: with TemporaryDirectory() as temp_dir: diff --git a/tests/test_proxy_end_to_end.py b/tests/test_proxy_end_to_end.py index 72a18b2..739da7f 100644 --- a/tests/test_proxy_end_to_end.py +++ b/tests/test_proxy_end_to_end.py @@ -3,6 +3,8 @@ from __future__ import annotations from dataclasses import replace from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer import json +from pathlib import Path +from tempfile import TemporaryDirectory import threading import time import unittest @@ -16,6 +18,7 @@ from deepseek_cursor_proxy.reasoning_store import ( message_signature, ) from deepseek_cursor_proxy.server import DeepSeekProxyHandler, DeepSeekProxyServer +from deepseek_cursor_proxy.trace import TraceWriter from deepseek_cursor_proxy.transform import ( RECOVERY_NOTICE_CONTENT, reasoning_cache_namespace, @@ -249,6 +252,12 @@ class ReasoningStreamingDeepSeekHandler(BaseHTTPRequestHandler): "created": 1, "model": "deepseek-v4-pro", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + "completion_tokens_details": {"reasoning_tokens": 3}, + }, }, ] for chunk in chunks: @@ -476,6 +485,17 @@ class ServerFixture: self.thread.join(timeout=5) +def read_single_trace(session_dir: Path) -> dict: + deadline = time.monotonic() + 2 + trace_files = sorted(session_dir.glob("request-*.json")) + while not trace_files and time.monotonic() < deadline: + time.sleep(0.01) + trace_files = sorted(session_dir.glob("request-*.json")) + if len(trace_files) != 1: + raise AssertionError(f"expected one trace file, found {trace_files}") + return json.loads(trace_files[0].read_text(encoding="utf-8")) + + class ProxyEndToEndTests(unittest.TestCase): def setUp(self) -> None: FakeDeepSeekHandler.requests = [] @@ -598,6 +618,47 @@ class ProxyEndToEndTests(unittest.TestCase): self.assertIn("What is tomorrow's date?", output) self.assertNotIn("sk-from-cursor", output) + def test_trace_captures_full_non_streaming_replay_without_api_key(self) -> None: + with TemporaryDirectory() as temp_dir: + writer = TraceWriter(temp_dir) + self.proxy.server.trace_writer = writer + + status, payload = post_json( + f"{self.proxy.url}/v1/chat/completions", + first_cursor_request(), + api_key="sk-from-cursor", + ) + + trace = read_single_trace(writer.session_dir) + serialized = json.dumps(trace) + + self.assertEqual(status, 200) + self.assertEqual( + payload["choices"][0]["message"]["tool_calls"][0]["id"], "call_date" + ) + self.assertEqual(trace["completion"]["status"], "completed") + self.assertEqual( + trace["request"]["body"]["messages"][0]["content"], + "What is tomorrow's date?", + ) + self.assertEqual( + trace["transform"]["upstream_request_body"]["model"], + "deepseek-v4-pro", + ) + self.assertEqual( + trace["upstream"]["response"]["body"]["json"]["choices"][0]["message"][ + "reasoning_content" + ], + TOOL_REASONING, + ) + self.assertEqual( + trace["cursor_response"]["body"]["json"]["choices"][0]["message"][ + "reasoning_content" + ], + TOOL_REASONING, + ) + self.assertNotIn("sk-from-cursor", serialized) + def test_proxy_rejects_missing_cursor_bearer_token(self) -> None: request = Request( f"{self.proxy.url}/v1/chat/completions", @@ -681,6 +742,32 @@ class ProxyEndToEndTests(unittest.TestCase): "\n".join(captured.output), ) + def test_trace_captures_recovery_diagnostics(self) -> None: + with TemporaryDirectory() as temp_dir: + writer = TraceWriter(temp_dir) + self.proxy.server.trace_writer = writer + + status, _ = post_json( + f"{self.proxy.url}/v1/chat/completions", + third_cursor_request_missing_all_reasoning(), + ) + + trace = read_single_trace(writer.session_dir) + + self.assertEqual(status, 200) + self.assertEqual(trace["transform"]["recovered_reasoning_messages"], 2) + self.assertEqual( + trace["transform"]["recovery_steps"][0]["strategy"], + "latest_user", + ) + missing_diagnostics = [ + item + for item in trace["transform"]["reasoning_diagnostics"] + if item["missing"] + ] + self.assertGreaterEqual(len(missing_diagnostics), 2) + self.assertIn("lookup_keys", missing_diagnostics[0]) + def test_proxy_keeps_deepseek_context_after_recovery_boundary(self) -> None: status, first = post_json( f"{self.proxy.url}/v1/chat/completions", @@ -948,6 +1035,45 @@ class ReasoningStreamingProxyTests(unittest.TestCase): "Need context.", ) + def test_trace_captures_streaming_replay_chunks(self) -> None: + with TemporaryDirectory() as temp_dir: + writer = TraceWriter(temp_dir) + self.proxy.server.trace_writer = writer + request = Request( + f"{self.proxy.url}/v1/chat/completions", + data=json.dumps( + { + "model": "deepseek-v4-pro", + "stream": True, + "messages": [{"role": "user", "content": "stream reasoning"}], + } + ).encode("utf-8"), + method="POST", + headers={ + "Authorization": "Bearer sk-cursor-test", + "Content-Type": "application/json", + }, + ) + + with urlopen(request, timeout=2) as response: + response.read() + + trace = read_single_trace(writer.session_dir) + + self.assertEqual(trace["completion"]["status"], "completed") + self.assertIn( + "reasoning_content", + trace["upstream"]["stream"]["chunks"][0]["line"], + ) + self.assertIn( + "", + trace["cursor_response"]["stream"]["chunks"][0]["line"], + ) + self.assertEqual( + trace["upstream"]["usage"]["completion_tokens_details"]["reasoning_tokens"], + 3, + ) + def test_streaming_recovery_notice_is_visible_in_cursor_content(self) -> None: payload = third_cursor_request_missing_all_reasoning() payload["stream"] = True diff --git a/tests/test_server.py b/tests/test_server.py index 25240f6..8bcc6a3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,6 +3,7 @@ from __future__ import annotations from io import BytesIO import gzip import json +from pathlib import Path from types import SimpleNamespace import unittest import zlib @@ -80,6 +81,8 @@ class ServerTests(unittest.TestCase): "--no-verbose", "--no-display-reasoning", "--cors", + "--trace-dir", + "/tmp/dcp-traces", ] ) @@ -87,6 +90,7 @@ class ServerTests(unittest.TestCase): self.assertFalse(args.verbose) self.assertFalse(args.display_reasoning) self.assertTrue(args.cors) + self.assertEqual(args.trace_dir, Path("/tmp/dcp-traces")) def test_read_response_body_handles_gzip(self) -> None: body = gzip.compress(b'{"ok":true}') diff --git a/tests/test_trace.py b/tests/test_trace.py new file mode 100644 index 0000000..d2be77b --- /dev/null +++ b/tests/test_trace.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import json +import stat +from tempfile import TemporaryDirectory +import unittest + +from deepseek_cursor_proxy.trace import TraceWriter + + +class TraceWriterTests(unittest.TestCase): + def test_writes_manifest_and_numbered_request_files(self) -> None: + with TemporaryDirectory() as temp_dir: + writer = TraceWriter(temp_dir) + first = writer.start_request( + method="POST", + path="/v1/chat/completions", + client_address="127.0.0.1", + headers={"User-Agent": "Cursor/1.0"}, + ) + second = writer.start_request( + method="POST", + path="/v1/chat/completions", + client_address="127.0.0.1", + headers={"User-Agent": "Cursor/1.0"}, + ) + first.finish("completed", http_status=200) + second.finish("completed", http_status=200) + + self.assertTrue((writer.session_dir / "manifest.json").exists()) + self.assertTrue((writer.session_dir / "request-000001.json").exists()) + self.assertTrue((writer.session_dir / "request-000002.json").exists()) + self.assertEqual( + stat.S_IMODE( + (writer.session_dir / "request-000001.json").stat().st_mode + ), + 0o600, + ) + + def test_authorization_header_is_redacted(self) -> None: + with TemporaryDirectory() as temp_dir: + writer = TraceWriter(temp_dir) + trace = writer.start_request( + method="POST", + path="/v1/chat/completions", + client_address="127.0.0.1", + headers={"Authorization": "Bearer sk-secret"}, + ) + trace.finish("completed", http_status=200) + + payload = json.loads(trace.path.read_text(encoding="utf-8")) + serialized = json.dumps(payload) + + self.assertNotIn("sk-secret", serialized) + self.assertEqual( + payload["request"]["headers"]["Authorization"]["present"], + True, + ) + self.assertIn("sha256", payload["request"]["headers"]["Authorization"]) + + +if __name__ == "__main__": + unittest.main()