fix(proxy): handle disconnects and tool-call reasoning race (#12)

main
Yixing Lao 2026-04-26 16:04:53 +08:00 committed by GitHub
parent 69366d8bd5
commit 4b66e5f081
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 368 additions and 44 deletions

View File

@ -64,9 +64,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
request_path, request_path,
self.client_address[0], self.client_address[0],
) )
self.send_response(204) self._send_response_headers(204, [], "sending CORS preflight response")
self._send_cors_headers()
self.end_headers()
def do_GET(self) -> None: def do_GET(self) -> None:
request_path = urlparse(self.path).path request_path = urlparse(self.path).path
@ -249,19 +247,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
elapsed_ms(started), elapsed_ms(started),
) )
if prepared.payload.get("stream"): if prepared.payload.get("stream"):
self._proxy_streaming_response( sent_response = self._proxy_streaming_response(
response, response,
prepared.original_model, prepared.original_model,
prepared.payload["messages"], prepared.payload["messages"],
prepared.cache_namespace, prepared.cache_namespace,
) )
else: else:
self._proxy_regular_response( sent_response = self._proxy_regular_response(
response, response,
prepared.original_model, prepared.original_model,
prepared.payload["messages"], prepared.payload["messages"],
prepared.cache_namespace, prepared.cache_namespace,
) )
if not sent_response:
return
LOG.info( LOG.info(
( (
"request complete status=%s stream=%s elapsed_ms=%s " "request complete status=%s stream=%s elapsed_ms=%s "
@ -297,15 +297,49 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode( body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode(
"utf-8" "utf-8"
) )
sent_headers = self._send_response_headers(
status,
[
("Content-Type", "application/json"),
("Content-Length", str(len(body))),
],
"sending JSON response headers",
)
if sent_headers:
self._write_to_client(body, "sending JSON response body")
def _send_response_headers(
self,
status: int,
headers: list[tuple[str, str]],
disconnect_context: str,
) -> bool:
try: try:
self.send_response(status) self.send_response(status)
self._send_cors_headers() self._send_cors_headers()
self.send_header("Content-Type", "application/json") for name, value in headers:
self.send_header("Content-Length", str(len(body))) self.send_header(name, value)
self.end_headers() self.end_headers()
self.wfile.write(body)
except (BrokenPipeError, ConnectionError) as exc: except (BrokenPipeError, ConnectionError) as exc:
LOG.warning("client disconnected before response could be sent: %s", exc) LOG.warning("client disconnected while %s: %s", disconnect_context, exc)
return False
return True
def _write_to_client(
self,
body: bytes,
disconnect_context: str,
*,
flush: bool = False,
) -> bool:
try:
self.wfile.write(body)
if flush:
self.wfile.flush()
except (BrokenPipeError, ConnectionError) as exc:
LOG.warning("client disconnected while %s: %s", disconnect_context, exc)
return False
return True
def _send_models(self) -> None: def _send_models(self) -> None:
created = int(time.time()) created = int(time.time())
@ -368,17 +402,16 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
body = read_response_body(exc) body = read_response_body(exc)
if self.config.verbose: if self.config.verbose:
log_bytes("upstream error body", body) log_bytes("upstream error body", body)
try: sent_headers = self._send_response_headers(
self.send_response(exc.code) exc.code,
self._send_cors_headers() [
self.send_header( ("Content-Type", exc.headers.get("Content-Type", "application/json")),
"Content-Type", exc.headers.get("Content-Type", "application/json") ("Content-Length", str(len(body))),
) ],
self.send_header("Content-Length", str(len(body))) "sending upstream error headers",
self.end_headers() )
self.wfile.write(body) if sent_headers:
except (BrokenPipeError, ConnectionError) as write_err: self._write_to_client(body, "sending upstream error body")
LOG.warning("client disconnected before upstream error could be sent: %s", write_err)
def _proxy_regular_response( def _proxy_regular_response(
self, self,
@ -386,7 +419,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
original_model: str, original_model: str,
request_messages: list[dict[str, Any]], request_messages: list[dict[str, Any]],
cache_namespace: str, cache_namespace: str,
) -> None: ) -> bool:
body = read_response_body(response) body = read_response_body(response)
try: try:
body = rewrite_response_body( body = rewrite_response_body(
@ -403,14 +436,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if self.config.verbose: if self.config.verbose:
log_bytes("cursor response body", body) log_bytes("cursor response body", body)
self.send_response(getattr(response, "status", 200)) sent_headers = self._send_response_headers(
self._send_cors_headers() getattr(response, "status", 200),
self.send_header( [
"Content-Type", response.headers.get("Content-Type", "application/json") (
"Content-Type",
response.headers.get("Content-Type", "application/json"),
),
("Content-Length", str(len(body))),
],
"sending upstream response headers",
) )
self.send_header("Content-Length", str(len(body))) if not sent_headers:
self.end_headers() return False
self.wfile.write(body) return self._write_to_client(body, "sending upstream response body")
def _proxy_streaming_response( def _proxy_streaming_response(
self, self,
@ -418,13 +457,18 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
original_model: str, original_model: str,
request_messages: list[dict[str, Any]], request_messages: list[dict[str, Any]],
cache_namespace: str, cache_namespace: str,
) -> None: ) -> bool:
self.send_response(getattr(response, "status", 200)) sent_headers = self._send_response_headers(
self._send_cors_headers() getattr(response, "status", 200),
self.send_header("Content-Type", "text/event-stream") [
self.send_header("Cache-Control", "no-cache") ("Content-Type", "text/event-stream"),
self.send_header("Connection", "close") ("Cache-Control", "no-cache"),
self.end_headers() ("Connection", "close"),
],
"sending streaming response headers",
)
if not sent_headers:
return False
self.close_connection = True self.close_connection = True
accumulator = StreamAccumulator() accumulator = StreamAccumulator()
@ -442,8 +486,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
rewritten, finalized = self._rewrite_sse_line( rewritten, finalized = self._rewrite_sse_line(
line, original_model, accumulator, scope, display_adapter line, original_model, accumulator, scope, display_adapter
) )
self.wfile.write(rewritten) if not self._write_to_client(
self.wfile.flush() rewritten, "sending streaming response chunk", flush=True
):
return False
if finalized: if finalized:
break break
@ -453,6 +499,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
stored = accumulator.store_reasoning(self.reasoning_store, scope) stored = accumulator.store_reasoning(self.reasoning_store, scope)
if stored: if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored) LOG.info("stored %s streaming reasoning cache key(s)", stored)
return True
def _rewrite_sse_line( def _rewrite_sse_line(
self, self,
@ -487,7 +534,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
if isinstance(chunk, dict): if isinstance(chunk, dict):
accumulator.ingest_chunk(chunk) accumulator.ingest_chunk(chunk)
stored = accumulator.store_finished_reasoning(self.reasoning_store, scope) stored = accumulator.store_ready_reasoning(self.reasoning_store, scope)
if stored: if stored:
LOG.info("stored %s streaming reasoning cache key(s)", stored) LOG.info("stored %s streaming reasoning cache key(s)", stored)
log_usage(chunk.get("usage")) log_usage(chunk.get("usage"))

View File

@ -35,7 +35,7 @@ class StreamingChoice:
class StreamAccumulator: class StreamAccumulator:
def __init__(self) -> None: def __init__(self) -> None:
self.choices: dict[int, StreamingChoice] = {} self.choices: dict[int, StreamingChoice] = {}
self._stored_choices: set[int] = set() self._stored_choices: dict[int, str] = {}
def ingest_chunk(self, chunk: dict[str, Any]) -> None: def ingest_chunk(self, chunk: dict[str, Any]) -> None:
choices = chunk.get("choices") choices = chunk.get("choices")
@ -80,7 +80,16 @@ class StreamAccumulator:
stored = 0 stored = 0
for index, choice in self.choices.items(): for index, choice in self.choices.items():
if choice.finish_reason is not None: if choice.finish_reason is not None:
stored += self._store_choice(index, choice, store, scope) stored += self._store_choice(index, choice, store, scope, "final")
return stored
def store_ready_reasoning(self, store: ReasoningStore, scope: str) -> int:
stored = 0
for index, choice in self.choices.items():
if choice.finish_reason is not None:
stored += self._store_choice(index, choice, store, scope, "final")
elif self._has_identified_tool_calls(choice):
stored += self._store_choice(index, choice, store, scope, "tool_call")
return stored return stored
def messages(self) -> list[dict[str, Any]]: def messages(self) -> list[dict[str, Any]]:
@ -131,14 +140,22 @@ class StreamAccumulator:
choice: StreamingChoice, choice: StreamingChoice,
store: ReasoningStore, store: ReasoningStore,
scope: str, scope: str,
stage: str = "final",
) -> int: ) -> int:
if index in self._stored_choices: stage_rank = {"tool_call": 1, "final": 2}
previous_stage = self._stored_choices.get(index)
if stage_rank.get(previous_stage or "", 0) >= stage_rank.get(stage, 0):
return 0 return 0
stored = store.store_assistant_message(choice.to_message(), scope) stored = store.store_assistant_message(choice.to_message(), scope)
if stored: if stored:
self._stored_choices.add(index) self._stored_choices[index] = stage
return stored return stored
def _has_identified_tool_calls(self, choice: StreamingChoice) -> bool:
if not choice.has_reasoning_content or not choice.tool_calls:
return False
return all(bool(tool_call.get("id")) for tool_call in choice.tool_calls)
class CursorReasoningDisplayAdapter: class CursorReasoningDisplayAdapter:
"""Mirror reasoning_content into content for Cursor's visible thinking UI path.""" """Mirror reasoning_content into content for Cursor's visible thinking UI path."""

View File

@ -313,6 +313,8 @@ class ToolCallStreamingBeforeDoneDeepSeekHandler(BaseHTTPRequestHandler):
for chunk in chunks: for chunk in chunks:
self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode("utf-8")) self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode("utf-8"))
self.wfile.flush() self.wfile.flush()
if chunk["choices"][0]["finish_reason"] is None:
time.sleep(0.2)
time.sleep(1) time.sleep(1)
self.wfile.write(b"data: [DONE]\n\n") self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush() self.wfile.flush()
@ -992,6 +994,84 @@ class StreamingToolRaceProxyTests(unittest.TestCase):
payload["choices"][0]["message"]["content"], "stream follow-up accepted" payload["choices"][0]["message"]["content"], "stream follow-up accepted"
) )
def test_streaming_tool_reasoning_is_available_before_finish_reason(self) -> None:
request_messages = [{"role": "user", "content": "stream tool"}]
request = Request(
f"{self.proxy.url}/v1/chat/completions",
data=json.dumps(
{
"model": "deepseek-v4-pro",
"stream": True,
"messages": request_messages,
"tools": [
{
"type": "function",
"function": {
"name": "lookup",
"parameters": {"type": "object", "properties": {}},
},
}
],
}
).encode("utf-8"),
method="POST",
headers={
"Authorization": "Bearer sk-cursor-test",
"Content-Type": "application/json",
},
)
with urlopen(request, timeout=3) as response:
while True:
line = response.readline().decode("utf-8")
self.assertNotEqual(line, "")
if '"tool_calls"' in line:
break
status, payload = post_json(
f"{self.proxy.url}/v1/chat/completions",
{
"model": "deepseek-v4-pro",
"messages": [
*request_messages,
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_stream_tool",
"type": "function",
"function": {
"name": "lookup",
"arguments": "{}",
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_stream_tool",
"content": "tool result",
},
],
"tools": [
{
"type": "function",
"function": {
"name": "lookup",
"parameters": {"type": "object", "properties": {}},
},
}
],
},
)
response.read()
self.assertEqual(status, 200, payload)
self.assertEqual(
payload["choices"][0]["message"]["content"], "stream follow-up accepted"
)
def first_cursor_request() -> dict: def first_cursor_request() -> dict:
return { return {

View File

@ -2,21 +2,67 @@ from __future__ import annotations
from io import BytesIO from io import BytesIO
import gzip import gzip
import json
from types import SimpleNamespace
import unittest import unittest
import zlib import zlib
from deepseek_cursor_proxy.server import read_response_body, summarize_chat_payload from deepseek_cursor_proxy.config import ProxyConfig
from deepseek_cursor_proxy.reasoning_store import ReasoningStore
from deepseek_cursor_proxy.server import (
DeepSeekProxyHandler,
read_response_body,
summarize_chat_payload,
)
class FakeResponse: class FakeResponse:
def __init__(self, body: bytes, encoding: str = "") -> None: def __init__(self, body: bytes, encoding: str = "", status: int = 200) -> None:
self._body = BytesIO(body) self._body = BytesIO(body)
self.headers = {"Content-Encoding": encoding} if encoding else {} self.headers = {"Content-Encoding": encoding} if encoding else {}
self.status = status
def read(self) -> bytes: def read(self) -> bytes:
return self._body.read() return self._body.read()
class FakeStreamingResponse:
status = 200
headers = {"Content-Type": "text/event-stream"}
def __init__(self, lines: list[bytes]) -> None:
self._lines = lines
self.readline_calls = 0
def readline(self) -> bytes:
self.readline_calls += 1
if not self._lines:
return b""
return self._lines.pop(0)
class BrokenPipeWfile:
def write(self, body: bytes) -> None:
raise BrokenPipeError("test disconnect")
def flush(self) -> None:
raise BrokenPipeError("test disconnect")
def make_proxy_handler(wfile: object) -> DeepSeekProxyHandler:
handler = object.__new__(DeepSeekProxyHandler)
handler.server = SimpleNamespace(
config=ProxyConfig(),
reasoning_store=ReasoningStore(":memory:"),
)
handler.wfile = wfile
handler.close_connection = False
handler.send_response = lambda status: None
handler.send_header = lambda name, value: None
handler.end_headers = lambda: None
return handler
class ServerTests(unittest.TestCase): class ServerTests(unittest.TestCase):
def test_read_response_body_handles_gzip(self) -> None: def test_read_response_body_handles_gzip(self) -> None:
body = gzip.compress(b'{"ok":true}') body = gzip.compress(b'{"ok":true}')
@ -47,6 +93,71 @@ class ServerTests(unittest.TestCase):
self.assertIn("tools=1", summary) self.assertIn("tools=1", summary)
self.assertNotIn("secret prompt", summary) self.assertNotIn("secret prompt", summary)
def test_regular_response_handles_client_disconnect(self) -> None:
handler = make_proxy_handler(BrokenPipeWfile())
body = json.dumps(
{
"id": "chatcmpl-test",
"object": "chat.completion",
"model": "deepseek-v4-pro",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": "ok"},
}
],
}
).encode("utf-8")
try:
with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured:
sent = handler._proxy_regular_response(
FakeResponse(body),
"deepseek-v4-pro",
[{"role": "user", "content": "hi"}],
"cache-namespace",
)
finally:
handler.server.reasoning_store.close()
self.assertFalse(sent)
self.assertIn("sending upstream response body", "\n".join(captured.output))
def test_streaming_response_stops_on_client_disconnect(self) -> None:
handler = make_proxy_handler(BrokenPipeWfile())
chunk = {
"id": "chatcmpl-stream",
"model": "deepseek-v4-pro",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "hello"},
}
],
}
response = FakeStreamingResponse(
[
f"data: {json.dumps(chunk)}\n\n".encode("utf-8"),
b"data: [DONE]\n\n",
]
)
try:
with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured:
sent = handler._proxy_streaming_response(
response,
"deepseek-v4-pro",
[{"role": "user", "content": "hi"}],
"cache-namespace",
)
finally:
handler.server.reasoning_store.close()
self.assertFalse(sent)
self.assertEqual(response.readline_calls, 1)
self.assertIn("sending streaming response chunk", "\n".join(captured.output))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -116,6 +116,75 @@ class StreamAccumulatorTests(unittest.TestCase):
self.assertEqual(accumulator.store_reasoning(store, scope), 0) self.assertEqual(accumulator.store_reasoning(store, scope), 0)
store.close() store.close()
def test_stores_tool_call_reasoning_before_finish_reason(self) -> None:
store = ReasoningStore(":memory:")
accumulator = StreamAccumulator()
accumulator.ingest_chunk(
{
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"reasoning_content": "Need a tool.",
},
}
]
}
)
accumulator.ingest_chunk(
{
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"id": "call_stream",
"type": "function",
"function": {
"name": "lookup",
"arguments": '{"query"',
},
}
],
},
}
]
}
)
scope = conversation_scope([{"role": "user", "content": "lookup"}])
stored = accumulator.store_ready_reasoning(store, scope)
self.assertGreater(stored, 0)
self.assertEqual(
store.get(f"scope:{scope}:tool_call:call_stream"), "Need a tool."
)
accumulator.ingest_chunk(
{
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"delta": {
"tool_calls": [
{
"index": 0,
"function": {"arguments": ':"README"}'},
}
],
},
}
]
}
)
self.assertGreater(accumulator.store_ready_reasoning(store, scope), 0)
store.close()
def test_stores_empty_reasoning_content_when_stream_field_is_present( def test_stores_empty_reasoning_content_when_stream_field_is_present(
self, self,
) -> None: ) -> None: