feat(trace): add request tracing and diagnostics (#27)
parent
0189b9290e
commit
cf3a9d3875
|
|
@ -140,6 +140,7 @@ celerybeat.pid
|
|||
.deepseek_cursor_reasoning.sqlite3*
|
||||
.deepseek-cursor/
|
||||
.deepseek-cursor-proxy/
|
||||
trace-dumps/
|
||||
reasoning_content.sqlite3*
|
||||
.venv
|
||||
env/
|
||||
|
|
|
|||
|
|
@ -168,6 +168,12 @@ Run without ngrok for local curl testing:
|
|||
deepseek-cursor-proxy --no-ngrok --port 9000 --verbose
|
||||
```
|
||||
|
||||
Capture full structured request traces for debugging:
|
||||
|
||||
```bash
|
||||
deepseek-cursor-proxy --verbose --trace-dir ./trace-dumps
|
||||
```
|
||||
|
||||
Use another config file:
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ class ProxyConfig:
|
|||
cors: bool = DEFAULT_CORS
|
||||
verbose: bool = DEFAULT_VERBOSE
|
||||
ngrok: bool = DEFAULT_NGROK
|
||||
trace_dir: Path | None = None
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from .config import (
|
|||
)
|
||||
from .reasoning_store import ReasoningStore, conversation_scope
|
||||
from .streaming import CursorReasoningDisplayAdapter, StreamAccumulator
|
||||
from .trace import TraceRequest, TraceWriter
|
||||
from .tunnel import NgrokTunnel, local_tunnel_target
|
||||
from .transform import (
|
||||
RECOVERY_NOTICE_CONTENT,
|
||||
|
|
@ -41,6 +42,7 @@ class RequestBodyTooLarge(ValueError):
|
|||
class DeepSeekProxyServer(ThreadingHTTPServer):
|
||||
config: ProxyConfig
|
||||
reasoning_store: ReasoningStore
|
||||
trace_writer: TraceWriter | None
|
||||
|
||||
|
||||
class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||
|
|
@ -54,6 +56,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
def reasoning_store(self) -> ReasoningStore:
|
||||
return self.server.reasoning_store # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
def trace_writer(self) -> TraceWriter | None:
|
||||
return getattr(self.server, "trace_writer", None)
|
||||
|
||||
def log_message(self, fmt: str, *args: Any) -> None:
|
||||
LOG.info("%s - %s", self.address_string(), fmt % args)
|
||||
|
||||
|
|
@ -82,6 +88,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
def do_POST(self) -> None:
|
||||
started = time.monotonic()
|
||||
request_path = urlparse(self.path).path
|
||||
trace = self._start_trace(request_path)
|
||||
if self.config.verbose:
|
||||
LOG.info(
|
||||
"incoming POST %s from %s content_length=%s user_agent=%s",
|
||||
|
|
@ -93,8 +100,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
if request_path not in {"/chat/completions", "/v1/chat/completions"}:
|
||||
LOG.warning("rejected unsupported POST path=%s status=404", request_path)
|
||||
self._send_json(
|
||||
404, {"error": {"message": "Only /v1/chat/completions is supported"}}
|
||||
404,
|
||||
{"error": {"message": "Only /v1/chat/completions is supported"}},
|
||||
trace=trace,
|
||||
)
|
||||
self._finish_trace(trace, "rejected", http_status=404)
|
||||
return
|
||||
cursor_authorization = self._cursor_authorization()
|
||||
if cursor_authorization is None:
|
||||
|
|
@ -103,8 +113,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
request_path,
|
||||
)
|
||||
self._send_json(
|
||||
401, {"error": {"message": "Missing Authorization bearer token"}}
|
||||
401,
|
||||
{"error": {"message": "Missing Authorization bearer token"}},
|
||||
trace=trace,
|
||||
)
|
||||
self._finish_trace(trace, "rejected", http_status=401)
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -113,15 +126,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
LOG.warning(
|
||||
"rejected request path=%s status=413 reason=%s", request_path, exc
|
||||
)
|
||||
self._send_json(413, {"error": {"message": str(exc)}})
|
||||
self._send_json(413, {"error": {"message": str(exc)}}, trace=trace)
|
||||
self._finish_trace(trace, "rejected", http_status=413, reason=str(exc))
|
||||
return
|
||||
except ValueError as exc:
|
||||
LOG.warning(
|
||||
"rejected request path=%s status=400 reason=%s", request_path, exc
|
||||
)
|
||||
self._send_json(400, {"error": {"message": str(exc)}})
|
||||
self._send_json(400, {"error": {"message": str(exc)}}, trace=trace)
|
||||
self._finish_trace(trace, "rejected", http_status=400, reason=str(exc))
|
||||
return
|
||||
|
||||
if trace is not None:
|
||||
trace.record_cursor_body(payload)
|
||||
|
||||
if self.config.verbose:
|
||||
log_json("cursor request body", payload)
|
||||
|
||||
|
|
@ -133,6 +151,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
self.reasoning_store,
|
||||
authorization=cursor_authorization,
|
||||
)
|
||||
if trace is not None:
|
||||
trace.record_transform(prepared)
|
||||
if prepared.patched_reasoning_messages:
|
||||
LOG.info(
|
||||
"restored reasoning_content on %s assistant message(s)",
|
||||
|
|
@ -187,8 +207,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
"missing_reasoning_messages": prepared.missing_reasoning_messages,
|
||||
}
|
||||
},
|
||||
trace=trace,
|
||||
)
|
||||
self._finish_trace(trace, "rejected", http_status=409)
|
||||
return
|
||||
|
||||
LOG.info(
|
||||
"deepseek send: %s patched=%s recovered=%s",
|
||||
compact_request_stats(prepared.payload),
|
||||
|
|
@ -216,14 +239,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
prepared.payload, ensure_ascii=False, separators=(",", ":")
|
||||
).encode("utf-8")
|
||||
upstream_url = f"{self.config.upstream_base_url}/chat/completions"
|
||||
upstream_headers = self._upstream_headers(
|
||||
stream=bool(prepared.payload.get("stream")),
|
||||
authorization=cursor_authorization,
|
||||
)
|
||||
if trace is not None:
|
||||
trace.record_upstream_request(
|
||||
url=upstream_url,
|
||||
headers=upstream_headers,
|
||||
body_bytes=upstream_body,
|
||||
)
|
||||
request = Request(
|
||||
upstream_url,
|
||||
data=upstream_body,
|
||||
method="POST",
|
||||
headers=self._upstream_headers(
|
||||
stream=bool(prepared.payload.get("stream")),
|
||||
authorization=cursor_authorization,
|
||||
),
|
||||
headers=upstream_headers,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -237,7 +267,13 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
bool(prepared.payload.get("stream")),
|
||||
elapsed_ms(started),
|
||||
)
|
||||
self._send_upstream_error(exc)
|
||||
self._send_upstream_error(exc, trace=trace)
|
||||
self._finish_trace(
|
||||
trace,
|
||||
"upstream_error",
|
||||
http_status=exc.code,
|
||||
stream=bool(prepared.payload.get("stream")),
|
||||
)
|
||||
return
|
||||
except URLError as exc:
|
||||
LOG.warning(
|
||||
|
|
@ -246,8 +282,11 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
exc.reason,
|
||||
)
|
||||
self._send_json(
|
||||
502, {"error": {"message": f"Upstream request failed: {exc.reason}"}}
|
||||
502,
|
||||
{"error": {"message": f"Upstream request failed: {exc.reason}"}},
|
||||
trace=trace,
|
||||
)
|
||||
self._finish_trace(trace, "upstream_error", http_status=502)
|
||||
return
|
||||
|
||||
with response:
|
||||
|
|
@ -266,6 +305,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
prepared.payload["messages"],
|
||||
prepared.cache_namespace,
|
||||
prepared.recovery_notice,
|
||||
trace,
|
||||
)
|
||||
else:
|
||||
sent_response = self._proxy_regular_response(
|
||||
|
|
@ -274,8 +314,15 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
prepared.payload["messages"],
|
||||
prepared.cache_namespace,
|
||||
prepared.recovery_notice,
|
||||
trace,
|
||||
)
|
||||
if not sent_response:
|
||||
self._finish_trace(
|
||||
trace,
|
||||
"client_disconnected",
|
||||
http_status=upstream_status,
|
||||
stream=bool(prepared.payload.get("stream")),
|
||||
)
|
||||
return
|
||||
LOG.info(
|
||||
(
|
||||
|
|
@ -289,6 +336,40 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
prepared.missing_reasoning_messages,
|
||||
prepared.recovered_reasoning_messages,
|
||||
)
|
||||
self._finish_trace(
|
||||
trace,
|
||||
"completed",
|
||||
http_status=upstream_status,
|
||||
stream=bool(prepared.payload.get("stream")),
|
||||
)
|
||||
|
||||
def _start_trace(self, request_path: str) -> TraceRequest | None:
|
||||
writer = self.trace_writer
|
||||
if writer is None:
|
||||
return None
|
||||
try:
|
||||
return writer.start_request(
|
||||
method=self.command,
|
||||
path=request_path,
|
||||
client_address=self.client_address[0],
|
||||
headers={name: value for name, value in self.headers.items()},
|
||||
)
|
||||
except OSError as exc:
|
||||
LOG.warning("failed to start request trace: %s", exc)
|
||||
return None
|
||||
|
||||
def _finish_trace(
|
||||
self,
|
||||
trace: TraceRequest | None,
|
||||
status: str,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
if trace is None:
|
||||
return
|
||||
try:
|
||||
trace.finish(status, **extra)
|
||||
except OSError as exc:
|
||||
LOG.warning("failed to write request trace: %s", exc)
|
||||
|
||||
def _cursor_authorization(self) -> str | None:
|
||||
auth_header = self.headers.get("Authorization", "")
|
||||
|
|
@ -309,10 +390,25 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
self.send_header("Access-Control-Expose-Headers", "Content-Length")
|
||||
self.send_header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
def _send_json(self, status: int, payload: dict[str, Any]) -> None:
|
||||
def _send_json(
|
||||
self,
|
||||
status: int,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
trace: TraceRequest | None = None,
|
||||
) -> None:
|
||||
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode(
|
||||
"utf-8"
|
||||
)
|
||||
if trace is not None:
|
||||
trace.record_cursor_response(
|
||||
status=status,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Content-Length": str(len(body)),
|
||||
},
|
||||
body=body,
|
||||
)
|
||||
sent_headers = self._send_response_headers(
|
||||
status,
|
||||
[
|
||||
|
|
@ -414,15 +510,31 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
headers["Accept-Language"] = accept_language
|
||||
return headers
|
||||
|
||||
def _send_upstream_error(self, exc: HTTPError) -> None:
|
||||
def _send_upstream_error(
|
||||
self,
|
||||
exc: HTTPError,
|
||||
*,
|
||||
trace: TraceRequest | None = None,
|
||||
) -> None:
|
||||
body = read_response_body(exc)
|
||||
if self.config.verbose:
|
||||
log_bytes("upstream error body", body)
|
||||
headers = {
|
||||
"Content-Type": exc.headers.get("Content-Type", "application/json"),
|
||||
"Content-Length": str(len(body)),
|
||||
}
|
||||
if trace is not None:
|
||||
trace.record_upstream_response(
|
||||
status=exc.code,
|
||||
headers={name: value for name, value in exc.headers.items()},
|
||||
body=body,
|
||||
)
|
||||
trace.record_cursor_response(status=exc.code, headers=headers, body=body)
|
||||
sent_headers = self._send_response_headers(
|
||||
exc.code,
|
||||
[
|
||||
("Content-Type", exc.headers.get("Content-Type", "application/json")),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", headers["Content-Type"]),
|
||||
("Content-Length", headers["Content-Length"]),
|
||||
],
|
||||
"sending upstream error headers",
|
||||
)
|
||||
|
|
@ -436,8 +548,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
request_messages: list[dict[str, Any]],
|
||||
cache_namespace: str,
|
||||
recovery_notice: str | None = None,
|
||||
trace: TraceRequest | None = None,
|
||||
) -> bool:
|
||||
body = read_response_body(response)
|
||||
upstream_body = body
|
||||
try:
|
||||
body = rewrite_response_body(
|
||||
body,
|
||||
|
|
@ -454,14 +568,34 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
if self.config.verbose:
|
||||
log_bytes("cursor response body", body)
|
||||
|
||||
headers = {
|
||||
"Content-Type": response.headers.get("Content-Type", "application/json"),
|
||||
"Content-Length": str(len(body)),
|
||||
}
|
||||
if trace is not None:
|
||||
trace.record_upstream_response(
|
||||
status=getattr(response, "status", 200),
|
||||
headers=response_headers(response),
|
||||
body=upstream_body,
|
||||
stream=False,
|
||||
)
|
||||
try:
|
||||
upstream_payload = json.loads(upstream_body.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
upstream_payload = None
|
||||
if isinstance(upstream_payload, dict):
|
||||
trace.record_usage(upstream_payload.get("usage"))
|
||||
trace.record_cursor_response(
|
||||
status=getattr(response, "status", 200),
|
||||
headers=headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
sent_headers = self._send_response_headers(
|
||||
getattr(response, "status", 200),
|
||||
[
|
||||
(
|
||||
"Content-Type",
|
||||
response.headers.get("Content-Type", "application/json"),
|
||||
),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", headers["Content-Type"]),
|
||||
("Content-Length", headers["Content-Length"]),
|
||||
],
|
||||
"sending upstream response headers",
|
||||
)
|
||||
|
|
@ -476,7 +610,22 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
request_messages: list[dict[str, Any]],
|
||||
cache_namespace: str,
|
||||
recovery_notice: str | None = None,
|
||||
trace: TraceRequest | None = None,
|
||||
) -> bool:
|
||||
if trace is not None:
|
||||
trace.record_upstream_response(
|
||||
status=getattr(response, "status", 200),
|
||||
headers=response_headers(response),
|
||||
stream=True,
|
||||
)
|
||||
trace.record_cursor_response(
|
||||
status=getattr(response, "status", 200),
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
},
|
||||
)
|
||||
sent_headers = self._send_response_headers(
|
||||
getattr(response, "status", 200),
|
||||
[
|
||||
|
|
@ -514,7 +663,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
scope,
|
||||
display_adapter,
|
||||
pending_recovery_notice,
|
||||
trace,
|
||||
)
|
||||
if trace is not None:
|
||||
trace.record_stream_chunk(line, rewritten)
|
||||
if not self._write_to_client(
|
||||
rewritten, "sending streaming response chunk", flush=True
|
||||
):
|
||||
|
|
@ -538,6 +690,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
scope: str,
|
||||
display_adapter: CursorReasoningDisplayAdapter | None,
|
||||
recovery_notice: str | None = None,
|
||||
trace: TraceRequest | None = None,
|
||||
) -> tuple[bytes, bool, str | None]:
|
||||
stripped = line.strip()
|
||||
if not stripped.startswith(b"data:"):
|
||||
|
|
@ -578,6 +731,8 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
stored = accumulator.store_ready_reasoning(self.reasoning_store, scope)
|
||||
if stored:
|
||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||
if trace is not None:
|
||||
trace.record_usage(chunk.get("usage"))
|
||||
log_usage(chunk.get("usage"))
|
||||
if display_adapter is not None:
|
||||
display_adapter.rewrite_chunk(chunk)
|
||||
|
|
@ -653,6 +808,11 @@ def build_arg_parser() -> argparse.ArgumentParser:
|
|||
default=None,
|
||||
help="Log detailed request lifecycle metadata and full payloads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trace-dir",
|
||||
type=Path,
|
||||
help="Write full structured request traces to this directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-reasoning",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
|
|
@ -891,6 +1051,13 @@ def read_response_body(response: Any) -> bytes:
|
|||
return body
|
||||
|
||||
|
||||
def response_headers(response: Any) -> dict[str, str]:
|
||||
headers = getattr(response, "headers", {})
|
||||
if hasattr(headers, "items"):
|
||||
return {str(name): str(value) for name, value in headers.items()}
|
||||
return {}
|
||||
|
||||
|
||||
def warn_if_insecure_upstream(url: str) -> None:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme != "http":
|
||||
|
|
@ -930,6 +1097,8 @@ def main(argv: list[str] | None = None) -> int:
|
|||
updates["ngrok"] = args.ngrok
|
||||
if args.verbose is not None:
|
||||
updates["verbose"] = args.verbose
|
||||
if args.trace_dir is not None:
|
||||
updates["trace_dir"] = args.trace_dir
|
||||
if args.display_reasoning is not None:
|
||||
updates["cursor_display_reasoning"] = args.display_reasoning
|
||||
if args.cors is not None:
|
||||
|
|
@ -960,9 +1129,18 @@ def main(argv: list[str] | None = None) -> int:
|
|||
LOG.info("cleared %s reasoning cache row(s)", deleted)
|
||||
store.close()
|
||||
return 0
|
||||
trace_writer: TraceWriter | None = None
|
||||
if config.trace_dir is not None:
|
||||
try:
|
||||
trace_writer = TraceWriter(config.trace_dir)
|
||||
except OSError as exc:
|
||||
LOG.error("failed to initialize trace directory: %s", exc)
|
||||
store.close()
|
||||
return 2
|
||||
server = DeepSeekProxyServer((config.host, config.port), DeepSeekProxyHandler)
|
||||
server.config = config
|
||||
server.reasoning_store = store
|
||||
server.trace_writer = trace_writer
|
||||
|
||||
LOG.info("listening on http://%s:%s/v1", config.host, config.port)
|
||||
LOG.info(
|
||||
|
|
@ -988,6 +1166,9 @@ def main(argv: list[str] | None = None) -> int:
|
|||
)
|
||||
else:
|
||||
LOG.info("logging mode=normal metadata=safe_summaries bodies=false")
|
||||
if trace_writer is not None:
|
||||
LOG.info("trace session directory: %s", trace_writer.session_dir)
|
||||
LOG.warning("trace logging enabled; prompts and code will be written to disk")
|
||||
|
||||
tunnel: NgrokTunnel | None = None
|
||||
if config.ngrok:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,318 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
|
||||
TRACE_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat(timespec="milliseconds")
|
||||
|
||||
|
||||
def sha256_text(value: str) -> str:
|
||||
return hashlib.sha256(value.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def authorization_summary(authorization: str | None) -> dict[str, Any]:
|
||||
if not authorization:
|
||||
return {"present": False}
|
||||
return {"present": True, "sha256": sha256_text(authorization)}
|
||||
|
||||
|
||||
def sanitized_headers(headers: dict[str, str] | None) -> dict[str, Any]:
|
||||
if not headers:
|
||||
return {}
|
||||
sanitized: dict[str, Any] = {}
|
||||
for name, value in headers.items():
|
||||
if name.lower() == "authorization":
|
||||
sanitized[name] = authorization_summary(value)
|
||||
else:
|
||||
sanitized[name] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def jsonable_body(body: bytes) -> dict[str, Any]:
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return {"text": text}
|
||||
return {"json": payload}
|
||||
|
||||
|
||||
def tool_names(payload: dict[str, Any]) -> list[str]:
|
||||
names: list[str] = []
|
||||
tools = payload.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
return names
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict):
|
||||
names.append("")
|
||||
continue
|
||||
function = tool.get("function")
|
||||
if isinstance(function, dict):
|
||||
names.append(str(function.get("name") or ""))
|
||||
else:
|
||||
names.append("")
|
||||
return names
|
||||
|
||||
|
||||
def content_stats(content: Any) -> dict[str, Any]:
|
||||
if content is None:
|
||||
text = ""
|
||||
elif isinstance(content, str):
|
||||
text = content
|
||||
else:
|
||||
text = json.dumps(content, ensure_ascii=False, sort_keys=True)
|
||||
return {"length": len(text), "sha256": sha256_text(text)}
|
||||
|
||||
|
||||
def message_summaries(payload: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
messages = payload.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return []
|
||||
summaries: list[dict[str, Any]] = []
|
||||
for index, message in enumerate(messages):
|
||||
if not isinstance(message, dict):
|
||||
summaries.append({"index": index, "type": type(message).__name__})
|
||||
continue
|
||||
tool_calls = message.get("tool_calls")
|
||||
tool_call_ids = []
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict) and tool_call.get("id"):
|
||||
tool_call_ids.append(str(tool_call["id"]))
|
||||
reasoning = message.get("reasoning_content")
|
||||
content = str(message.get("content") or "")
|
||||
summary: dict[str, Any] = {
|
||||
"index": index,
|
||||
"role": message.get("role"),
|
||||
"content": content_stats(message.get("content")),
|
||||
"has_tool_calls": bool(tool_call_ids or tool_calls),
|
||||
"tool_call_ids": tool_call_ids,
|
||||
"tool_call_id": message.get("tool_call_id"),
|
||||
"has_reasoning_content": isinstance(reasoning, str),
|
||||
"reasoning_content_length": (
|
||||
len(reasoning) if isinstance(reasoning, str) else 0
|
||||
),
|
||||
"has_recovery_notice": content.startswith(
|
||||
"[deepseek-cursor-proxy] Recovered"
|
||||
),
|
||||
}
|
||||
summaries.append(summary)
|
||||
return summaries
|
||||
|
||||
|
||||
def payload_summary(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
messages = payload.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
messages = []
|
||||
system_hashes = [
|
||||
content_stats(message.get("content"))["sha256"]
|
||||
for message in messages
|
||||
if isinstance(message, dict) and message.get("role") == "system"
|
||||
]
|
||||
return {
|
||||
"model": payload.get("model"),
|
||||
"stream": bool(payload.get("stream")),
|
||||
"message_count": len(messages),
|
||||
"tool_count": (
|
||||
len(payload.get("tools") or [])
|
||||
if isinstance(payload.get("tools"), list)
|
||||
else 0
|
||||
),
|
||||
"tool_names": tool_names(payload),
|
||||
"system_prompt_hashes": system_hashes,
|
||||
"messages": message_summaries(payload),
|
||||
}
|
||||
|
||||
|
||||
def write_json_private(path: Path, payload: dict[str, Any]) -> None:
|
||||
tmp_path = path.with_name(f".{path.name}.tmp")
|
||||
tmp_path.write_text(
|
||||
json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
tmp_path.chmod(0o600)
|
||||
tmp_path.replace(path)
|
||||
path.chmod(0o600)
|
||||
|
||||
|
||||
class TraceWriter:
|
||||
def __init__(self, base_dir: str | Path) -> None:
|
||||
self.base_dir = Path(base_dir).expanduser()
|
||||
self.base_dir.mkdir(mode=0o700, parents=True, exist_ok=True)
|
||||
session_name = (
|
||||
datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S.%fZ")
|
||||
+ f"-pid{os.getpid()}"
|
||||
)
|
||||
self.session_dir = self.base_dir / session_name
|
||||
self.session_dir.mkdir(mode=0o700)
|
||||
self._lock = threading.Lock()
|
||||
self._next_sequence = 1
|
||||
self._write_manifest()
|
||||
|
||||
def start_request(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
path: str,
|
||||
client_address: str,
|
||||
headers: dict[str, str],
|
||||
) -> "TraceRequest":
|
||||
with self._lock:
|
||||
sequence = self._next_sequence
|
||||
self._next_sequence += 1
|
||||
trace_path = self.session_dir / f"request-{sequence:06d}.json"
|
||||
return TraceRequest(
|
||||
writer=self,
|
||||
sequence=sequence,
|
||||
path=trace_path,
|
||||
data={
|
||||
"schema_version": TRACE_SCHEMA_VERSION,
|
||||
"sequence": sequence,
|
||||
"created_at": utc_now_iso(),
|
||||
"request": {
|
||||
"method": method,
|
||||
"path": path,
|
||||
"client_address": client_address,
|
||||
"headers": sanitized_headers(headers),
|
||||
},
|
||||
"transform": {},
|
||||
"upstream": {},
|
||||
"cursor_response": {},
|
||||
"completion": {},
|
||||
},
|
||||
)
|
||||
|
||||
def _write_manifest(self) -> None:
|
||||
write_json_private(
|
||||
self.session_dir / "manifest.json",
|
||||
{
|
||||
"schema_version": TRACE_SCHEMA_VERSION,
|
||||
"created_at": utc_now_iso(),
|
||||
"pid": os.getpid(),
|
||||
"base_dir": str(self.base_dir),
|
||||
"session_dir": str(self.session_dir),
|
||||
"format": "one JSON file per proxied POST request",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceRequest:
|
||||
writer: TraceWriter
|
||||
sequence: int
|
||||
path: Path
|
||||
data: dict[str, Any]
|
||||
_started: float = field(default_factory=time.monotonic)
|
||||
_finished: bool = False
|
||||
|
||||
def record_cursor_body(self, payload: dict[str, Any]) -> None:
|
||||
self.data["request"]["body"] = payload
|
||||
self.data["request"]["summary"] = payload_summary(payload)
|
||||
|
||||
def record_transform(self, prepared: Any) -> None:
|
||||
self.data["transform"] = {
|
||||
"original_model": prepared.original_model,
|
||||
"upstream_model": prepared.upstream_model,
|
||||
"cache_namespace": prepared.cache_namespace,
|
||||
"patched_reasoning_messages": prepared.patched_reasoning_messages,
|
||||
"missing_reasoning_messages": prepared.missing_reasoning_messages,
|
||||
"recovered_reasoning_messages": prepared.recovered_reasoning_messages,
|
||||
"recovery_dropped_messages": prepared.recovery_dropped_messages,
|
||||
"recovery_notice": prepared.recovery_notice,
|
||||
"reasoning_diagnostics": prepared.reasoning_diagnostics,
|
||||
"recovery_steps": prepared.recovery_steps,
|
||||
"upstream_request_summary": payload_summary(prepared.payload),
|
||||
"upstream_request_body": prepared.payload,
|
||||
}
|
||||
|
||||
def record_upstream_request(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
body_bytes: bytes,
|
||||
) -> None:
|
||||
self.data["upstream"]["request"] = {
|
||||
"url": url,
|
||||
"headers": sanitized_headers(headers),
|
||||
"body_bytes": len(body_bytes),
|
||||
}
|
||||
|
||||
def record_upstream_response(
|
||||
self,
|
||||
*,
|
||||
status: int,
|
||||
headers: dict[str, str] | None = None,
|
||||
body: bytes | None = None,
|
||||
stream: bool | None = None,
|
||||
) -> None:
|
||||
response: dict[str, Any] = {"status": status}
|
||||
if headers is not None:
|
||||
response["headers"] = sanitized_headers(headers)
|
||||
if stream is not None:
|
||||
response["stream"] = stream
|
||||
if body is not None:
|
||||
response["body"] = jsonable_body(body)
|
||||
self.data["upstream"]["response"] = response
|
||||
|
||||
def record_cursor_response(
|
||||
self,
|
||||
*,
|
||||
status: int,
|
||||
headers: dict[str, str] | None = None,
|
||||
body: bytes | None = None,
|
||||
) -> None:
|
||||
response: dict[str, Any] = {"status": status}
|
||||
if headers is not None:
|
||||
response["headers"] = sanitized_headers(headers)
|
||||
if body is not None:
|
||||
response["body"] = jsonable_body(body)
|
||||
self.data["cursor_response"].update(response)
|
||||
|
||||
def record_stream_chunk(self, upstream_line: bytes, cursor_line: bytes) -> None:
|
||||
upstream_stream = self.data["upstream"].setdefault("stream", {"chunks": []})
|
||||
cursor_stream = self.data["cursor_response"].setdefault(
|
||||
"stream", {"chunks": []}
|
||||
)
|
||||
index = len(upstream_stream["chunks"])
|
||||
upstream_stream["chunks"].append(
|
||||
{
|
||||
"index": index,
|
||||
"line": upstream_line.decode("utf-8", errors="replace"),
|
||||
}
|
||||
)
|
||||
cursor_stream["chunks"].append(
|
||||
{
|
||||
"index": index,
|
||||
"line": cursor_line.decode("utf-8", errors="replace"),
|
||||
}
|
||||
)
|
||||
|
||||
def record_usage(self, usage: Any) -> None:
|
||||
if isinstance(usage, dict):
|
||||
self.data["upstream"]["usage"] = usage
|
||||
|
||||
def finish(self, status: str, **extra: Any) -> None:
|
||||
if self._finished:
|
||||
return
|
||||
completion = {
|
||||
"status": status,
|
||||
"finished_at": utc_now_iso(),
|
||||
"elapsed_ms": round((time.monotonic() - self._started) * 1000),
|
||||
}
|
||||
completion.update(extra)
|
||||
self.data["completion"] = completion
|
||||
write_json_private(self.path, self.data)
|
||||
self._finished = True
|
||||
|
|
@ -1,13 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .config import ProxyConfig
|
||||
from .reasoning_store import ReasoningStore, conversation_scope
|
||||
from .reasoning_store import (
|
||||
ReasoningStore,
|
||||
conversation_scope,
|
||||
message_signature,
|
||||
tool_call_ids,
|
||||
tool_call_signature,
|
||||
)
|
||||
|
||||
|
||||
SUPPORTED_REQUEST_FIELDS = {
|
||||
|
|
@ -95,6 +101,8 @@ class PreparedRequest:
|
|||
recovered_reasoning_messages: int = 0
|
||||
recovery_dropped_messages: int = 0
|
||||
recovery_notice: str | None = None
|
||||
reasoning_diagnostics: list[dict[str, Any]] = field(default_factory=list)
|
||||
recovery_steps: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def normalize_reasoning_effort(value: Any) -> str:
|
||||
|
|
@ -214,7 +222,7 @@ def normalize_message(
|
|||
cache_namespace: str,
|
||||
repair_reasoning: bool,
|
||||
keep_reasoning: bool,
|
||||
) -> tuple[dict[str, Any], bool, bool]:
|
||||
) -> tuple[dict[str, Any], bool, bool, dict[str, Any] | None]:
|
||||
if not isinstance(message, dict):
|
||||
message = {"role": "user", "content": str(message)}
|
||||
normalized = {key: value for key, value in message.items() if key in MESSAGE_FIELDS}
|
||||
|
|
@ -239,6 +247,7 @@ def normalize_message(
|
|||
|
||||
patched = False
|
||||
missing = False
|
||||
diagnostic: dict[str, Any] | None = None
|
||||
if normalized["role"] == "assistant":
|
||||
if not keep_reasoning:
|
||||
normalized.pop("reasoning_content", None)
|
||||
|
|
@ -249,22 +258,94 @@ def normalize_message(
|
|||
needs_reasoning = assistant_needs_reasoning_for_tool_context(
|
||||
normalized, prior_messages
|
||||
)
|
||||
lookup_scope = conversation_scope(prior_messages, cache_namespace)
|
||||
lookup_keys = (
|
||||
reasoning_lookup_keys(normalized, lookup_scope)
|
||||
if needs_reasoning
|
||||
else []
|
||||
)
|
||||
hit_kind = None
|
||||
if needs_reasoning and store is not None:
|
||||
restored = store.lookup_for_message(
|
||||
normalized,
|
||||
conversation_scope(prior_messages, cache_namespace),
|
||||
)
|
||||
if restored is not None:
|
||||
normalized["reasoning_content"] = restored
|
||||
patched = True
|
||||
for lookup_key in lookup_keys:
|
||||
restored = store.get(str(lookup_key["key"]))
|
||||
if restored is not None:
|
||||
lookup_key["hit"] = True
|
||||
hit_kind = lookup_key["kind"]
|
||||
normalized["reasoning_content"] = restored
|
||||
patched = True
|
||||
break
|
||||
if needs_reasoning and not patched:
|
||||
missing = True
|
||||
if needs_reasoning:
|
||||
diagnostic = {
|
||||
"message_index": len(prior_messages),
|
||||
"role": "assistant",
|
||||
"needs_reasoning": True,
|
||||
"had_reasoning_content": False,
|
||||
"patched": patched,
|
||||
"missing": missing,
|
||||
"lookup_scope": lookup_scope,
|
||||
"message_signature": message_signature(normalized),
|
||||
"tool_call_ids": tool_call_ids(normalized),
|
||||
"lookup_keys": lookup_keys,
|
||||
"hit_kind": hit_kind,
|
||||
}
|
||||
elif assistant_needs_reasoning_for_tool_context(normalized, prior_messages):
|
||||
diagnostic = {
|
||||
"message_index": len(prior_messages),
|
||||
"role": "assistant",
|
||||
"needs_reasoning": True,
|
||||
"had_reasoning_content": True,
|
||||
"patched": False,
|
||||
"missing": False,
|
||||
"lookup_scope": conversation_scope(prior_messages, cache_namespace),
|
||||
"message_signature": message_signature(normalized),
|
||||
"tool_call_ids": tool_call_ids(normalized),
|
||||
"lookup_keys": [],
|
||||
"hit_kind": "request",
|
||||
}
|
||||
|
||||
allowed_fields = ROLE_MESSAGE_FIELDS.get(str(normalized["role"]), MESSAGE_FIELDS)
|
||||
normalized = {
|
||||
key: value for key, value in normalized.items() if key in allowed_fields
|
||||
}
|
||||
return normalized, patched, missing
|
||||
return normalized, patched, missing, diagnostic
|
||||
|
||||
|
||||
def reasoning_lookup_keys(
|
||||
message: dict[str, Any],
|
||||
scope: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
keys = [
|
||||
{
|
||||
"kind": "message_signature",
|
||||
"key": f"scope:{scope}:signature:{message_signature(message)}",
|
||||
"hit": False,
|
||||
}
|
||||
]
|
||||
keys.extend(
|
||||
{
|
||||
"kind": "tool_call_id",
|
||||
"tool_call_id": tool_call_id,
|
||||
"key": f"scope:{scope}:tool_call:{tool_call_id}",
|
||||
"hit": False,
|
||||
}
|
||||
for tool_call_id in tool_call_ids(message)
|
||||
)
|
||||
keys.extend(
|
||||
{
|
||||
"kind": "tool_call_signature",
|
||||
"function_name": str((tool_call.get("function") or {}).get("name") or ""),
|
||||
"key": (
|
||||
f"scope:{scope}:tool_call_signature:"
|
||||
f"{tool_call_signature(tool_call)}"
|
||||
),
|
||||
"hit": False,
|
||||
}
|
||||
for tool_call in (message.get("tool_calls") or [])
|
||||
if isinstance(tool_call, dict)
|
||||
)
|
||||
return keys
|
||||
|
||||
|
||||
def normalize_messages(
|
||||
|
|
@ -273,14 +354,15 @@ def normalize_messages(
|
|||
cache_namespace: str,
|
||||
repair_reasoning: bool,
|
||||
keep_reasoning: bool,
|
||||
) -> tuple[list[dict[str, Any]], int, list[int]]:
|
||||
) -> tuple[list[dict[str, Any]], int, list[int], list[dict[str, Any]]]:
|
||||
if not isinstance(messages, list):
|
||||
return [], 0, []
|
||||
return [], 0, [], []
|
||||
normalized_messages: list[dict[str, Any]] = []
|
||||
patched_count = 0
|
||||
missing_indexes: list[int] = []
|
||||
diagnostics: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
normalized, patched, missing = normalize_message(
|
||||
normalized, patched, missing, diagnostic = normalize_message(
|
||||
message,
|
||||
store,
|
||||
normalized_messages,
|
||||
|
|
@ -293,7 +375,9 @@ def normalize_messages(
|
|||
patched_count += 1
|
||||
if missing:
|
||||
missing_indexes.append(len(normalized_messages) - 1)
|
||||
return normalized_messages, patched_count, missing_indexes
|
||||
if diagnostic is not None:
|
||||
diagnostics.append(diagnostic)
|
||||
return normalized_messages, patched_count, missing_indexes, diagnostics
|
||||
|
||||
|
||||
def has_recovery_notice(message: dict[str, Any]) -> bool:
|
||||
|
|
@ -318,7 +402,7 @@ def leading_system_messages(messages: list[dict[str, Any]]) -> list[dict[str, An
|
|||
def recover_messages_from_missing_reasoning(
|
||||
messages: list[dict[str, Any]],
|
||||
missing_indexes: list[int],
|
||||
) -> tuple[list[dict[str, Any]], int, str | None]:
|
||||
) -> tuple[list[dict[str, Any]], int, str | None, dict[str, Any]]:
|
||||
recovery_boundary_index = next(
|
||||
(
|
||||
index
|
||||
|
|
@ -351,7 +435,19 @@ def recover_messages_from_missing_reasoning(
|
|||
omitted_messages = (
|
||||
recovery_boundary_index - len(leading_messages) - kept_context_messages
|
||||
)
|
||||
return recovered, omitted_messages, None
|
||||
return (
|
||||
recovered,
|
||||
omitted_messages,
|
||||
None,
|
||||
{
|
||||
"strategy": "recovery_boundary",
|
||||
"missing_indexes": missing_indexes,
|
||||
"recovery_boundary_index": recovery_boundary_index,
|
||||
"context_user_index": context_user_index,
|
||||
"dropped_messages": omitted_messages,
|
||||
"notice": None,
|
||||
},
|
||||
)
|
||||
|
||||
last_user_index = next(
|
||||
(
|
||||
|
|
@ -362,13 +458,35 @@ def recover_messages_from_missing_reasoning(
|
|||
-1,
|
||||
)
|
||||
if last_user_index == -1:
|
||||
return messages, 0, None
|
||||
return (
|
||||
messages,
|
||||
0,
|
||||
None,
|
||||
{
|
||||
"strategy": "none",
|
||||
"missing_indexes": missing_indexes,
|
||||
"last_user_index": None,
|
||||
"dropped_messages": 0,
|
||||
"notice": None,
|
||||
},
|
||||
)
|
||||
|
||||
recovered = leading_system_messages(messages)
|
||||
omitted_messages = len(messages) - len(recovered) - 1
|
||||
recovered.append({"role": "system", "content": RECOVERY_SYSTEM_CONTENT})
|
||||
recovered.append(messages[last_user_index])
|
||||
return recovered, omitted_messages, RECOVERY_NOTICE_CONTENT
|
||||
return (
|
||||
recovered,
|
||||
omitted_messages,
|
||||
RECOVERY_NOTICE_CONTENT,
|
||||
{
|
||||
"strategy": "latest_user",
|
||||
"missing_indexes": missing_indexes,
|
||||
"last_user_index": last_user_index,
|
||||
"dropped_messages": omitted_messages,
|
||||
"notice": RECOVERY_NOTICE_CONTENT,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def assistant_needs_reasoning_for_tool_context(
|
||||
|
|
@ -478,33 +596,43 @@ def prepare_upstream_request(
|
|||
prepared.get("reasoning_effort"),
|
||||
authorization,
|
||||
)
|
||||
messages, patched_count, missing_indexes = normalize_messages(
|
||||
payload.get("messages"),
|
||||
store,
|
||||
cache_namespace,
|
||||
repair_reasoning=thinking_enabled,
|
||||
keep_reasoning=not thinking_disabled,
|
||||
messages, patched_count, missing_indexes, reasoning_diagnostics = (
|
||||
normalize_messages(
|
||||
payload.get("messages"),
|
||||
store,
|
||||
cache_namespace,
|
||||
repair_reasoning=thinking_enabled,
|
||||
keep_reasoning=not thinking_disabled,
|
||||
)
|
||||
)
|
||||
recovered_count = 0
|
||||
recovery_dropped_messages = 0
|
||||
recovery_notice = None
|
||||
recovery_steps: list[dict[str, Any]] = []
|
||||
while missing_indexes and config.missing_reasoning_strategy == "recover":
|
||||
recovered_messages, dropped_messages, notice = (
|
||||
recovered_messages, dropped_messages, notice, recovery_step = (
|
||||
recover_messages_from_missing_reasoning(messages, missing_indexes)
|
||||
)
|
||||
recovery_steps.append(recovery_step)
|
||||
if not dropped_messages:
|
||||
break
|
||||
recovered_count += len(missing_indexes)
|
||||
recovery_dropped_messages += dropped_messages
|
||||
if notice:
|
||||
recovery_notice = notice
|
||||
messages, patched_count, missing_indexes = normalize_messages(
|
||||
(
|
||||
messages,
|
||||
patched_count,
|
||||
missing_indexes,
|
||||
latest_diagnostics,
|
||||
) = normalize_messages(
|
||||
recovered_messages,
|
||||
store,
|
||||
cache_namespace,
|
||||
repair_reasoning=thinking_enabled,
|
||||
keep_reasoning=not thinking_disabled,
|
||||
)
|
||||
reasoning_diagnostics.extend(latest_diagnostics)
|
||||
prepared["messages"] = messages
|
||||
|
||||
return PreparedRequest(
|
||||
|
|
@ -517,6 +645,8 @@ def prepare_upstream_request(
|
|||
recovered_reasoning_messages=recovered_count,
|
||||
recovery_dropped_messages=recovery_dropped_messages,
|
||||
recovery_notice=recovery_notice,
|
||||
reasoning_diagnostics=reasoning_diagnostics,
|
||||
recovery_steps=recovery_steps,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class ConfigTests(unittest.TestCase):
|
|||
home / ".deepseek-cursor-proxy" / "reasoning_content.sqlite3",
|
||||
)
|
||||
self.assertEqual(ProxyConfig().ngrok, DEFAULT_NGROK)
|
||||
self.assertIsNone(ProxyConfig().trace_dir)
|
||||
|
||||
def test_missing_default_config_file_is_populated(self) -> None:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from __future__ import annotations
|
|||
from dataclasses import replace
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
import json
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
|
@ -16,6 +18,7 @@ from deepseek_cursor_proxy.reasoning_store import (
|
|||
message_signature,
|
||||
)
|
||||
from deepseek_cursor_proxy.server import DeepSeekProxyHandler, DeepSeekProxyServer
|
||||
from deepseek_cursor_proxy.trace import TraceWriter
|
||||
from deepseek_cursor_proxy.transform import (
|
||||
RECOVERY_NOTICE_CONTENT,
|
||||
reasoning_cache_namespace,
|
||||
|
|
@ -249,6 +252,12 @@ class ReasoningStreamingDeepSeekHandler(BaseHTTPRequestHandler):
|
|||
"created": 1,
|
||||
"model": "deepseek-v4-pro",
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
"completion_tokens_details": {"reasoning_tokens": 3},
|
||||
},
|
||||
},
|
||||
]
|
||||
for chunk in chunks:
|
||||
|
|
@ -476,6 +485,17 @@ class ServerFixture:
|
|||
self.thread.join(timeout=5)
|
||||
|
||||
|
||||
def read_single_trace(session_dir: Path) -> dict:
|
||||
deadline = time.monotonic() + 2
|
||||
trace_files = sorted(session_dir.glob("request-*.json"))
|
||||
while not trace_files and time.monotonic() < deadline:
|
||||
time.sleep(0.01)
|
||||
trace_files = sorted(session_dir.glob("request-*.json"))
|
||||
if len(trace_files) != 1:
|
||||
raise AssertionError(f"expected one trace file, found {trace_files}")
|
||||
return json.loads(trace_files[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
class ProxyEndToEndTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
FakeDeepSeekHandler.requests = []
|
||||
|
|
@ -598,6 +618,47 @@ class ProxyEndToEndTests(unittest.TestCase):
|
|||
self.assertIn("What is tomorrow's date?", output)
|
||||
self.assertNotIn("sk-from-cursor", output)
|
||||
|
||||
def test_trace_captures_full_non_streaming_replay_without_api_key(self) -> None:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
writer = TraceWriter(temp_dir)
|
||||
self.proxy.server.trace_writer = writer
|
||||
|
||||
status, payload = post_json(
|
||||
f"{self.proxy.url}/v1/chat/completions",
|
||||
first_cursor_request(),
|
||||
api_key="sk-from-cursor",
|
||||
)
|
||||
|
||||
trace = read_single_trace(writer.session_dir)
|
||||
serialized = json.dumps(trace)
|
||||
|
||||
self.assertEqual(status, 200)
|
||||
self.assertEqual(
|
||||
payload["choices"][0]["message"]["tool_calls"][0]["id"], "call_date"
|
||||
)
|
||||
self.assertEqual(trace["completion"]["status"], "completed")
|
||||
self.assertEqual(
|
||||
trace["request"]["body"]["messages"][0]["content"],
|
||||
"What is tomorrow's date?",
|
||||
)
|
||||
self.assertEqual(
|
||||
trace["transform"]["upstream_request_body"]["model"],
|
||||
"deepseek-v4-pro",
|
||||
)
|
||||
self.assertEqual(
|
||||
trace["upstream"]["response"]["body"]["json"]["choices"][0]["message"][
|
||||
"reasoning_content"
|
||||
],
|
||||
TOOL_REASONING,
|
||||
)
|
||||
self.assertEqual(
|
||||
trace["cursor_response"]["body"]["json"]["choices"][0]["message"][
|
||||
"reasoning_content"
|
||||
],
|
||||
TOOL_REASONING,
|
||||
)
|
||||
self.assertNotIn("sk-from-cursor", serialized)
|
||||
|
||||
def test_proxy_rejects_missing_cursor_bearer_token(self) -> None:
|
||||
request = Request(
|
||||
f"{self.proxy.url}/v1/chat/completions",
|
||||
|
|
@ -681,6 +742,32 @@ class ProxyEndToEndTests(unittest.TestCase):
|
|||
"\n".join(captured.output),
|
||||
)
|
||||
|
||||
def test_trace_captures_recovery_diagnostics(self) -> None:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
writer = TraceWriter(temp_dir)
|
||||
self.proxy.server.trace_writer = writer
|
||||
|
||||
status, _ = post_json(
|
||||
f"{self.proxy.url}/v1/chat/completions",
|
||||
third_cursor_request_missing_all_reasoning(),
|
||||
)
|
||||
|
||||
trace = read_single_trace(writer.session_dir)
|
||||
|
||||
self.assertEqual(status, 200)
|
||||
self.assertEqual(trace["transform"]["recovered_reasoning_messages"], 2)
|
||||
self.assertEqual(
|
||||
trace["transform"]["recovery_steps"][0]["strategy"],
|
||||
"latest_user",
|
||||
)
|
||||
missing_diagnostics = [
|
||||
item
|
||||
for item in trace["transform"]["reasoning_diagnostics"]
|
||||
if item["missing"]
|
||||
]
|
||||
self.assertGreaterEqual(len(missing_diagnostics), 2)
|
||||
self.assertIn("lookup_keys", missing_diagnostics[0])
|
||||
|
||||
def test_proxy_keeps_deepseek_context_after_recovery_boundary(self) -> None:
|
||||
status, first = post_json(
|
||||
f"{self.proxy.url}/v1/chat/completions",
|
||||
|
|
@ -948,6 +1035,45 @@ class ReasoningStreamingProxyTests(unittest.TestCase):
|
|||
"Need context.",
|
||||
)
|
||||
|
||||
def test_trace_captures_streaming_replay_chunks(self) -> None:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
writer = TraceWriter(temp_dir)
|
||||
self.proxy.server.trace_writer = writer
|
||||
request = Request(
|
||||
f"{self.proxy.url}/v1/chat/completions",
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "deepseek-v4-pro",
|
||||
"stream": True,
|
||||
"messages": [{"role": "user", "content": "stream reasoning"}],
|
||||
}
|
||||
).encode("utf-8"),
|
||||
method="POST",
|
||||
headers={
|
||||
"Authorization": "Bearer sk-cursor-test",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
with urlopen(request, timeout=2) as response:
|
||||
response.read()
|
||||
|
||||
trace = read_single_trace(writer.session_dir)
|
||||
|
||||
self.assertEqual(trace["completion"]["status"], "completed")
|
||||
self.assertIn(
|
||||
"reasoning_content",
|
||||
trace["upstream"]["stream"]["chunks"][0]["line"],
|
||||
)
|
||||
self.assertIn(
|
||||
"<think>",
|
||||
trace["cursor_response"]["stream"]["chunks"][0]["line"],
|
||||
)
|
||||
self.assertEqual(
|
||||
trace["upstream"]["usage"]["completion_tokens_details"]["reasoning_tokens"],
|
||||
3,
|
||||
)
|
||||
|
||||
def test_streaming_recovery_notice_is_visible_in_cursor_content(self) -> None:
|
||||
payload = third_cursor_request_missing_all_reasoning()
|
||||
payload["stream"] = True
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
from io import BytesIO
|
||||
import gzip
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
import unittest
|
||||
import zlib
|
||||
|
|
@ -80,6 +81,8 @@ class ServerTests(unittest.TestCase):
|
|||
"--no-verbose",
|
||||
"--no-display-reasoning",
|
||||
"--cors",
|
||||
"--trace-dir",
|
||||
"/tmp/dcp-traces",
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -87,6 +90,7 @@ class ServerTests(unittest.TestCase):
|
|||
self.assertFalse(args.verbose)
|
||||
self.assertFalse(args.display_reasoning)
|
||||
self.assertTrue(args.cors)
|
||||
self.assertEqual(args.trace_dir, Path("/tmp/dcp-traces"))
|
||||
|
||||
def test_read_response_body_handles_gzip(self) -> None:
|
||||
body = gzip.compress(b'{"ok":true}')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,63 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import stat
|
||||
from tempfile import TemporaryDirectory
|
||||
import unittest
|
||||
|
||||
from deepseek_cursor_proxy.trace import TraceWriter
|
||||
|
||||
|
||||
class TraceWriterTests(unittest.TestCase):
|
||||
def test_writes_manifest_and_numbered_request_files(self) -> None:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
writer = TraceWriter(temp_dir)
|
||||
first = writer.start_request(
|
||||
method="POST",
|
||||
path="/v1/chat/completions",
|
||||
client_address="127.0.0.1",
|
||||
headers={"User-Agent": "Cursor/1.0"},
|
||||
)
|
||||
second = writer.start_request(
|
||||
method="POST",
|
||||
path="/v1/chat/completions",
|
||||
client_address="127.0.0.1",
|
||||
headers={"User-Agent": "Cursor/1.0"},
|
||||
)
|
||||
first.finish("completed", http_status=200)
|
||||
second.finish("completed", http_status=200)
|
||||
|
||||
self.assertTrue((writer.session_dir / "manifest.json").exists())
|
||||
self.assertTrue((writer.session_dir / "request-000001.json").exists())
|
||||
self.assertTrue((writer.session_dir / "request-000002.json").exists())
|
||||
self.assertEqual(
|
||||
stat.S_IMODE(
|
||||
(writer.session_dir / "request-000001.json").stat().st_mode
|
||||
),
|
||||
0o600,
|
||||
)
|
||||
|
||||
def test_authorization_header_is_redacted(self) -> None:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
writer = TraceWriter(temp_dir)
|
||||
trace = writer.start_request(
|
||||
method="POST",
|
||||
path="/v1/chat/completions",
|
||||
client_address="127.0.0.1",
|
||||
headers={"Authorization": "Bearer sk-secret"},
|
||||
)
|
||||
trace.finish("completed", http_status=200)
|
||||
|
||||
payload = json.loads(trace.path.read_text(encoding="utf-8"))
|
||||
serialized = json.dumps(payload)
|
||||
|
||||
self.assertNotIn("sk-secret", serialized)
|
||||
self.assertEqual(
|
||||
payload["request"]["headers"]["Authorization"]["present"],
|
||||
True,
|
||||
)
|
||||
self.assertIn("sha256", payload["request"]["headers"]["Authorization"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue