diff --git a/src/deepseek_cursor_proxy/server.py b/src/deepseek_cursor_proxy/server.py index d7e8f86..a30e8b8 100644 --- a/src/deepseek_cursor_proxy/server.py +++ b/src/deepseek_cursor_proxy/server.py @@ -64,9 +64,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): request_path, self.client_address[0], ) - self.send_response(204) - self._send_cors_headers() - self.end_headers() + self._send_response_headers(204, [], "sending CORS preflight response") def do_GET(self) -> None: request_path = urlparse(self.path).path @@ -249,19 +247,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): elapsed_ms(started), ) if prepared.payload.get("stream"): - self._proxy_streaming_response( + sent_response = self._proxy_streaming_response( response, prepared.original_model, prepared.payload["messages"], prepared.cache_namespace, ) else: - self._proxy_regular_response( + sent_response = self._proxy_regular_response( response, prepared.original_model, prepared.payload["messages"], prepared.cache_namespace, ) + if not sent_response: + return LOG.info( ( "request complete status=%s stream=%s elapsed_ms=%s " @@ -297,15 +297,49 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode( "utf-8" ) + sent_headers = self._send_response_headers( + status, + [ + ("Content-Type", "application/json"), + ("Content-Length", str(len(body))), + ], + "sending JSON response headers", + ) + if sent_headers: + self._write_to_client(body, "sending JSON response body") + + def _send_response_headers( + self, + status: int, + headers: list[tuple[str, str]], + disconnect_context: str, + ) -> bool: try: self.send_response(status) self._send_cors_headers() - self.send_header("Content-Type", "application/json") - self.send_header("Content-Length", str(len(body))) + for name, value in headers: + self.send_header(name, value) self.end_headers() - self.wfile.write(body) except (BrokenPipeError, ConnectionError) as exc: - LOG.warning("client disconnected before response could be sent: %s", exc) + LOG.warning("client disconnected while %s: %s", disconnect_context, exc) + return False + return True + + def _write_to_client( + self, + body: bytes, + disconnect_context: str, + *, + flush: bool = False, + ) -> bool: + try: + self.wfile.write(body) + if flush: + self.wfile.flush() + except (BrokenPipeError, ConnectionError) as exc: + LOG.warning("client disconnected while %s: %s", disconnect_context, exc) + return False + return True def _send_models(self) -> None: created = int(time.time()) @@ -368,17 +402,16 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): body = read_response_body(exc) if self.config.verbose: log_bytes("upstream error body", body) - try: - self.send_response(exc.code) - self._send_cors_headers() - self.send_header( - "Content-Type", exc.headers.get("Content-Type", "application/json") - ) - self.send_header("Content-Length", str(len(body))) - self.end_headers() - self.wfile.write(body) - except (BrokenPipeError, ConnectionError) as write_err: - LOG.warning("client disconnected before upstream error could be sent: %s", write_err) + sent_headers = self._send_response_headers( + exc.code, + [ + ("Content-Type", exc.headers.get("Content-Type", "application/json")), + ("Content-Length", str(len(body))), + ], + "sending upstream error headers", + ) + if sent_headers: + self._write_to_client(body, "sending upstream error body") def _proxy_regular_response( self, @@ -386,7 +419,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): original_model: str, request_messages: list[dict[str, Any]], cache_namespace: str, - ) -> None: + ) -> bool: body = read_response_body(response) try: body = rewrite_response_body( @@ -403,14 +436,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if self.config.verbose: log_bytes("cursor response body", body) - self.send_response(getattr(response, "status", 200)) - self._send_cors_headers() - self.send_header( - "Content-Type", response.headers.get("Content-Type", "application/json") + sent_headers = self._send_response_headers( + getattr(response, "status", 200), + [ + ( + "Content-Type", + response.headers.get("Content-Type", "application/json"), + ), + ("Content-Length", str(len(body))), + ], + "sending upstream response headers", ) - self.send_header("Content-Length", str(len(body))) - self.end_headers() - self.wfile.write(body) + if not sent_headers: + return False + return self._write_to_client(body, "sending upstream response body") def _proxy_streaming_response( self, @@ -418,13 +457,18 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): original_model: str, request_messages: list[dict[str, Any]], cache_namespace: str, - ) -> None: - self.send_response(getattr(response, "status", 200)) - self._send_cors_headers() - self.send_header("Content-Type", "text/event-stream") - self.send_header("Cache-Control", "no-cache") - self.send_header("Connection", "close") - self.end_headers() + ) -> bool: + sent_headers = self._send_response_headers( + getattr(response, "status", 200), + [ + ("Content-Type", "text/event-stream"), + ("Cache-Control", "no-cache"), + ("Connection", "close"), + ], + "sending streaming response headers", + ) + if not sent_headers: + return False self.close_connection = True accumulator = StreamAccumulator() @@ -442,8 +486,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): rewritten, finalized = self._rewrite_sse_line( line, original_model, accumulator, scope, display_adapter ) - self.wfile.write(rewritten) - self.wfile.flush() + if not self._write_to_client( + rewritten, "sending streaming response chunk", flush=True + ): + return False if finalized: break @@ -453,6 +499,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): stored = accumulator.store_reasoning(self.reasoning_store, scope) if stored: LOG.info("stored %s streaming reasoning cache key(s)", stored) + return True def _rewrite_sse_line( self, @@ -487,7 +534,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler): if isinstance(chunk, dict): accumulator.ingest_chunk(chunk) - stored = accumulator.store_finished_reasoning(self.reasoning_store, scope) + stored = accumulator.store_ready_reasoning(self.reasoning_store, scope) if stored: LOG.info("stored %s streaming reasoning cache key(s)", stored) log_usage(chunk.get("usage")) diff --git a/src/deepseek_cursor_proxy/streaming.py b/src/deepseek_cursor_proxy/streaming.py index b00de7e..aafee1e 100644 --- a/src/deepseek_cursor_proxy/streaming.py +++ b/src/deepseek_cursor_proxy/streaming.py @@ -35,7 +35,7 @@ class StreamingChoice: class StreamAccumulator: def __init__(self) -> None: self.choices: dict[int, StreamingChoice] = {} - self._stored_choices: set[int] = set() + self._stored_choices: dict[int, str] = {} def ingest_chunk(self, chunk: dict[str, Any]) -> None: choices = chunk.get("choices") @@ -80,7 +80,16 @@ class StreamAccumulator: stored = 0 for index, choice in self.choices.items(): if choice.finish_reason is not None: - stored += self._store_choice(index, choice, store, scope) + stored += self._store_choice(index, choice, store, scope, "final") + return stored + + def store_ready_reasoning(self, store: ReasoningStore, scope: str) -> int: + stored = 0 + for index, choice in self.choices.items(): + if choice.finish_reason is not None: + stored += self._store_choice(index, choice, store, scope, "final") + elif self._has_identified_tool_calls(choice): + stored += self._store_choice(index, choice, store, scope, "tool_call") return stored def messages(self) -> list[dict[str, Any]]: @@ -131,14 +140,22 @@ class StreamAccumulator: choice: StreamingChoice, store: ReasoningStore, scope: str, + stage: str = "final", ) -> int: - if index in self._stored_choices: + stage_rank = {"tool_call": 1, "final": 2} + previous_stage = self._stored_choices.get(index) + if stage_rank.get(previous_stage or "", 0) >= stage_rank.get(stage, 0): return 0 stored = store.store_assistant_message(choice.to_message(), scope) if stored: - self._stored_choices.add(index) + self._stored_choices[index] = stage return stored + def _has_identified_tool_calls(self, choice: StreamingChoice) -> bool: + if not choice.has_reasoning_content or not choice.tool_calls: + return False + return all(bool(tool_call.get("id")) for tool_call in choice.tool_calls) + class CursorReasoningDisplayAdapter: """Mirror reasoning_content into content for Cursor's visible thinking UI path.""" diff --git a/tests/test_proxy_end_to_end.py b/tests/test_proxy_end_to_end.py index 598ea98..d1c3b1e 100644 --- a/tests/test_proxy_end_to_end.py +++ b/tests/test_proxy_end_to_end.py @@ -313,6 +313,8 @@ class ToolCallStreamingBeforeDoneDeepSeekHandler(BaseHTTPRequestHandler): for chunk in chunks: self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode("utf-8")) self.wfile.flush() + if chunk["choices"][0]["finish_reason"] is None: + time.sleep(0.2) time.sleep(1) self.wfile.write(b"data: [DONE]\n\n") self.wfile.flush() @@ -992,6 +994,84 @@ class StreamingToolRaceProxyTests(unittest.TestCase): payload["choices"][0]["message"]["content"], "stream follow-up accepted" ) + def test_streaming_tool_reasoning_is_available_before_finish_reason(self) -> None: + request_messages = [{"role": "user", "content": "stream tool"}] + request = Request( + f"{self.proxy.url}/v1/chat/completions", + data=json.dumps( + { + "model": "deepseek-v4-pro", + "stream": True, + "messages": request_messages, + "tools": [ + { + "type": "function", + "function": { + "name": "lookup", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + } + ).encode("utf-8"), + method="POST", + headers={ + "Authorization": "Bearer sk-cursor-test", + "Content-Type": "application/json", + }, + ) + + with urlopen(request, timeout=3) as response: + while True: + line = response.readline().decode("utf-8") + self.assertNotEqual(line, "") + if '"tool_calls"' in line: + break + + status, payload = post_json( + f"{self.proxy.url}/v1/chat/completions", + { + "model": "deepseek-v4-pro", + "messages": [ + *request_messages, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_stream_tool", + "type": "function", + "function": { + "name": "lookup", + "arguments": "{}", + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_stream_tool", + "content": "tool result", + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "lookup", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + ) + response.read() + + self.assertEqual(status, 200, payload) + self.assertEqual( + payload["choices"][0]["message"]["content"], "stream follow-up accepted" + ) + def first_cursor_request() -> dict: return { diff --git a/tests/test_server.py b/tests/test_server.py index fea48ef..8b1d675 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,21 +2,67 @@ from __future__ import annotations from io import BytesIO import gzip +import json +from types import SimpleNamespace import unittest import zlib -from deepseek_cursor_proxy.server import read_response_body, summarize_chat_payload +from deepseek_cursor_proxy.config import ProxyConfig +from deepseek_cursor_proxy.reasoning_store import ReasoningStore +from deepseek_cursor_proxy.server import ( + DeepSeekProxyHandler, + read_response_body, + summarize_chat_payload, +) class FakeResponse: - def __init__(self, body: bytes, encoding: str = "") -> None: + def __init__(self, body: bytes, encoding: str = "", status: int = 200) -> None: self._body = BytesIO(body) self.headers = {"Content-Encoding": encoding} if encoding else {} + self.status = status def read(self) -> bytes: return self._body.read() +class FakeStreamingResponse: + status = 200 + headers = {"Content-Type": "text/event-stream"} + + def __init__(self, lines: list[bytes]) -> None: + self._lines = lines + self.readline_calls = 0 + + def readline(self) -> bytes: + self.readline_calls += 1 + if not self._lines: + return b"" + return self._lines.pop(0) + + +class BrokenPipeWfile: + def write(self, body: bytes) -> None: + raise BrokenPipeError("test disconnect") + + def flush(self) -> None: + raise BrokenPipeError("test disconnect") + + +def make_proxy_handler(wfile: object) -> DeepSeekProxyHandler: + handler = object.__new__(DeepSeekProxyHandler) + handler.server = SimpleNamespace( + config=ProxyConfig(), + reasoning_store=ReasoningStore(":memory:"), + ) + handler.wfile = wfile + handler.close_connection = False + handler.send_response = lambda status: None + handler.send_header = lambda name, value: None + handler.end_headers = lambda: None + return handler + + class ServerTests(unittest.TestCase): def test_read_response_body_handles_gzip(self) -> None: body = gzip.compress(b'{"ok":true}') @@ -47,6 +93,71 @@ class ServerTests(unittest.TestCase): self.assertIn("tools=1", summary) self.assertNotIn("secret prompt", summary) + def test_regular_response_handles_client_disconnect(self) -> None: + handler = make_proxy_handler(BrokenPipeWfile()) + body = json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "model": "deepseek-v4-pro", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "ok"}, + } + ], + } + ).encode("utf-8") + + try: + with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured: + sent = handler._proxy_regular_response( + FakeResponse(body), + "deepseek-v4-pro", + [{"role": "user", "content": "hi"}], + "cache-namespace", + ) + finally: + handler.server.reasoning_store.close() + + self.assertFalse(sent) + self.assertIn("sending upstream response body", "\n".join(captured.output)) + + def test_streaming_response_stops_on_client_disconnect(self) -> None: + handler = make_proxy_handler(BrokenPipeWfile()) + chunk = { + "id": "chatcmpl-stream", + "model": "deepseek-v4-pro", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "hello"}, + } + ], + } + response = FakeStreamingResponse( + [ + f"data: {json.dumps(chunk)}\n\n".encode("utf-8"), + b"data: [DONE]\n\n", + ] + ) + + try: + with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured: + sent = handler._proxy_streaming_response( + response, + "deepseek-v4-pro", + [{"role": "user", "content": "hi"}], + "cache-namespace", + ) + finally: + handler.server.reasoning_store.close() + + self.assertFalse(sent) + self.assertEqual(response.readline_calls, 1) + self.assertIn("sending streaming response chunk", "\n".join(captured.output)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 26b7cd3..cd2a2de 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -116,6 +116,75 @@ class StreamAccumulatorTests(unittest.TestCase): self.assertEqual(accumulator.store_reasoning(store, scope), 0) store.close() + def test_stores_tool_call_reasoning_before_finish_reason(self) -> None: + store = ReasoningStore(":memory:") + accumulator = StreamAccumulator() + accumulator.ingest_chunk( + { + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": "Need a tool.", + }, + } + ] + } + ) + accumulator.ingest_chunk( + { + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_stream", + "type": "function", + "function": { + "name": "lookup", + "arguments": '{"query"', + }, + } + ], + }, + } + ] + } + ) + + scope = conversation_scope([{"role": "user", "content": "lookup"}]) + stored = accumulator.store_ready_reasoning(store, scope) + + self.assertGreater(stored, 0) + self.assertEqual( + store.get(f"scope:{scope}:tool_call:call_stream"), "Need a tool." + ) + + accumulator.ingest_chunk( + { + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "delta": { + "tool_calls": [ + { + "index": 0, + "function": {"arguments": ':"README"}'}, + } + ], + }, + } + ] + } + ) + + self.assertGreater(accumulator.store_ready_reasoning(store, scope), 0) + store.close() + def test_stores_empty_reasoning_content_when_stream_field_is_present( self, ) -> None: