feat(trace): add request tracing and diagnostics (#27)

main
Yixing Lao 2026-04-29 19:20:56 +08:00 committed by GitHub
parent 0189b9290e
commit cf3a9d3875
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 877 additions and 46 deletions

1
.gitignore vendored
View File

@ -140,6 +140,7 @@ celerybeat.pid
.deepseek_cursor_reasoning.sqlite3* .deepseek_cursor_reasoning.sqlite3*
.deepseek-cursor/ .deepseek-cursor/
.deepseek-cursor-proxy/ .deepseek-cursor-proxy/
trace-dumps/
reasoning_content.sqlite3* reasoning_content.sqlite3*
.venv .venv
env/ env/

View File

@ -168,6 +168,12 @@ Run without ngrok for local curl testing:
deepseek-cursor-proxy --no-ngrok --port 9000 --verbose deepseek-cursor-proxy --no-ngrok --port 9000 --verbose
``` ```
Capture full structured request traces for debugging:
```bash
deepseek-cursor-proxy --verbose --trace-dir ./trace-dumps
```
Use another config file: Use another config file:
```bash ```bash

View File

@ -194,6 +194,7 @@ class ProxyConfig:
cors: bool = DEFAULT_CORS cors: bool = DEFAULT_CORS
verbose: bool = DEFAULT_VERBOSE verbose: bool = DEFAULT_VERBOSE
ngrok: bool = DEFAULT_NGROK ngrok: bool = DEFAULT_NGROK
trace_dir: Path | None = None
@classmethod @classmethod
def from_file( def from_file(

View File

@ -23,6 +23,7 @@ from .config import (
) )
from .reasoning_store import ReasoningStore, conversation_scope from .reasoning_store import ReasoningStore, conversation_scope
from .streaming import CursorReasoningDisplayAdapter, StreamAccumulator from .streaming import CursorReasoningDisplayAdapter, StreamAccumulator
from .trace import TraceRequest, TraceWriter
from .tunnel import NgrokTunnel, local_tunnel_target from .tunnel import NgrokTunnel, local_tunnel_target
from .transform import ( from .transform import (
RECOVERY_NOTICE_CONTENT, RECOVERY_NOTICE_CONTENT,
@ -41,6 +42,7 @@ class RequestBodyTooLarge(ValueError):
class DeepSeekProxyServer(ThreadingHTTPServer): class DeepSeekProxyServer(ThreadingHTTPServer):
config: ProxyConfig config: ProxyConfig
reasoning_store: ReasoningStore reasoning_store: ReasoningStore
trace_writer: TraceWriter | None
class DeepSeekProxyHandler(BaseHTTPRequestHandler): class DeepSeekProxyHandler(BaseHTTPRequestHandler):
@ -54,6 +56,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
def reasoning_store(self) -> ReasoningStore: def reasoning_store(self) -> ReasoningStore:
return self.server.reasoning_store # type: ignore[return-value] return self.server.reasoning_store # type: ignore[return-value]
@property
def trace_writer(self) -> TraceWriter | None:
return getattr(self.server, "trace_writer", None)
def log_message(self, fmt: str, *args: Any) -> None: def log_message(self, fmt: str, *args: Any) -> None:
LOG.info("%s - %s", self.address_string(), fmt % args) LOG.info("%s - %s", self.address_string(), fmt % args)
@ -82,6 +88,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
def do_POST(self) -> None: def do_POST(self) -> None:
started = time.monotonic() started = time.monotonic()
request_path = urlparse(self.path).path request_path = urlparse(self.path).path
trace = self._start_trace(request_path)
if self.config.verbose: if self.config.verbose:
LOG.info( LOG.info(
"incoming POST %s from %s content_length=%s user_agent=%s", "incoming POST %s from %s content_length=%s user_agent=%s",
@ -93,8 +100,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if request_path not in {"/chat/completions", "/v1/chat/completions"}: if request_path not in {"/chat/completions", "/v1/chat/completions"}:
LOG.warning("rejected unsupported POST path=%s status=404", request_path) LOG.warning("rejected unsupported POST path=%s status=404", request_path)
self._send_json( self._send_json(
404, {"error": {"message": "Only /v1/chat/completions is supported"}} 404,
{"error": {"message": "Only /v1/chat/completions is supported"}},
trace=trace,
) )
self._finish_trace(trace, "rejected", http_status=404)
return return
cursor_authorization = self._cursor_authorization() cursor_authorization = self._cursor_authorization()
if cursor_authorization is None: if cursor_authorization is None:
@ -103,8 +113,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
request_path, request_path,
) )
self._send_json( self._send_json(
401, {"error": {"message": "Missing Authorization bearer token"}} 401,
{"error": {"message": "Missing Authorization bearer token"}},
trace=trace,
) )
self._finish_trace(trace, "rejected", http_status=401)
return return
try: try:
@ -113,15 +126,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
LOG.warning( LOG.warning(
"rejected request path=%s status=413 reason=%s", request_path, exc "rejected request path=%s status=413 reason=%s", request_path, exc
) )
self._send_json(413, {"error": {"message": str(exc)}}) self._send_json(413, {"error": {"message": str(exc)}}, trace=trace)
self._finish_trace(trace, "rejected", http_status=413, reason=str(exc))
return return
except ValueError as exc: except ValueError as exc:
LOG.warning( LOG.warning(
"rejected request path=%s status=400 reason=%s", request_path, exc "rejected request path=%s status=400 reason=%s", request_path, exc
) )
self._send_json(400, {"error": {"message": str(exc)}}) self._send_json(400, {"error": {"message": str(exc)}}, trace=trace)
self._finish_trace(trace, "rejected", http_status=400, reason=str(exc))
return return
if trace is not None:
trace.record_cursor_body(payload)
if self.config.verbose: if self.config.verbose:
log_json("cursor request body", payload) log_json("cursor request body", payload)
@ -133,6 +151,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
self.reasoning_store, self.reasoning_store,
authorization=cursor_authorization, authorization=cursor_authorization,
) )
if trace is not None:
trace.record_transform(prepared)
if prepared.patched_reasoning_messages: if prepared.patched_reasoning_messages:
LOG.info( LOG.info(
"restored reasoning_content on %s assistant message(s)", "restored reasoning_content on %s assistant message(s)",
@ -187,8 +207,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
"missing_reasoning_messages": prepared.missing_reasoning_messages, "missing_reasoning_messages": prepared.missing_reasoning_messages,
} }
}, },
trace=trace,
) )
self._finish_trace(trace, "rejected", http_status=409)
return return
LOG.info( LOG.info(
"deepseek send: %s patched=%s recovered=%s", "deepseek send: %s patched=%s recovered=%s",
compact_request_stats(prepared.payload), compact_request_stats(prepared.payload),
@ -216,14 +239,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
prepared.payload, ensure_ascii=False, separators=(",", ":") prepared.payload, ensure_ascii=False, separators=(",", ":")
).encode("utf-8") ).encode("utf-8")
upstream_url = f"{self.config.upstream_base_url}/chat/completions" upstream_url = f"{self.config.upstream_base_url}/chat/completions"
upstream_headers = self._upstream_headers(
stream=bool(prepared.payload.get("stream")),
authorization=cursor_authorization,
)
if trace is not None:
trace.record_upstream_request(
url=upstream_url,
headers=upstream_headers,
body_bytes=upstream_body,
)
request = Request( request = Request(
upstream_url, upstream_url,
data=upstream_body, data=upstream_body,
method="POST", method="POST",
headers=self._upstream_headers( headers=upstream_headers,
stream=bool(prepared.payload.get("stream")),
authorization=cursor_authorization,
),
) )
try: try:
@ -237,7 +267,13 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
bool(prepared.payload.get("stream")), bool(prepared.payload.get("stream")),
elapsed_ms(started), elapsed_ms(started),
) )
self._send_upstream_error(exc) self._send_upstream_error(exc, trace=trace)
self._finish_trace(
trace,
"upstream_error",
http_status=exc.code,
stream=bool(prepared.payload.get("stream")),
)
return return
except URLError as exc: except URLError as exc:
LOG.warning( LOG.warning(
@ -246,8 +282,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
exc.reason, exc.reason,
) )
self._send_json( self._send_json(
502, {"error": {"message": f"Upstream request failed: {exc.reason}"}} 502,
{"error": {"message": f"Upstream request failed: {exc.reason}"}},
trace=trace,
) )
self._finish_trace(trace, "upstream_error", http_status=502)
return return
with response: with response:
@ -266,6 +305,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
prepared.payload["messages"], prepared.payload["messages"],
prepared.cache_namespace, prepared.cache_namespace,
prepared.recovery_notice, prepared.recovery_notice,
trace,
) )
else: else:
sent_response = self._proxy_regular_response( sent_response = self._proxy_regular_response(
@ -274,8 +314,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
prepared.payload["messages"], prepared.payload["messages"],
prepared.cache_namespace, prepared.cache_namespace,
prepared.recovery_notice, prepared.recovery_notice,
trace,
) )
if not sent_response: if not sent_response:
self._finish_trace(
trace,
"client_disconnected",
http_status=upstream_status,
stream=bool(prepared.payload.get("stream")),
)
return return
LOG.info( LOG.info(
( (
@ -289,6 +336,40 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
prepared.missing_reasoning_messages, prepared.missing_reasoning_messages,
prepared.recovered_reasoning_messages, prepared.recovered_reasoning_messages,
) )
self._finish_trace(
trace,
"completed",
http_status=upstream_status,
stream=bool(prepared.payload.get("stream")),
)
def _start_trace(self, request_path: str) -> TraceRequest | None:
writer = self.trace_writer
if writer is None:
return None
try:
return writer.start_request(
method=self.command,
path=request_path,
client_address=self.client_address[0],
headers={name: value for name, value in self.headers.items()},
)
except OSError as exc:
LOG.warning("failed to start request trace: %s", exc)
return None
def _finish_trace(
self,
trace: TraceRequest | None,
status: str,
**extra: Any,
) -> None:
if trace is None:
return
try:
trace.finish(status, **extra)
except OSError as exc:
LOG.warning("failed to write request trace: %s", exc)
def _cursor_authorization(self) -> str | None: def _cursor_authorization(self) -> str | None:
auth_header = self.headers.get("Authorization", "") auth_header = self.headers.get("Authorization", "")
@ -309,10 +390,25 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
self.send_header("Access-Control-Expose-Headers", "Content-Length") self.send_header("Access-Control-Expose-Headers", "Content-Length")
self.send_header("Access-Control-Allow-Credentials", "true") self.send_header("Access-Control-Allow-Credentials", "true")
def _send_json(self, status: int, payload: dict[str, Any]) -> None: def _send_json(
self,
status: int,
payload: dict[str, Any],
*,
trace: TraceRequest | None = None,
) -> None:
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode( body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode(
"utf-8" "utf-8"
) )
if trace is not None:
trace.record_cursor_response(
status=status,
headers={
"Content-Type": "application/json",
"Content-Length": str(len(body)),
},
body=body,
)
sent_headers = self._send_response_headers( sent_headers = self._send_response_headers(
status, status,
[ [
@ -414,15 +510,31 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
headers["Accept-Language"] = accept_language headers["Accept-Language"] = accept_language
return headers return headers
def _send_upstream_error(self, exc: HTTPError) -> None: def _send_upstream_error(
self,
exc: HTTPError,
*,
trace: TraceRequest | None = None,
) -> None:
body = read_response_body(exc) body = read_response_body(exc)
if self.config.verbose: if self.config.verbose:
log_bytes("upstream error body", body) log_bytes("upstream error body", body)
headers = {
"Content-Type": exc.headers.get("Content-Type", "application/json"),
"Content-Length": str(len(body)),
}
if trace is not None:
trace.record_upstream_response(
status=exc.code,
headers={name: value for name, value in exc.headers.items()},
body=body,
)
trace.record_cursor_response(status=exc.code, headers=headers, body=body)
sent_headers = self._send_response_headers( sent_headers = self._send_response_headers(
exc.code, exc.code,
[ [
("Content-Type", exc.headers.get("Content-Type", "application/json")), ("Content-Type", headers["Content-Type"]),
("Content-Length", str(len(body))), ("Content-Length", headers["Content-Length"]),
], ],
"sending upstream error headers", "sending upstream error headers",
) )
@ -436,8 +548,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
request_messages: list[dict[str, Any]], request_messages: list[dict[str, Any]],
cache_namespace: str, cache_namespace: str,
recovery_notice: str | None = None, recovery_notice: str | None = None,
trace: TraceRequest | None = None,
) -> bool: ) -> bool:
body = read_response_body(response) body = read_response_body(response)
upstream_body = body
try: try:
body = rewrite_response_body( body = rewrite_response_body(
body, body,
@ -454,14 +568,34 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if self.config.verbose: if self.config.verbose:
log_bytes("cursor response body", body) log_bytes("cursor response body", body)
headers = {
"Content-Type": response.headers.get("Content-Type", "application/json"),
"Content-Length": str(len(body)),
}
if trace is not None:
trace.record_upstream_response(
status=getattr(response, "status", 200),
headers=response_headers(response),
body=upstream_body,
stream=False,
)
try:
upstream_payload = json.loads(upstream_body.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
upstream_payload = None
if isinstance(upstream_payload, dict):
trace.record_usage(upstream_payload.get("usage"))
trace.record_cursor_response(
status=getattr(response, "status", 200),
headers=headers,
body=body,
)
sent_headers = self._send_response_headers( sent_headers = self._send_response_headers(
getattr(response, "status", 200), getattr(response, "status", 200),
[ [
( ("Content-Type", headers["Content-Type"]),
"Content-Type", ("Content-Length", headers["Content-Length"]),
response.headers.get("Content-Type", "application/json"),
),
("Content-Length", str(len(body))),
], ],
"sending upstream response headers", "sending upstream response headers",
) )
@ -476,7 +610,22 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
request_messages: list[dict[str, Any]], request_messages: list[dict[str, Any]],
cache_namespace: str, cache_namespace: str,
recovery_notice: str | None = None, recovery_notice: str | None = None,
trace: TraceRequest | None = None,
) -> bool: ) -> bool:
if trace is not None:
trace.record_upstream_response(
status=getattr(response, "status", 200),
headers=response_headers(response),
stream=True,
)
trace.record_cursor_response(
status=getattr(response, "status", 200),
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "close",
},
)
sent_headers = self._send_response_headers( sent_headers = self._send_response_headers(
getattr(response, "status", 200), getattr(response, "status", 200),
[ [
@ -514,7 +663,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
scope, scope,
display_adapter, display_adapter,
pending_recovery_notice, pending_recovery_notice,
trace,
) )
if trace is not None:
trace.record_stream_chunk(line, rewritten)
if not self._write_to_client( if not self._write_to_client(
rewritten, "sending streaming response chunk", flush=True rewritten, "sending streaming response chunk", flush=True
): ):
@ -538,6 +690,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
scope: str, scope: str,
display_adapter: CursorReasoningDisplayAdapter | None, display_adapter: CursorReasoningDisplayAdapter | None,
recovery_notice: str | None = None, recovery_notice: str | None = None,
trace: TraceRequest | None = None,
) -> tuple[bytes, bool, str | None]: ) -> tuple[bytes, bool, str | None]:
stripped = line.strip() stripped = line.strip()
if not stripped.startswith(b"data:"): if not stripped.startswith(b"data:"):
@ -578,6 +731,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
stored = accumulator.store_ready_reasoning(self.reasoning_store, scope) stored = accumulator.store_ready_reasoning(self.reasoning_store, scope)
if stored: if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored) LOG.info("stored %s streaming reasoning cache key(s)", stored)
if trace is not None:
trace.record_usage(chunk.get("usage"))
log_usage(chunk.get("usage")) log_usage(chunk.get("usage"))
if display_adapter is not None: if display_adapter is not None:
display_adapter.rewrite_chunk(chunk) display_adapter.rewrite_chunk(chunk)
@ -653,6 +808,11 @@ def build_arg_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="Log detailed request lifecycle metadata and full payloads", help="Log detailed request lifecycle metadata and full payloads",
) )
parser.add_argument(
"--trace-dir",
type=Path,
help="Write full structured request traces to this directory",
)
parser.add_argument( parser.add_argument(
"--display-reasoning", "--display-reasoning",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
@ -891,6 +1051,13 @@ def read_response_body(response: Any) -> bytes:
return body return body
def response_headers(response: Any) -> dict[str, str]:
headers = getattr(response, "headers", {})
if hasattr(headers, "items"):
return {str(name): str(value) for name, value in headers.items()}
return {}
def warn_if_insecure_upstream(url: str) -> None: def warn_if_insecure_upstream(url: str) -> None:
parsed = urlparse(url) parsed = urlparse(url)
if parsed.scheme != "http": if parsed.scheme != "http":
@ -930,6 +1097,8 @@ def main(argv: list[str] | None = None) -> int:
updates["ngrok"] = args.ngrok updates["ngrok"] = args.ngrok
if args.verbose is not None: if args.verbose is not None:
updates["verbose"] = args.verbose updates["verbose"] = args.verbose
if args.trace_dir is not None:
updates["trace_dir"] = args.trace_dir
if args.display_reasoning is not None: if args.display_reasoning is not None:
updates["cursor_display_reasoning"] = args.display_reasoning updates["cursor_display_reasoning"] = args.display_reasoning
if args.cors is not None: if args.cors is not None:
@ -960,9 +1129,18 @@ def main(argv: list[str] | None = None) -> int:
LOG.info("cleared %s reasoning cache row(s)", deleted) LOG.info("cleared %s reasoning cache row(s)", deleted)
store.close() store.close()
return 0 return 0
trace_writer: TraceWriter | None = None
if config.trace_dir is not None:
try:
trace_writer = TraceWriter(config.trace_dir)
except OSError as exc:
LOG.error("failed to initialize trace directory: %s", exc)
store.close()
return 2
server = DeepSeekProxyServer((config.host, config.port), DeepSeekProxyHandler) server = DeepSeekProxyServer((config.host, config.port), DeepSeekProxyHandler)
server.config = config server.config = config
server.reasoning_store = store server.reasoning_store = store
server.trace_writer = trace_writer
LOG.info("listening on http://%s:%s/v1", config.host, config.port) LOG.info("listening on http://%s:%s/v1", config.host, config.port)
LOG.info( LOG.info(
@ -988,6 +1166,9 @@ def main(argv: list[str] | None = None) -> int:
) )
else: else:
LOG.info("logging mode=normal metadata=safe_summaries bodies=false") LOG.info("logging mode=normal metadata=safe_summaries bodies=false")
if trace_writer is not None:
LOG.info("trace session directory: %s", trace_writer.session_dir)
LOG.warning("trace logging enabled; prompts and code will be written to disk")
tunnel: NgrokTunnel | None = None tunnel: NgrokTunnel | None = None
if config.ngrok: if config.ngrok:

View File

@ -0,0 +1,318 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
import hashlib
import json
import os
from pathlib import Path
import threading
import time
from typing import Any
TRACE_SCHEMA_VERSION = 1
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat(timespec="milliseconds")
def sha256_text(value: str) -> str:
return hashlib.sha256(value.encode("utf-8")).hexdigest()
def authorization_summary(authorization: str | None) -> dict[str, Any]:
if not authorization:
return {"present": False}
return {"present": True, "sha256": sha256_text(authorization)}
def sanitized_headers(headers: dict[str, str] | None) -> dict[str, Any]:
if not headers:
return {}
sanitized: dict[str, Any] = {}
for name, value in headers.items():
if name.lower() == "authorization":
sanitized[name] = authorization_summary(value)
else:
sanitized[name] = value
return sanitized
def jsonable_body(body: bytes) -> dict[str, Any]:
text = body.decode("utf-8", errors="replace")
try:
payload = json.loads(text)
except json.JSONDecodeError:
return {"text": text}
return {"json": payload}
def tool_names(payload: dict[str, Any]) -> list[str]:
names: list[str] = []
tools = payload.get("tools")
if not isinstance(tools, list):
return names
for tool in tools:
if not isinstance(tool, dict):
names.append("")
continue
function = tool.get("function")
if isinstance(function, dict):
names.append(str(function.get("name") or ""))
else:
names.append("")
return names
def content_stats(content: Any) -> dict[str, Any]:
if content is None:
text = ""
elif isinstance(content, str):
text = content
else:
text = json.dumps(content, ensure_ascii=False, sort_keys=True)
return {"length": len(text), "sha256": sha256_text(text)}
def message_summaries(payload: dict[str, Any]) -> list[dict[str, Any]]:
messages = payload.get("messages")
if not isinstance(messages, list):
return []
summaries: list[dict[str, Any]] = []
for index, message in enumerate(messages):
if not isinstance(message, dict):
summaries.append({"index": index, "type": type(message).__name__})
continue
tool_calls = message.get("tool_calls")
tool_call_ids = []
if isinstance(tool_calls, list):
for tool_call in tool_calls:
if isinstance(tool_call, dict) and tool_call.get("id"):
tool_call_ids.append(str(tool_call["id"]))
reasoning = message.get("reasoning_content")
content = str(message.get("content") or "")
summary: dict[str, Any] = {
"index": index,
"role": message.get("role"),
"content": content_stats(message.get("content")),
"has_tool_calls": bool(tool_call_ids or tool_calls),
"tool_call_ids": tool_call_ids,
"tool_call_id": message.get("tool_call_id"),
"has_reasoning_content": isinstance(reasoning, str),
"reasoning_content_length": (
len(reasoning) if isinstance(reasoning, str) else 0
),
"has_recovery_notice": content.startswith(
"[deepseek-cursor-proxy] Recovered"
),
}
summaries.append(summary)
return summaries
def payload_summary(payload: dict[str, Any]) -> dict[str, Any]:
messages = payload.get("messages")
if not isinstance(messages, list):
messages = []
system_hashes = [
content_stats(message.get("content"))["sha256"]
for message in messages
if isinstance(message, dict) and message.get("role") == "system"
]
return {
"model": payload.get("model"),
"stream": bool(payload.get("stream")),
"message_count": len(messages),
"tool_count": (
len(payload.get("tools") or [])
if isinstance(payload.get("tools"), list)
else 0
),
"tool_names": tool_names(payload),
"system_prompt_hashes": system_hashes,
"messages": message_summaries(payload),
}
def write_json_private(path: Path, payload: dict[str, Any]) -> None:
tmp_path = path.with_name(f".{path.name}.tmp")
tmp_path.write_text(
json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
tmp_path.chmod(0o600)
tmp_path.replace(path)
path.chmod(0o600)
class TraceWriter:
def __init__(self, base_dir: str | Path) -> None:
self.base_dir = Path(base_dir).expanduser()
self.base_dir.mkdir(mode=0o700, parents=True, exist_ok=True)
session_name = (
datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S.%fZ")
+ f"-pid{os.getpid()}"
)
self.session_dir = self.base_dir / session_name
self.session_dir.mkdir(mode=0o700)
self._lock = threading.Lock()
self._next_sequence = 1
self._write_manifest()
def start_request(
self,
*,
method: str,
path: str,
client_address: str,
headers: dict[str, str],
) -> "TraceRequest":
with self._lock:
sequence = self._next_sequence
self._next_sequence += 1
trace_path = self.session_dir / f"request-{sequence:06d}.json"
return TraceRequest(
writer=self,
sequence=sequence,
path=trace_path,
data={
"schema_version": TRACE_SCHEMA_VERSION,
"sequence": sequence,
"created_at": utc_now_iso(),
"request": {
"method": method,
"path": path,
"client_address": client_address,
"headers": sanitized_headers(headers),
},
"transform": {},
"upstream": {},
"cursor_response": {},
"completion": {},
},
)
def _write_manifest(self) -> None:
write_json_private(
self.session_dir / "manifest.json",
{
"schema_version": TRACE_SCHEMA_VERSION,
"created_at": utc_now_iso(),
"pid": os.getpid(),
"base_dir": str(self.base_dir),
"session_dir": str(self.session_dir),
"format": "one JSON file per proxied POST request",
},
)
@dataclass
class TraceRequest:
writer: TraceWriter
sequence: int
path: Path
data: dict[str, Any]
_started: float = field(default_factory=time.monotonic)
_finished: bool = False
def record_cursor_body(self, payload: dict[str, Any]) -> None:
self.data["request"]["body"] = payload
self.data["request"]["summary"] = payload_summary(payload)
def record_transform(self, prepared: Any) -> None:
self.data["transform"] = {
"original_model": prepared.original_model,
"upstream_model": prepared.upstream_model,
"cache_namespace": prepared.cache_namespace,
"patched_reasoning_messages": prepared.patched_reasoning_messages,
"missing_reasoning_messages": prepared.missing_reasoning_messages,
"recovered_reasoning_messages": prepared.recovered_reasoning_messages,
"recovery_dropped_messages": prepared.recovery_dropped_messages,
"recovery_notice": prepared.recovery_notice,
"reasoning_diagnostics": prepared.reasoning_diagnostics,
"recovery_steps": prepared.recovery_steps,
"upstream_request_summary": payload_summary(prepared.payload),
"upstream_request_body": prepared.payload,
}
def record_upstream_request(
self,
*,
url: str,
headers: dict[str, str],
body_bytes: bytes,
) -> None:
self.data["upstream"]["request"] = {
"url": url,
"headers": sanitized_headers(headers),
"body_bytes": len(body_bytes),
}
def record_upstream_response(
self,
*,
status: int,
headers: dict[str, str] | None = None,
body: bytes | None = None,
stream: bool | None = None,
) -> None:
response: dict[str, Any] = {"status": status}
if headers is not None:
response["headers"] = sanitized_headers(headers)
if stream is not None:
response["stream"] = stream
if body is not None:
response["body"] = jsonable_body(body)
self.data["upstream"]["response"] = response
def record_cursor_response(
self,
*,
status: int,
headers: dict[str, str] | None = None,
body: bytes | None = None,
) -> None:
response: dict[str, Any] = {"status": status}
if headers is not None:
response["headers"] = sanitized_headers(headers)
if body is not None:
response["body"] = jsonable_body(body)
self.data["cursor_response"].update(response)
def record_stream_chunk(self, upstream_line: bytes, cursor_line: bytes) -> None:
upstream_stream = self.data["upstream"].setdefault("stream", {"chunks": []})
cursor_stream = self.data["cursor_response"].setdefault(
"stream", {"chunks": []}
)
index = len(upstream_stream["chunks"])
upstream_stream["chunks"].append(
{
"index": index,
"line": upstream_line.decode("utf-8", errors="replace"),
}
)
cursor_stream["chunks"].append(
{
"index": index,
"line": cursor_line.decode("utf-8", errors="replace"),
}
)
def record_usage(self, usage: Any) -> None:
if isinstance(usage, dict):
self.data["upstream"]["usage"] = usage
def finish(self, status: str, **extra: Any) -> None:
if self._finished:
return
completion = {
"status": status,
"finished_at": utc_now_iso(),
"elapsed_ms": round((time.monotonic() - self._started) * 1000),
}
completion.update(extra)
self.data["completion"] = completion
write_json_private(self.path, self.data)
self._finished = True

View File

@ -1,13 +1,19 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, field
import hashlib import hashlib
import json import json
import re import re
from typing import Any from typing import Any
from .config import ProxyConfig from .config import ProxyConfig
from .reasoning_store import ReasoningStore, conversation_scope from .reasoning_store import (
ReasoningStore,
conversation_scope,
message_signature,
tool_call_ids,
tool_call_signature,
)
SUPPORTED_REQUEST_FIELDS = { SUPPORTED_REQUEST_FIELDS = {
@ -95,6 +101,8 @@ class PreparedRequest:
recovered_reasoning_messages: int = 0 recovered_reasoning_messages: int = 0
recovery_dropped_messages: int = 0 recovery_dropped_messages: int = 0
recovery_notice: str | None = None recovery_notice: str | None = None
reasoning_diagnostics: list[dict[str, Any]] = field(default_factory=list)
recovery_steps: list[dict[str, Any]] = field(default_factory=list)
def normalize_reasoning_effort(value: Any) -> str: def normalize_reasoning_effort(value: Any) -> str:
@ -214,7 +222,7 @@ def normalize_message(
cache_namespace: str, cache_namespace: str,
repair_reasoning: bool, repair_reasoning: bool,
keep_reasoning: bool, keep_reasoning: bool,
) -> tuple[dict[str, Any], bool, bool]: ) -> tuple[dict[str, Any], bool, bool, dict[str, Any] | None]:
if not isinstance(message, dict): if not isinstance(message, dict):
message = {"role": "user", "content": str(message)} message = {"role": "user", "content": str(message)}
normalized = {key: value for key, value in message.items() if key in MESSAGE_FIELDS} normalized = {key: value for key, value in message.items() if key in MESSAGE_FIELDS}
@ -239,6 +247,7 @@ def normalize_message(
patched = False patched = False
missing = False missing = False
diagnostic: dict[str, Any] | None = None
if normalized["role"] == "assistant": if normalized["role"] == "assistant":
if not keep_reasoning: if not keep_reasoning:
normalized.pop("reasoning_content", None) normalized.pop("reasoning_content", None)
@ -249,22 +258,94 @@ def normalize_message(
needs_reasoning = assistant_needs_reasoning_for_tool_context( needs_reasoning = assistant_needs_reasoning_for_tool_context(
normalized, prior_messages normalized, prior_messages
) )
lookup_scope = conversation_scope(prior_messages, cache_namespace)
lookup_keys = (
reasoning_lookup_keys(normalized, lookup_scope)
if needs_reasoning
else []
)
hit_kind = None
if needs_reasoning and store is not None: if needs_reasoning and store is not None:
restored = store.lookup_for_message( for lookup_key in lookup_keys:
normalized, restored = store.get(str(lookup_key["key"]))
conversation_scope(prior_messages, cache_namespace), if restored is not None:
) lookup_key["hit"] = True
if restored is not None: hit_kind = lookup_key["kind"]
normalized["reasoning_content"] = restored normalized["reasoning_content"] = restored
patched = True patched = True
break
if needs_reasoning and not patched: if needs_reasoning and not patched:
missing = True missing = True
if needs_reasoning:
diagnostic = {
"message_index": len(prior_messages),
"role": "assistant",
"needs_reasoning": True,
"had_reasoning_content": False,
"patched": patched,
"missing": missing,
"lookup_scope": lookup_scope,
"message_signature": message_signature(normalized),
"tool_call_ids": tool_call_ids(normalized),
"lookup_keys": lookup_keys,
"hit_kind": hit_kind,
}
elif assistant_needs_reasoning_for_tool_context(normalized, prior_messages):
diagnostic = {
"message_index": len(prior_messages),
"role": "assistant",
"needs_reasoning": True,
"had_reasoning_content": True,
"patched": False,
"missing": False,
"lookup_scope": conversation_scope(prior_messages, cache_namespace),
"message_signature": message_signature(normalized),
"tool_call_ids": tool_call_ids(normalized),
"lookup_keys": [],
"hit_kind": "request",
}
allowed_fields = ROLE_MESSAGE_FIELDS.get(str(normalized["role"]), MESSAGE_FIELDS) allowed_fields = ROLE_MESSAGE_FIELDS.get(str(normalized["role"]), MESSAGE_FIELDS)
normalized = { normalized = {
key: value for key, value in normalized.items() if key in allowed_fields key: value for key, value in normalized.items() if key in allowed_fields
} }
return normalized, patched, missing return normalized, patched, missing, diagnostic
def reasoning_lookup_keys(
message: dict[str, Any],
scope: str,
) -> list[dict[str, Any]]:
keys = [
{
"kind": "message_signature",
"key": f"scope:{scope}:signature:{message_signature(message)}",
"hit": False,
}
]
keys.extend(
{
"kind": "tool_call_id",
"tool_call_id": tool_call_id,
"key": f"scope:{scope}:tool_call:{tool_call_id}",
"hit": False,
}
for tool_call_id in tool_call_ids(message)
)
keys.extend(
{
"kind": "tool_call_signature",
"function_name": str((tool_call.get("function") or {}).get("name") or ""),
"key": (
f"scope:{scope}:tool_call_signature:"
f"{tool_call_signature(tool_call)}"
),
"hit": False,
}
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
return keys
def normalize_messages( def normalize_messages(
@ -273,14 +354,15 @@ def normalize_messages(
cache_namespace: str, cache_namespace: str,
repair_reasoning: bool, repair_reasoning: bool,
keep_reasoning: bool, keep_reasoning: bool,
) -> tuple[list[dict[str, Any]], int, list[int]]: ) -> tuple[list[dict[str, Any]], int, list[int], list[dict[str, Any]]]:
if not isinstance(messages, list): if not isinstance(messages, list):
return [], 0, [] return [], 0, [], []
normalized_messages: list[dict[str, Any]] = [] normalized_messages: list[dict[str, Any]] = []
patched_count = 0 patched_count = 0
missing_indexes: list[int] = [] missing_indexes: list[int] = []
diagnostics: list[dict[str, Any]] = []
for message in messages: for message in messages:
normalized, patched, missing = normalize_message( normalized, patched, missing, diagnostic = normalize_message(
message, message,
store, store,
normalized_messages, normalized_messages,
@ -293,7 +375,9 @@ def normalize_messages(
patched_count += 1 patched_count += 1
if missing: if missing:
missing_indexes.append(len(normalized_messages) - 1) missing_indexes.append(len(normalized_messages) - 1)
return normalized_messages, patched_count, missing_indexes if diagnostic is not None:
diagnostics.append(diagnostic)
return normalized_messages, patched_count, missing_indexes, diagnostics
def has_recovery_notice(message: dict[str, Any]) -> bool: def has_recovery_notice(message: dict[str, Any]) -> bool:
@ -318,7 +402,7 @@ def leading_system_messages(messages: list[dict[str, Any]]) -> list[dict[str, An
def recover_messages_from_missing_reasoning( def recover_messages_from_missing_reasoning(
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
missing_indexes: list[int], missing_indexes: list[int],
) -> tuple[list[dict[str, Any]], int, str | None]: ) -> tuple[list[dict[str, Any]], int, str | None, dict[str, Any]]:
recovery_boundary_index = next( recovery_boundary_index = next(
( (
index index
@ -351,7 +435,19 @@ def recover_messages_from_missing_reasoning(
omitted_messages = ( omitted_messages = (
recovery_boundary_index - len(leading_messages) - kept_context_messages recovery_boundary_index - len(leading_messages) - kept_context_messages
) )
return recovered, omitted_messages, None return (
recovered,
omitted_messages,
None,
{
"strategy": "recovery_boundary",
"missing_indexes": missing_indexes,
"recovery_boundary_index": recovery_boundary_index,
"context_user_index": context_user_index,
"dropped_messages": omitted_messages,
"notice": None,
},
)
last_user_index = next( last_user_index = next(
( (
@ -362,13 +458,35 @@ def recover_messages_from_missing_reasoning(
-1, -1,
) )
if last_user_index == -1: if last_user_index == -1:
return messages, 0, None return (
messages,
0,
None,
{
"strategy": "none",
"missing_indexes": missing_indexes,
"last_user_index": None,
"dropped_messages": 0,
"notice": None,
},
)
recovered = leading_system_messages(messages) recovered = leading_system_messages(messages)
omitted_messages = len(messages) - len(recovered) - 1 omitted_messages = len(messages) - len(recovered) - 1
recovered.append({"role": "system", "content": RECOVERY_SYSTEM_CONTENT}) recovered.append({"role": "system", "content": RECOVERY_SYSTEM_CONTENT})
recovered.append(messages[last_user_index]) recovered.append(messages[last_user_index])
return recovered, omitted_messages, RECOVERY_NOTICE_CONTENT return (
recovered,
omitted_messages,
RECOVERY_NOTICE_CONTENT,
{
"strategy": "latest_user",
"missing_indexes": missing_indexes,
"last_user_index": last_user_index,
"dropped_messages": omitted_messages,
"notice": RECOVERY_NOTICE_CONTENT,
},
)
def assistant_needs_reasoning_for_tool_context( def assistant_needs_reasoning_for_tool_context(
@ -478,33 +596,43 @@ def prepare_upstream_request(
prepared.get("reasoning_effort"), prepared.get("reasoning_effort"),
authorization, authorization,
) )
messages, patched_count, missing_indexes = normalize_messages( messages, patched_count, missing_indexes, reasoning_diagnostics = (
payload.get("messages"), normalize_messages(
store, payload.get("messages"),
cache_namespace, store,
repair_reasoning=thinking_enabled, cache_namespace,
keep_reasoning=not thinking_disabled, repair_reasoning=thinking_enabled,
keep_reasoning=not thinking_disabled,
)
) )
recovered_count = 0 recovered_count = 0
recovery_dropped_messages = 0 recovery_dropped_messages = 0
recovery_notice = None recovery_notice = None
recovery_steps: list[dict[str, Any]] = []
while missing_indexes and config.missing_reasoning_strategy == "recover": while missing_indexes and config.missing_reasoning_strategy == "recover":
recovered_messages, dropped_messages, notice = ( recovered_messages, dropped_messages, notice, recovery_step = (
recover_messages_from_missing_reasoning(messages, missing_indexes) recover_messages_from_missing_reasoning(messages, missing_indexes)
) )
recovery_steps.append(recovery_step)
if not dropped_messages: if not dropped_messages:
break break
recovered_count += len(missing_indexes) recovered_count += len(missing_indexes)
recovery_dropped_messages += dropped_messages recovery_dropped_messages += dropped_messages
if notice: if notice:
recovery_notice = notice recovery_notice = notice
messages, patched_count, missing_indexes = normalize_messages( (
messages,
patched_count,
missing_indexes,
latest_diagnostics,
) = normalize_messages(
recovered_messages, recovered_messages,
store, store,
cache_namespace, cache_namespace,
repair_reasoning=thinking_enabled, repair_reasoning=thinking_enabled,
keep_reasoning=not thinking_disabled, keep_reasoning=not thinking_disabled,
) )
reasoning_diagnostics.extend(latest_diagnostics)
prepared["messages"] = messages prepared["messages"] = messages
return PreparedRequest( return PreparedRequest(
@ -517,6 +645,8 @@ def prepare_upstream_request(
recovered_reasoning_messages=recovered_count, recovered_reasoning_messages=recovered_count,
recovery_dropped_messages=recovery_dropped_messages, recovery_dropped_messages=recovery_dropped_messages,
recovery_notice=recovery_notice, recovery_notice=recovery_notice,
reasoning_diagnostics=reasoning_diagnostics,
recovery_steps=recovery_steps,
) )

View File

@ -39,6 +39,7 @@ class ConfigTests(unittest.TestCase):
home / ".deepseek-cursor-proxy" / "reasoning_content.sqlite3", home / ".deepseek-cursor-proxy" / "reasoning_content.sqlite3",
) )
self.assertEqual(ProxyConfig().ngrok, DEFAULT_NGROK) self.assertEqual(ProxyConfig().ngrok, DEFAULT_NGROK)
self.assertIsNone(ProxyConfig().trace_dir)
def test_missing_default_config_file_is_populated(self) -> None: def test_missing_default_config_file_is_populated(self) -> None:
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from dataclasses import replace from dataclasses import replace
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
import json import json
from pathlib import Path
from tempfile import TemporaryDirectory
import threading import threading
import time import time
import unittest import unittest
@ -16,6 +18,7 @@ from deepseek_cursor_proxy.reasoning_store import (
message_signature, message_signature,
) )
from deepseek_cursor_proxy.server import DeepSeekProxyHandler, DeepSeekProxyServer from deepseek_cursor_proxy.server import DeepSeekProxyHandler, DeepSeekProxyServer
from deepseek_cursor_proxy.trace import TraceWriter
from deepseek_cursor_proxy.transform import ( from deepseek_cursor_proxy.transform import (
RECOVERY_NOTICE_CONTENT, RECOVERY_NOTICE_CONTENT,
reasoning_cache_namespace, reasoning_cache_namespace,
@ -249,6 +252,12 @@ class ReasoningStreamingDeepSeekHandler(BaseHTTPRequestHandler):
"created": 1, "created": 1,
"model": "deepseek-v4-pro", "model": "deepseek-v4-pro",
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
"completion_tokens_details": {"reasoning_tokens": 3},
},
}, },
] ]
for chunk in chunks: for chunk in chunks:
@ -476,6 +485,17 @@ class ServerFixture:
self.thread.join(timeout=5) self.thread.join(timeout=5)
def read_single_trace(session_dir: Path) -> dict:
deadline = time.monotonic() + 2
trace_files = sorted(session_dir.glob("request-*.json"))
while not trace_files and time.monotonic() < deadline:
time.sleep(0.01)
trace_files = sorted(session_dir.glob("request-*.json"))
if len(trace_files) != 1:
raise AssertionError(f"expected one trace file, found {trace_files}")
return json.loads(trace_files[0].read_text(encoding="utf-8"))
class ProxyEndToEndTests(unittest.TestCase): class ProxyEndToEndTests(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
FakeDeepSeekHandler.requests = [] FakeDeepSeekHandler.requests = []
@ -598,6 +618,47 @@ class ProxyEndToEndTests(unittest.TestCase):
self.assertIn("What is tomorrow's date?", output) self.assertIn("What is tomorrow's date?", output)
self.assertNotIn("sk-from-cursor", output) self.assertNotIn("sk-from-cursor", output)
def test_trace_captures_full_non_streaming_replay_without_api_key(self) -> None:
with TemporaryDirectory() as temp_dir:
writer = TraceWriter(temp_dir)
self.proxy.server.trace_writer = writer
status, payload = post_json(
f"{self.proxy.url}/v1/chat/completions",
first_cursor_request(),
api_key="sk-from-cursor",
)
trace = read_single_trace(writer.session_dir)
serialized = json.dumps(trace)
self.assertEqual(status, 200)
self.assertEqual(
payload["choices"][0]["message"]["tool_calls"][0]["id"], "call_date"
)
self.assertEqual(trace["completion"]["status"], "completed")
self.assertEqual(
trace["request"]["body"]["messages"][0]["content"],
"What is tomorrow's date?",
)
self.assertEqual(
trace["transform"]["upstream_request_body"]["model"],
"deepseek-v4-pro",
)
self.assertEqual(
trace["upstream"]["response"]["body"]["json"]["choices"][0]["message"][
"reasoning_content"
],
TOOL_REASONING,
)
self.assertEqual(
trace["cursor_response"]["body"]["json"]["choices"][0]["message"][
"reasoning_content"
],
TOOL_REASONING,
)
self.assertNotIn("sk-from-cursor", serialized)
def test_proxy_rejects_missing_cursor_bearer_token(self) -> None: def test_proxy_rejects_missing_cursor_bearer_token(self) -> None:
request = Request( request = Request(
f"{self.proxy.url}/v1/chat/completions", f"{self.proxy.url}/v1/chat/completions",
@ -681,6 +742,32 @@ class ProxyEndToEndTests(unittest.TestCase):
"\n".join(captured.output), "\n".join(captured.output),
) )
def test_trace_captures_recovery_diagnostics(self) -> None:
with TemporaryDirectory() as temp_dir:
writer = TraceWriter(temp_dir)
self.proxy.server.trace_writer = writer
status, _ = post_json(
f"{self.proxy.url}/v1/chat/completions",
third_cursor_request_missing_all_reasoning(),
)
trace = read_single_trace(writer.session_dir)
self.assertEqual(status, 200)
self.assertEqual(trace["transform"]["recovered_reasoning_messages"], 2)
self.assertEqual(
trace["transform"]["recovery_steps"][0]["strategy"],
"latest_user",
)
missing_diagnostics = [
item
for item in trace["transform"]["reasoning_diagnostics"]
if item["missing"]
]
self.assertGreaterEqual(len(missing_diagnostics), 2)
self.assertIn("lookup_keys", missing_diagnostics[0])
def test_proxy_keeps_deepseek_context_after_recovery_boundary(self) -> None: def test_proxy_keeps_deepseek_context_after_recovery_boundary(self) -> None:
status, first = post_json( status, first = post_json(
f"{self.proxy.url}/v1/chat/completions", f"{self.proxy.url}/v1/chat/completions",
@ -948,6 +1035,45 @@ class ReasoningStreamingProxyTests(unittest.TestCase):
"Need context.", "Need context.",
) )
def test_trace_captures_streaming_replay_chunks(self) -> None:
with TemporaryDirectory() as temp_dir:
writer = TraceWriter(temp_dir)
self.proxy.server.trace_writer = writer
request = Request(
f"{self.proxy.url}/v1/chat/completions",
data=json.dumps(
{
"model": "deepseek-v4-pro",
"stream": True,
"messages": [{"role": "user", "content": "stream reasoning"}],
}
).encode("utf-8"),
method="POST",
headers={
"Authorization": "Bearer sk-cursor-test",
"Content-Type": "application/json",
},
)
with urlopen(request, timeout=2) as response:
response.read()
trace = read_single_trace(writer.session_dir)
self.assertEqual(trace["completion"]["status"], "completed")
self.assertIn(
"reasoning_content",
trace["upstream"]["stream"]["chunks"][0]["line"],
)
self.assertIn(
"<think>",
trace["cursor_response"]["stream"]["chunks"][0]["line"],
)
self.assertEqual(
trace["upstream"]["usage"]["completion_tokens_details"]["reasoning_tokens"],
3,
)
def test_streaming_recovery_notice_is_visible_in_cursor_content(self) -> None: def test_streaming_recovery_notice_is_visible_in_cursor_content(self) -> None:
payload = third_cursor_request_missing_all_reasoning() payload = third_cursor_request_missing_all_reasoning()
payload["stream"] = True payload["stream"] = True

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from io import BytesIO from io import BytesIO
import gzip import gzip
import json import json
from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import unittest import unittest
import zlib import zlib
@ -80,6 +81,8 @@ class ServerTests(unittest.TestCase):
"--no-verbose", "--no-verbose",
"--no-display-reasoning", "--no-display-reasoning",
"--cors", "--cors",
"--trace-dir",
"/tmp/dcp-traces",
] ]
) )
@ -87,6 +90,7 @@ class ServerTests(unittest.TestCase):
self.assertFalse(args.verbose) self.assertFalse(args.verbose)
self.assertFalse(args.display_reasoning) self.assertFalse(args.display_reasoning)
self.assertTrue(args.cors) self.assertTrue(args.cors)
self.assertEqual(args.trace_dir, Path("/tmp/dcp-traces"))
def test_read_response_body_handles_gzip(self) -> None: def test_read_response_body_handles_gzip(self) -> None:
body = gzip.compress(b'{"ok":true}') body = gzip.compress(b'{"ok":true}')

63
tests/test_trace.py Normal file
View File

@ -0,0 +1,63 @@
from __future__ import annotations
import json
import stat
from tempfile import TemporaryDirectory
import unittest
from deepseek_cursor_proxy.trace import TraceWriter
class TraceWriterTests(unittest.TestCase):
def test_writes_manifest_and_numbered_request_files(self) -> None:
with TemporaryDirectory() as temp_dir:
writer = TraceWriter(temp_dir)
first = writer.start_request(
method="POST",
path="/v1/chat/completions",
client_address="127.0.0.1",
headers={"User-Agent": "Cursor/1.0"},
)
second = writer.start_request(
method="POST",
path="/v1/chat/completions",
client_address="127.0.0.1",
headers={"User-Agent": "Cursor/1.0"},
)
first.finish("completed", http_status=200)
second.finish("completed", http_status=200)
self.assertTrue((writer.session_dir / "manifest.json").exists())
self.assertTrue((writer.session_dir / "request-000001.json").exists())
self.assertTrue((writer.session_dir / "request-000002.json").exists())
self.assertEqual(
stat.S_IMODE(
(writer.session_dir / "request-000001.json").stat().st_mode
),
0o600,
)
def test_authorization_header_is_redacted(self) -> None:
with TemporaryDirectory() as temp_dir:
writer = TraceWriter(temp_dir)
trace = writer.start_request(
method="POST",
path="/v1/chat/completions",
client_address="127.0.0.1",
headers={"Authorization": "Bearer sk-secret"},
)
trace.finish("completed", http_status=200)
payload = json.loads(trace.path.read_text(encoding="utf-8"))
serialized = json.dumps(payload)
self.assertNotIn("sk-secret", serialized)
self.assertEqual(
payload["request"]["headers"]["Authorization"]["present"],
True,
)
self.assertIn("sha256", payload["request"]["headers"]["Authorization"])
if __name__ == "__main__":
unittest.main()