fix(proxy): handle disconnects and tool-call reasoning race (#12)
parent
69366d8bd5
commit
4b66e5f081
|
|
@ -64,9 +64,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
request_path,
|
||||
self.client_address[0],
|
||||
)
|
||||
self.send_response(204)
|
||||
self._send_cors_headers()
|
||||
self.end_headers()
|
||||
self._send_response_headers(204, [], "sending CORS preflight response")
|
||||
|
||||
def do_GET(self) -> None:
|
||||
request_path = urlparse(self.path).path
|
||||
|
|
@ -249,19 +247,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
elapsed_ms(started),
|
||||
)
|
||||
if prepared.payload.get("stream"):
|
||||
self._proxy_streaming_response(
|
||||
sent_response = self._proxy_streaming_response(
|
||||
response,
|
||||
prepared.original_model,
|
||||
prepared.payload["messages"],
|
||||
prepared.cache_namespace,
|
||||
)
|
||||
else:
|
||||
self._proxy_regular_response(
|
||||
sent_response = self._proxy_regular_response(
|
||||
response,
|
||||
prepared.original_model,
|
||||
prepared.payload["messages"],
|
||||
prepared.cache_namespace,
|
||||
)
|
||||
if not sent_response:
|
||||
return
|
||||
LOG.info(
|
||||
(
|
||||
"request complete status=%s stream=%s elapsed_ms=%s "
|
||||
|
|
@ -297,15 +297,49 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode(
|
||||
"utf-8"
|
||||
)
|
||||
sent_headers = self._send_response_headers(
|
||||
status,
|
||||
[
|
||||
("Content-Type", "application/json"),
|
||||
("Content-Length", str(len(body))),
|
||||
],
|
||||
"sending JSON response headers",
|
||||
)
|
||||
if sent_headers:
|
||||
self._write_to_client(body, "sending JSON response body")
|
||||
|
||||
def _send_response_headers(
|
||||
self,
|
||||
status: int,
|
||||
headers: list[tuple[str, str]],
|
||||
disconnect_context: str,
|
||||
) -> bool:
|
||||
try:
|
||||
self.send_response(status)
|
||||
self._send_cors_headers()
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.send_header("Content-Length", str(len(body)))
|
||||
for name, value in headers:
|
||||
self.send_header(name, value)
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
except (BrokenPipeError, ConnectionError) as exc:
|
||||
LOG.warning("client disconnected before response could be sent: %s", exc)
|
||||
LOG.warning("client disconnected while %s: %s", disconnect_context, exc)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _write_to_client(
|
||||
self,
|
||||
body: bytes,
|
||||
disconnect_context: str,
|
||||
*,
|
||||
flush: bool = False,
|
||||
) -> bool:
|
||||
try:
|
||||
self.wfile.write(body)
|
||||
if flush:
|
||||
self.wfile.flush()
|
||||
except (BrokenPipeError, ConnectionError) as exc:
|
||||
LOG.warning("client disconnected while %s: %s", disconnect_context, exc)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _send_models(self) -> None:
|
||||
created = int(time.time())
|
||||
|
|
@ -368,17 +402,16 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
body = read_response_body(exc)
|
||||
if self.config.verbose:
|
||||
log_bytes("upstream error body", body)
|
||||
try:
|
||||
self.send_response(exc.code)
|
||||
self._send_cors_headers()
|
||||
self.send_header(
|
||||
"Content-Type", exc.headers.get("Content-Type", "application/json")
|
||||
)
|
||||
self.send_header("Content-Length", str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
except (BrokenPipeError, ConnectionError) as write_err:
|
||||
LOG.warning("client disconnected before upstream error could be sent: %s", write_err)
|
||||
sent_headers = self._send_response_headers(
|
||||
exc.code,
|
||||
[
|
||||
("Content-Type", exc.headers.get("Content-Type", "application/json")),
|
||||
("Content-Length", str(len(body))),
|
||||
],
|
||||
"sending upstream error headers",
|
||||
)
|
||||
if sent_headers:
|
||||
self._write_to_client(body, "sending upstream error body")
|
||||
|
||||
def _proxy_regular_response(
|
||||
self,
|
||||
|
|
@ -386,7 +419,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
original_model: str,
|
||||
request_messages: list[dict[str, Any]],
|
||||
cache_namespace: str,
|
||||
) -> None:
|
||||
) -> bool:
|
||||
body = read_response_body(response)
|
||||
try:
|
||||
body = rewrite_response_body(
|
||||
|
|
@ -403,14 +436,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
if self.config.verbose:
|
||||
log_bytes("cursor response body", body)
|
||||
|
||||
self.send_response(getattr(response, "status", 200))
|
||||
self._send_cors_headers()
|
||||
self.send_header(
|
||||
"Content-Type", response.headers.get("Content-Type", "application/json")
|
||||
sent_headers = self._send_response_headers(
|
||||
getattr(response, "status", 200),
|
||||
[
|
||||
(
|
||||
"Content-Type",
|
||||
response.headers.get("Content-Type", "application/json"),
|
||||
),
|
||||
("Content-Length", str(len(body))),
|
||||
],
|
||||
"sending upstream response headers",
|
||||
)
|
||||
self.send_header("Content-Length", str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
if not sent_headers:
|
||||
return False
|
||||
return self._write_to_client(body, "sending upstream response body")
|
||||
|
||||
def _proxy_streaming_response(
|
||||
self,
|
||||
|
|
@ -418,13 +457,18 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
original_model: str,
|
||||
request_messages: list[dict[str, Any]],
|
||||
cache_namespace: str,
|
||||
) -> None:
|
||||
self.send_response(getattr(response, "status", 200))
|
||||
self._send_cors_headers()
|
||||
self.send_header("Content-Type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self.send_header("Connection", "close")
|
||||
self.end_headers()
|
||||
) -> bool:
|
||||
sent_headers = self._send_response_headers(
|
||||
getattr(response, "status", 200),
|
||||
[
|
||||
("Content-Type", "text/event-stream"),
|
||||
("Cache-Control", "no-cache"),
|
||||
("Connection", "close"),
|
||||
],
|
||||
"sending streaming response headers",
|
||||
)
|
||||
if not sent_headers:
|
||||
return False
|
||||
self.close_connection = True
|
||||
|
||||
accumulator = StreamAccumulator()
|
||||
|
|
@ -442,8 +486,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
rewritten, finalized = self._rewrite_sse_line(
|
||||
line, original_model, accumulator, scope, display_adapter
|
||||
)
|
||||
self.wfile.write(rewritten)
|
||||
self.wfile.flush()
|
||||
if not self._write_to_client(
|
||||
rewritten, "sending streaming response chunk", flush=True
|
||||
):
|
||||
return False
|
||||
if finalized:
|
||||
break
|
||||
|
||||
|
|
@ -453,6 +499,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
stored = accumulator.store_reasoning(self.reasoning_store, scope)
|
||||
if stored:
|
||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||
return True
|
||||
|
||||
def _rewrite_sse_line(
|
||||
self,
|
||||
|
|
@ -487,7 +534,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
|||
|
||||
if isinstance(chunk, dict):
|
||||
accumulator.ingest_chunk(chunk)
|
||||
stored = accumulator.store_finished_reasoning(self.reasoning_store, scope)
|
||||
stored = accumulator.store_ready_reasoning(self.reasoning_store, scope)
|
||||
if stored:
|
||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||
log_usage(chunk.get("usage"))
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class StreamingChoice:
|
|||
class StreamAccumulator:
|
||||
def __init__(self) -> None:
|
||||
self.choices: dict[int, StreamingChoice] = {}
|
||||
self._stored_choices: set[int] = set()
|
||||
self._stored_choices: dict[int, str] = {}
|
||||
|
||||
def ingest_chunk(self, chunk: dict[str, Any]) -> None:
|
||||
choices = chunk.get("choices")
|
||||
|
|
@ -80,7 +80,16 @@ class StreamAccumulator:
|
|||
stored = 0
|
||||
for index, choice in self.choices.items():
|
||||
if choice.finish_reason is not None:
|
||||
stored += self._store_choice(index, choice, store, scope)
|
||||
stored += self._store_choice(index, choice, store, scope, "final")
|
||||
return stored
|
||||
|
||||
def store_ready_reasoning(self, store: ReasoningStore, scope: str) -> int:
|
||||
stored = 0
|
||||
for index, choice in self.choices.items():
|
||||
if choice.finish_reason is not None:
|
||||
stored += self._store_choice(index, choice, store, scope, "final")
|
||||
elif self._has_identified_tool_calls(choice):
|
||||
stored += self._store_choice(index, choice, store, scope, "tool_call")
|
||||
return stored
|
||||
|
||||
def messages(self) -> list[dict[str, Any]]:
|
||||
|
|
@ -131,14 +140,22 @@ class StreamAccumulator:
|
|||
choice: StreamingChoice,
|
||||
store: ReasoningStore,
|
||||
scope: str,
|
||||
stage: str = "final",
|
||||
) -> int:
|
||||
if index in self._stored_choices:
|
||||
stage_rank = {"tool_call": 1, "final": 2}
|
||||
previous_stage = self._stored_choices.get(index)
|
||||
if stage_rank.get(previous_stage or "", 0) >= stage_rank.get(stage, 0):
|
||||
return 0
|
||||
stored = store.store_assistant_message(choice.to_message(), scope)
|
||||
if stored:
|
||||
self._stored_choices.add(index)
|
||||
self._stored_choices[index] = stage
|
||||
return stored
|
||||
|
||||
def _has_identified_tool_calls(self, choice: StreamingChoice) -> bool:
|
||||
if not choice.has_reasoning_content or not choice.tool_calls:
|
||||
return False
|
||||
return all(bool(tool_call.get("id")) for tool_call in choice.tool_calls)
|
||||
|
||||
|
||||
class CursorReasoningDisplayAdapter:
|
||||
"""Mirror reasoning_content into content for Cursor's visible thinking UI path."""
|
||||
|
|
|
|||
|
|
@ -313,6 +313,8 @@ class ToolCallStreamingBeforeDoneDeepSeekHandler(BaseHTTPRequestHandler):
|
|||
for chunk in chunks:
|
||||
self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode("utf-8"))
|
||||
self.wfile.flush()
|
||||
if chunk["choices"][0]["finish_reason"] is None:
|
||||
time.sleep(0.2)
|
||||
time.sleep(1)
|
||||
self.wfile.write(b"data: [DONE]\n\n")
|
||||
self.wfile.flush()
|
||||
|
|
@ -992,6 +994,84 @@ class StreamingToolRaceProxyTests(unittest.TestCase):
|
|||
payload["choices"][0]["message"]["content"], "stream follow-up accepted"
|
||||
)
|
||||
|
||||
def test_streaming_tool_reasoning_is_available_before_finish_reason(self) -> None:
|
||||
request_messages = [{"role": "user", "content": "stream tool"}]
|
||||
request = Request(
|
||||
f"{self.proxy.url}/v1/chat/completions",
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "deepseek-v4-pro",
|
||||
"stream": True,
|
||||
"messages": request_messages,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
).encode("utf-8"),
|
||||
method="POST",
|
||||
headers={
|
||||
"Authorization": "Bearer sk-cursor-test",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
with urlopen(request, timeout=3) as response:
|
||||
while True:
|
||||
line = response.readline().decode("utf-8")
|
||||
self.assertNotEqual(line, "")
|
||||
if '"tool_calls"' in line:
|
||||
break
|
||||
|
||||
status, payload = post_json(
|
||||
f"{self.proxy.url}/v1/chat/completions",
|
||||
{
|
||||
"model": "deepseek-v4-pro",
|
||||
"messages": [
|
||||
*request_messages,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_stream_tool",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup",
|
||||
"arguments": "{}",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_stream_tool",
|
||||
"content": "tool result",
|
||||
},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
response.read()
|
||||
|
||||
self.assertEqual(status, 200, payload)
|
||||
self.assertEqual(
|
||||
payload["choices"][0]["message"]["content"], "stream follow-up accepted"
|
||||
)
|
||||
|
||||
|
||||
def first_cursor_request() -> dict:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -2,21 +2,67 @@ from __future__ import annotations
|
|||
|
||||
from io import BytesIO
|
||||
import gzip
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
import unittest
|
||||
import zlib
|
||||
|
||||
from deepseek_cursor_proxy.server import read_response_body, summarize_chat_payload
|
||||
from deepseek_cursor_proxy.config import ProxyConfig
|
||||
from deepseek_cursor_proxy.reasoning_store import ReasoningStore
|
||||
from deepseek_cursor_proxy.server import (
|
||||
DeepSeekProxyHandler,
|
||||
read_response_body,
|
||||
summarize_chat_payload,
|
||||
)
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, body: bytes, encoding: str = "") -> None:
|
||||
def __init__(self, body: bytes, encoding: str = "", status: int = 200) -> None:
|
||||
self._body = BytesIO(body)
|
||||
self.headers = {"Content-Encoding": encoding} if encoding else {}
|
||||
self.status = status
|
||||
|
||||
def read(self) -> bytes:
|
||||
return self._body.read()
|
||||
|
||||
|
||||
class FakeStreamingResponse:
|
||||
status = 200
|
||||
headers = {"Content-Type": "text/event-stream"}
|
||||
|
||||
def __init__(self, lines: list[bytes]) -> None:
|
||||
self._lines = lines
|
||||
self.readline_calls = 0
|
||||
|
||||
def readline(self) -> bytes:
|
||||
self.readline_calls += 1
|
||||
if not self._lines:
|
||||
return b""
|
||||
return self._lines.pop(0)
|
||||
|
||||
|
||||
class BrokenPipeWfile:
|
||||
def write(self, body: bytes) -> None:
|
||||
raise BrokenPipeError("test disconnect")
|
||||
|
||||
def flush(self) -> None:
|
||||
raise BrokenPipeError("test disconnect")
|
||||
|
||||
|
||||
def make_proxy_handler(wfile: object) -> DeepSeekProxyHandler:
|
||||
handler = object.__new__(DeepSeekProxyHandler)
|
||||
handler.server = SimpleNamespace(
|
||||
config=ProxyConfig(),
|
||||
reasoning_store=ReasoningStore(":memory:"),
|
||||
)
|
||||
handler.wfile = wfile
|
||||
handler.close_connection = False
|
||||
handler.send_response = lambda status: None
|
||||
handler.send_header = lambda name, value: None
|
||||
handler.end_headers = lambda: None
|
||||
return handler
|
||||
|
||||
|
||||
class ServerTests(unittest.TestCase):
|
||||
def test_read_response_body_handles_gzip(self) -> None:
|
||||
body = gzip.compress(b'{"ok":true}')
|
||||
|
|
@ -47,6 +93,71 @@ class ServerTests(unittest.TestCase):
|
|||
self.assertIn("tools=1", summary)
|
||||
self.assertNotIn("secret prompt", summary)
|
||||
|
||||
def test_regular_response_handles_client_disconnect(self) -> None:
|
||||
handler = make_proxy_handler(BrokenPipeWfile())
|
||||
body = json.dumps(
|
||||
{
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"model": "deepseek-v4-pro",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {"role": "assistant", "content": "ok"},
|
||||
}
|
||||
],
|
||||
}
|
||||
).encode("utf-8")
|
||||
|
||||
try:
|
||||
with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured:
|
||||
sent = handler._proxy_regular_response(
|
||||
FakeResponse(body),
|
||||
"deepseek-v4-pro",
|
||||
[{"role": "user", "content": "hi"}],
|
||||
"cache-namespace",
|
||||
)
|
||||
finally:
|
||||
handler.server.reasoning_store.close()
|
||||
|
||||
self.assertFalse(sent)
|
||||
self.assertIn("sending upstream response body", "\n".join(captured.output))
|
||||
|
||||
def test_streaming_response_stops_on_client_disconnect(self) -> None:
|
||||
handler = make_proxy_handler(BrokenPipeWfile())
|
||||
chunk = {
|
||||
"id": "chatcmpl-stream",
|
||||
"model": "deepseek-v4-pro",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "hello"},
|
||||
}
|
||||
],
|
||||
}
|
||||
response = FakeStreamingResponse(
|
||||
[
|
||||
f"data: {json.dumps(chunk)}\n\n".encode("utf-8"),
|
||||
b"data: [DONE]\n\n",
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured:
|
||||
sent = handler._proxy_streaming_response(
|
||||
response,
|
||||
"deepseek-v4-pro",
|
||||
[{"role": "user", "content": "hi"}],
|
||||
"cache-namespace",
|
||||
)
|
||||
finally:
|
||||
handler.server.reasoning_store.close()
|
||||
|
||||
self.assertFalse(sent)
|
||||
self.assertEqual(response.readline_calls, 1)
|
||||
self.assertIn("sending streaming response chunk", "\n".join(captured.output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -116,6 +116,75 @@ class StreamAccumulatorTests(unittest.TestCase):
|
|||
self.assertEqual(accumulator.store_reasoning(store, scope), 0)
|
||||
store.close()
|
||||
|
||||
def test_stores_tool_call_reasoning_before_finish_reason(self) -> None:
|
||||
store = ReasoningStore(":memory:")
|
||||
accumulator = StreamAccumulator()
|
||||
accumulator.ingest_chunk(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"reasoning_content": "Need a tool.",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
accumulator.ingest_chunk(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "call_stream",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup",
|
||||
"arguments": '{"query"',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
scope = conversation_scope([{"role": "user", "content": "lookup"}])
|
||||
stored = accumulator.store_ready_reasoning(store, scope)
|
||||
|
||||
self.assertGreater(stored, 0)
|
||||
self.assertEqual(
|
||||
store.get(f"scope:{scope}:tool_call:call_stream"), "Need a tool."
|
||||
)
|
||||
|
||||
accumulator.ingest_chunk(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "tool_calls",
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"function": {"arguments": ':"README"}'},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
self.assertGreater(accumulator.store_ready_reasoning(store, scope), 0)
|
||||
store.close()
|
||||
|
||||
def test_stores_empty_reasoning_content_when_stream_field_is_present(
|
||||
self,
|
||||
) -> None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue