fix(proxy): handle disconnects and tool-call reasoning race (#12)

main
Yixing Lao 2026-04-26 16:04:53 +08:00 committed by GitHub
parent 69366d8bd5
commit 4b66e5f081
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 368 additions and 44 deletions

View File

@ -64,9 +64,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
request_path,
self.client_address[0],
)
self.send_response(204)
self._send_cors_headers()
self.end_headers()
self._send_response_headers(204, [], "sending CORS preflight response")
def do_GET(self) -> None:
request_path = urlparse(self.path).path
@ -249,19 +247,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
elapsed_ms(started),
)
if prepared.payload.get("stream"):
self._proxy_streaming_response(
sent_response = self._proxy_streaming_response(
response,
prepared.original_model,
prepared.payload["messages"],
prepared.cache_namespace,
)
else:
self._proxy_regular_response(
sent_response = self._proxy_regular_response(
response,
prepared.original_model,
prepared.payload["messages"],
prepared.cache_namespace,
)
if not sent_response:
return
LOG.info(
(
"request complete status=%s stream=%s elapsed_ms=%s "
@ -297,15 +297,49 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode(
"utf-8"
)
sent_headers = self._send_response_headers(
status,
[
("Content-Type", "application/json"),
("Content-Length", str(len(body))),
],
"sending JSON response headers",
)
if sent_headers:
self._write_to_client(body, "sending JSON response body")
def _send_response_headers(
self,
status: int,
headers: list[tuple[str, str]],
disconnect_context: str,
) -> bool:
try:
self.send_response(status)
self._send_cors_headers()
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
for name, value in headers:
self.send_header(name, value)
self.end_headers()
self.wfile.write(body)
except (BrokenPipeError, ConnectionError) as exc:
LOG.warning("client disconnected before response could be sent: %s", exc)
LOG.warning("client disconnected while %s: %s", disconnect_context, exc)
return False
return True
def _write_to_client(
self,
body: bytes,
disconnect_context: str,
*,
flush: bool = False,
) -> bool:
try:
self.wfile.write(body)
if flush:
self.wfile.flush()
except (BrokenPipeError, ConnectionError) as exc:
LOG.warning("client disconnected while %s: %s", disconnect_context, exc)
return False
return True
def _send_models(self) -> None:
created = int(time.time())
@ -368,17 +402,16 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
body = read_response_body(exc)
if self.config.verbose:
log_bytes("upstream error body", body)
try:
self.send_response(exc.code)
self._send_cors_headers()
self.send_header(
"Content-Type", exc.headers.get("Content-Type", "application/json")
sent_headers = self._send_response_headers(
exc.code,
[
("Content-Type", exc.headers.get("Content-Type", "application/json")),
("Content-Length", str(len(body))),
],
"sending upstream error headers",
)
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
except (BrokenPipeError, ConnectionError) as write_err:
LOG.warning("client disconnected before upstream error could be sent: %s", write_err)
if sent_headers:
self._write_to_client(body, "sending upstream error body")
def _proxy_regular_response(
self,
@ -386,7 +419,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
original_model: str,
request_messages: list[dict[str, Any]],
cache_namespace: str,
) -> None:
) -> bool:
body = read_response_body(response)
try:
body = rewrite_response_body(
@ -403,14 +436,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if self.config.verbose:
log_bytes("cursor response body", body)
self.send_response(getattr(response, "status", 200))
self._send_cors_headers()
self.send_header(
"Content-Type", response.headers.get("Content-Type", "application/json")
sent_headers = self._send_response_headers(
getattr(response, "status", 200),
[
(
"Content-Type",
response.headers.get("Content-Type", "application/json"),
),
("Content-Length", str(len(body))),
],
"sending upstream response headers",
)
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
if not sent_headers:
return False
return self._write_to_client(body, "sending upstream response body")
def _proxy_streaming_response(
self,
@ -418,13 +457,18 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
original_model: str,
request_messages: list[dict[str, Any]],
cache_namespace: str,
) -> None:
self.send_response(getattr(response, "status", 200))
self._send_cors_headers()
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "close")
self.end_headers()
) -> bool:
sent_headers = self._send_response_headers(
getattr(response, "status", 200),
[
("Content-Type", "text/event-stream"),
("Cache-Control", "no-cache"),
("Connection", "close"),
],
"sending streaming response headers",
)
if not sent_headers:
return False
self.close_connection = True
accumulator = StreamAccumulator()
@ -442,8 +486,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
rewritten, finalized = self._rewrite_sse_line(
line, original_model, accumulator, scope, display_adapter
)
self.wfile.write(rewritten)
self.wfile.flush()
if not self._write_to_client(
rewritten, "sending streaming response chunk", flush=True
):
return False
if finalized:
break
@ -453,6 +499,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
stored = accumulator.store_reasoning(self.reasoning_store, scope)
if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored)
return True
def _rewrite_sse_line(
self,
@ -487,7 +534,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if isinstance(chunk, dict):
accumulator.ingest_chunk(chunk)
stored = accumulator.store_finished_reasoning(self.reasoning_store, scope)
stored = accumulator.store_ready_reasoning(self.reasoning_store, scope)
if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored)
log_usage(chunk.get("usage"))

View File

@ -35,7 +35,7 @@ class StreamingChoice:
class StreamAccumulator:
def __init__(self) -> None:
self.choices: dict[int, StreamingChoice] = {}
self._stored_choices: set[int] = set()
self._stored_choices: dict[int, str] = {}
def ingest_chunk(self, chunk: dict[str, Any]) -> None:
choices = chunk.get("choices")
@ -80,7 +80,16 @@ class StreamAccumulator:
stored = 0
for index, choice in self.choices.items():
if choice.finish_reason is not None:
stored += self._store_choice(index, choice, store, scope)
stored += self._store_choice(index, choice, store, scope, "final")
return stored
def store_ready_reasoning(self, store: ReasoningStore, scope: str) -> int:
stored = 0
for index, choice in self.choices.items():
if choice.finish_reason is not None:
stored += self._store_choice(index, choice, store, scope, "final")
elif self._has_identified_tool_calls(choice):
stored += self._store_choice(index, choice, store, scope, "tool_call")
return stored
def messages(self) -> list[dict[str, Any]]:
@ -131,14 +140,22 @@ class StreamAccumulator:
choice: StreamingChoice,
store: ReasoningStore,
scope: str,
stage: str = "final",
) -> int:
if index in self._stored_choices:
stage_rank = {"tool_call": 1, "final": 2}
previous_stage = self._stored_choices.get(index)
if stage_rank.get(previous_stage or "", 0) >= stage_rank.get(stage, 0):
return 0
stored = store.store_assistant_message(choice.to_message(), scope)
if stored:
self._stored_choices.add(index)
self._stored_choices[index] = stage
return stored
def _has_identified_tool_calls(self, choice: StreamingChoice) -> bool:
if not choice.has_reasoning_content or not choice.tool_calls:
return False
return all(bool(tool_call.get("id")) for tool_call in choice.tool_calls)
class CursorReasoningDisplayAdapter:
"""Mirror reasoning_content into content for Cursor's visible thinking UI path."""

View File

@ -313,6 +313,8 @@ class ToolCallStreamingBeforeDoneDeepSeekHandler(BaseHTTPRequestHandler):
for chunk in chunks:
self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode("utf-8"))
self.wfile.flush()
if chunk["choices"][0]["finish_reason"] is None:
time.sleep(0.2)
time.sleep(1)
self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush()
@ -992,6 +994,84 @@ class StreamingToolRaceProxyTests(unittest.TestCase):
payload["choices"][0]["message"]["content"], "stream follow-up accepted"
)
def test_streaming_tool_reasoning_is_available_before_finish_reason(self) -> None:
request_messages = [{"role": "user", "content": "stream tool"}]
request = Request(
f"{self.proxy.url}/v1/chat/completions",
data=json.dumps(
{
"model": "deepseek-v4-pro",
"stream": True,
"messages": request_messages,
"tools": [
{
"type": "function",
"function": {
"name": "lookup",
"parameters": {"type": "object", "properties": {}},
},
}
],
}
).encode("utf-8"),
method="POST",
headers={
"Authorization": "Bearer sk-cursor-test",
"Content-Type": "application/json",
},
)
with urlopen(request, timeout=3) as response:
while True:
line = response.readline().decode("utf-8")
self.assertNotEqual(line, "")
if '"tool_calls"' in line:
break
status, payload = post_json(
f"{self.proxy.url}/v1/chat/completions",
{
"model": "deepseek-v4-pro",
"messages": [
*request_messages,
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_stream_tool",
"type": "function",
"function": {
"name": "lookup",
"arguments": "{}",
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_stream_tool",
"content": "tool result",
},
],
"tools": [
{
"type": "function",
"function": {
"name": "lookup",
"parameters": {"type": "object", "properties": {}},
},
}
],
},
)
response.read()
self.assertEqual(status, 200, payload)
self.assertEqual(
payload["choices"][0]["message"]["content"], "stream follow-up accepted"
)
def first_cursor_request() -> dict:
return {

View File

@ -2,21 +2,67 @@ from __future__ import annotations
from io import BytesIO
import gzip
import json
from types import SimpleNamespace
import unittest
import zlib
from deepseek_cursor_proxy.server import read_response_body, summarize_chat_payload
from deepseek_cursor_proxy.config import ProxyConfig
from deepseek_cursor_proxy.reasoning_store import ReasoningStore
from deepseek_cursor_proxy.server import (
DeepSeekProxyHandler,
read_response_body,
summarize_chat_payload,
)
class FakeResponse:
def __init__(self, body: bytes, encoding: str = "") -> None:
def __init__(self, body: bytes, encoding: str = "", status: int = 200) -> None:
self._body = BytesIO(body)
self.headers = {"Content-Encoding": encoding} if encoding else {}
self.status = status
def read(self) -> bytes:
return self._body.read()
class FakeStreamingResponse:
status = 200
headers = {"Content-Type": "text/event-stream"}
def __init__(self, lines: list[bytes]) -> None:
self._lines = lines
self.readline_calls = 0
def readline(self) -> bytes:
self.readline_calls += 1
if not self._lines:
return b""
return self._lines.pop(0)
class BrokenPipeWfile:
def write(self, body: bytes) -> None:
raise BrokenPipeError("test disconnect")
def flush(self) -> None:
raise BrokenPipeError("test disconnect")
def make_proxy_handler(wfile: object) -> DeepSeekProxyHandler:
handler = object.__new__(DeepSeekProxyHandler)
handler.server = SimpleNamespace(
config=ProxyConfig(),
reasoning_store=ReasoningStore(":memory:"),
)
handler.wfile = wfile
handler.close_connection = False
handler.send_response = lambda status: None
handler.send_header = lambda name, value: None
handler.end_headers = lambda: None
return handler
class ServerTests(unittest.TestCase):
def test_read_response_body_handles_gzip(self) -> None:
body = gzip.compress(b'{"ok":true}')
@ -47,6 +93,71 @@ class ServerTests(unittest.TestCase):
self.assertIn("tools=1", summary)
self.assertNotIn("secret prompt", summary)
def test_regular_response_handles_client_disconnect(self) -> None:
handler = make_proxy_handler(BrokenPipeWfile())
body = json.dumps(
{
"id": "chatcmpl-test",
"object": "chat.completion",
"model": "deepseek-v4-pro",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": "ok"},
}
],
}
).encode("utf-8")
try:
with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured:
sent = handler._proxy_regular_response(
FakeResponse(body),
"deepseek-v4-pro",
[{"role": "user", "content": "hi"}],
"cache-namespace",
)
finally:
handler.server.reasoning_store.close()
self.assertFalse(sent)
self.assertIn("sending upstream response body", "\n".join(captured.output))
def test_streaming_response_stops_on_client_disconnect(self) -> None:
handler = make_proxy_handler(BrokenPipeWfile())
chunk = {
"id": "chatcmpl-stream",
"model": "deepseek-v4-pro",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "hello"},
}
],
}
response = FakeStreamingResponse(
[
f"data: {json.dumps(chunk)}\n\n".encode("utf-8"),
b"data: [DONE]\n\n",
]
)
try:
with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured:
sent = handler._proxy_streaming_response(
response,
"deepseek-v4-pro",
[{"role": "user", "content": "hi"}],
"cache-namespace",
)
finally:
handler.server.reasoning_store.close()
self.assertFalse(sent)
self.assertEqual(response.readline_calls, 1)
self.assertIn("sending streaming response chunk", "\n".join(captured.output))
if __name__ == "__main__":
unittest.main()

View File

@ -116,6 +116,75 @@ class StreamAccumulatorTests(unittest.TestCase):
self.assertEqual(accumulator.store_reasoning(store, scope), 0)
store.close()
def test_stores_tool_call_reasoning_before_finish_reason(self) -> None:
store = ReasoningStore(":memory:")
accumulator = StreamAccumulator()
accumulator.ingest_chunk(
{
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"reasoning_content": "Need a tool.",
},
}
]
}
)
accumulator.ingest_chunk(
{
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"id": "call_stream",
"type": "function",
"function": {
"name": "lookup",
"arguments": '{"query"',
},
}
],
},
}
]
}
)
scope = conversation_scope([{"role": "user", "content": "lookup"}])
stored = accumulator.store_ready_reasoning(store, scope)
self.assertGreater(stored, 0)
self.assertEqual(
store.get(f"scope:{scope}:tool_call:call_stream"), "Need a tool."
)
accumulator.ingest_chunk(
{
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"delta": {
"tool_calls": [
{
"index": 0,
"function": {"arguments": ':"README"}'},
}
],
},
}
]
}
)
self.assertGreater(accumulator.store_ready_reasoning(store, scope), 0)
store.close()
def test_stores_empty_reasoning_content_when_stream_field_is_present(
self,
) -> None: