fix(proxy): handle disconnects and tool-call reasoning race (#12)
parent
69366d8bd5
commit
4b66e5f081
|
|
@ -64,9 +64,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
request_path,
|
request_path,
|
||||||
self.client_address[0],
|
self.client_address[0],
|
||||||
)
|
)
|
||||||
self.send_response(204)
|
self._send_response_headers(204, [], "sending CORS preflight response")
|
||||||
self._send_cors_headers()
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
def do_GET(self) -> None:
|
def do_GET(self) -> None:
|
||||||
request_path = urlparse(self.path).path
|
request_path = urlparse(self.path).path
|
||||||
|
|
@ -249,19 +247,21 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
elapsed_ms(started),
|
elapsed_ms(started),
|
||||||
)
|
)
|
||||||
if prepared.payload.get("stream"):
|
if prepared.payload.get("stream"):
|
||||||
self._proxy_streaming_response(
|
sent_response = self._proxy_streaming_response(
|
||||||
response,
|
response,
|
||||||
prepared.original_model,
|
prepared.original_model,
|
||||||
prepared.payload["messages"],
|
prepared.payload["messages"],
|
||||||
prepared.cache_namespace,
|
prepared.cache_namespace,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._proxy_regular_response(
|
sent_response = self._proxy_regular_response(
|
||||||
response,
|
response,
|
||||||
prepared.original_model,
|
prepared.original_model,
|
||||||
prepared.payload["messages"],
|
prepared.payload["messages"],
|
||||||
prepared.cache_namespace,
|
prepared.cache_namespace,
|
||||||
)
|
)
|
||||||
|
if not sent_response:
|
||||||
|
return
|
||||||
LOG.info(
|
LOG.info(
|
||||||
(
|
(
|
||||||
"request complete status=%s stream=%s elapsed_ms=%s "
|
"request complete status=%s stream=%s elapsed_ms=%s "
|
||||||
|
|
@ -297,15 +297,49 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode(
|
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode(
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
sent_headers = self._send_response_headers(
|
||||||
|
status,
|
||||||
|
[
|
||||||
|
("Content-Type", "application/json"),
|
||||||
|
("Content-Length", str(len(body))),
|
||||||
|
],
|
||||||
|
"sending JSON response headers",
|
||||||
|
)
|
||||||
|
if sent_headers:
|
||||||
|
self._write_to_client(body, "sending JSON response body")
|
||||||
|
|
||||||
|
def _send_response_headers(
|
||||||
|
self,
|
||||||
|
status: int,
|
||||||
|
headers: list[tuple[str, str]],
|
||||||
|
disconnect_context: str,
|
||||||
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
self.send_response(status)
|
self.send_response(status)
|
||||||
self._send_cors_headers()
|
self._send_cors_headers()
|
||||||
self.send_header("Content-Type", "application/json")
|
for name, value in headers:
|
||||||
self.send_header("Content-Length", str(len(body)))
|
self.send_header(name, value)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(body)
|
|
||||||
except (BrokenPipeError, ConnectionError) as exc:
|
except (BrokenPipeError, ConnectionError) as exc:
|
||||||
LOG.warning("client disconnected before response could be sent: %s", exc)
|
LOG.warning("client disconnected while %s: %s", disconnect_context, exc)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _write_to_client(
|
||||||
|
self,
|
||||||
|
body: bytes,
|
||||||
|
disconnect_context: str,
|
||||||
|
*,
|
||||||
|
flush: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
try:
|
||||||
|
self.wfile.write(body)
|
||||||
|
if flush:
|
||||||
|
self.wfile.flush()
|
||||||
|
except (BrokenPipeError, ConnectionError) as exc:
|
||||||
|
LOG.warning("client disconnected while %s: %s", disconnect_context, exc)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _send_models(self) -> None:
|
def _send_models(self) -> None:
|
||||||
created = int(time.time())
|
created = int(time.time())
|
||||||
|
|
@ -368,17 +402,16 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
body = read_response_body(exc)
|
body = read_response_body(exc)
|
||||||
if self.config.verbose:
|
if self.config.verbose:
|
||||||
log_bytes("upstream error body", body)
|
log_bytes("upstream error body", body)
|
||||||
try:
|
sent_headers = self._send_response_headers(
|
||||||
self.send_response(exc.code)
|
exc.code,
|
||||||
self._send_cors_headers()
|
[
|
||||||
self.send_header(
|
("Content-Type", exc.headers.get("Content-Type", "application/json")),
|
||||||
"Content-Type", exc.headers.get("Content-Type", "application/json")
|
("Content-Length", str(len(body))),
|
||||||
)
|
],
|
||||||
self.send_header("Content-Length", str(len(body)))
|
"sending upstream error headers",
|
||||||
self.end_headers()
|
)
|
||||||
self.wfile.write(body)
|
if sent_headers:
|
||||||
except (BrokenPipeError, ConnectionError) as write_err:
|
self._write_to_client(body, "sending upstream error body")
|
||||||
LOG.warning("client disconnected before upstream error could be sent: %s", write_err)
|
|
||||||
|
|
||||||
def _proxy_regular_response(
|
def _proxy_regular_response(
|
||||||
self,
|
self,
|
||||||
|
|
@ -386,7 +419,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
original_model: str,
|
original_model: str,
|
||||||
request_messages: list[dict[str, Any]],
|
request_messages: list[dict[str, Any]],
|
||||||
cache_namespace: str,
|
cache_namespace: str,
|
||||||
) -> None:
|
) -> bool:
|
||||||
body = read_response_body(response)
|
body = read_response_body(response)
|
||||||
try:
|
try:
|
||||||
body = rewrite_response_body(
|
body = rewrite_response_body(
|
||||||
|
|
@ -403,14 +436,20 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
if self.config.verbose:
|
if self.config.verbose:
|
||||||
log_bytes("cursor response body", body)
|
log_bytes("cursor response body", body)
|
||||||
|
|
||||||
self.send_response(getattr(response, "status", 200))
|
sent_headers = self._send_response_headers(
|
||||||
self._send_cors_headers()
|
getattr(response, "status", 200),
|
||||||
self.send_header(
|
[
|
||||||
"Content-Type", response.headers.get("Content-Type", "application/json")
|
(
|
||||||
|
"Content-Type",
|
||||||
|
response.headers.get("Content-Type", "application/json"),
|
||||||
|
),
|
||||||
|
("Content-Length", str(len(body))),
|
||||||
|
],
|
||||||
|
"sending upstream response headers",
|
||||||
)
|
)
|
||||||
self.send_header("Content-Length", str(len(body)))
|
if not sent_headers:
|
||||||
self.end_headers()
|
return False
|
||||||
self.wfile.write(body)
|
return self._write_to_client(body, "sending upstream response body")
|
||||||
|
|
||||||
def _proxy_streaming_response(
|
def _proxy_streaming_response(
|
||||||
self,
|
self,
|
||||||
|
|
@ -418,13 +457,18 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
original_model: str,
|
original_model: str,
|
||||||
request_messages: list[dict[str, Any]],
|
request_messages: list[dict[str, Any]],
|
||||||
cache_namespace: str,
|
cache_namespace: str,
|
||||||
) -> None:
|
) -> bool:
|
||||||
self.send_response(getattr(response, "status", 200))
|
sent_headers = self._send_response_headers(
|
||||||
self._send_cors_headers()
|
getattr(response, "status", 200),
|
||||||
self.send_header("Content-Type", "text/event-stream")
|
[
|
||||||
self.send_header("Cache-Control", "no-cache")
|
("Content-Type", "text/event-stream"),
|
||||||
self.send_header("Connection", "close")
|
("Cache-Control", "no-cache"),
|
||||||
self.end_headers()
|
("Connection", "close"),
|
||||||
|
],
|
||||||
|
"sending streaming response headers",
|
||||||
|
)
|
||||||
|
if not sent_headers:
|
||||||
|
return False
|
||||||
self.close_connection = True
|
self.close_connection = True
|
||||||
|
|
||||||
accumulator = StreamAccumulator()
|
accumulator = StreamAccumulator()
|
||||||
|
|
@ -442,8 +486,10 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
rewritten, finalized = self._rewrite_sse_line(
|
rewritten, finalized = self._rewrite_sse_line(
|
||||||
line, original_model, accumulator, scope, display_adapter
|
line, original_model, accumulator, scope, display_adapter
|
||||||
)
|
)
|
||||||
self.wfile.write(rewritten)
|
if not self._write_to_client(
|
||||||
self.wfile.flush()
|
rewritten, "sending streaming response chunk", flush=True
|
||||||
|
):
|
||||||
|
return False
|
||||||
if finalized:
|
if finalized:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -453,6 +499,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
stored = accumulator.store_reasoning(self.reasoning_store, scope)
|
stored = accumulator.store_reasoning(self.reasoning_store, scope)
|
||||||
if stored:
|
if stored:
|
||||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||||
|
return True
|
||||||
|
|
||||||
def _rewrite_sse_line(
|
def _rewrite_sse_line(
|
||||||
self,
|
self,
|
||||||
|
|
@ -487,7 +534,7 @@ class DeepSeekProxyHandler(BaseHTTPRequestHandler):
|
||||||
|
|
||||||
if isinstance(chunk, dict):
|
if isinstance(chunk, dict):
|
||||||
accumulator.ingest_chunk(chunk)
|
accumulator.ingest_chunk(chunk)
|
||||||
stored = accumulator.store_finished_reasoning(self.reasoning_store, scope)
|
stored = accumulator.store_ready_reasoning(self.reasoning_store, scope)
|
||||||
if stored:
|
if stored:
|
||||||
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
LOG.info("stored %s streaming reasoning cache key(s)", stored)
|
||||||
log_usage(chunk.get("usage"))
|
log_usage(chunk.get("usage"))
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class StreamingChoice:
|
||||||
class StreamAccumulator:
|
class StreamAccumulator:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.choices: dict[int, StreamingChoice] = {}
|
self.choices: dict[int, StreamingChoice] = {}
|
||||||
self._stored_choices: set[int] = set()
|
self._stored_choices: dict[int, str] = {}
|
||||||
|
|
||||||
def ingest_chunk(self, chunk: dict[str, Any]) -> None:
|
def ingest_chunk(self, chunk: dict[str, Any]) -> None:
|
||||||
choices = chunk.get("choices")
|
choices = chunk.get("choices")
|
||||||
|
|
@ -80,7 +80,16 @@ class StreamAccumulator:
|
||||||
stored = 0
|
stored = 0
|
||||||
for index, choice in self.choices.items():
|
for index, choice in self.choices.items():
|
||||||
if choice.finish_reason is not None:
|
if choice.finish_reason is not None:
|
||||||
stored += self._store_choice(index, choice, store, scope)
|
stored += self._store_choice(index, choice, store, scope, "final")
|
||||||
|
return stored
|
||||||
|
|
||||||
|
def store_ready_reasoning(self, store: ReasoningStore, scope: str) -> int:
|
||||||
|
stored = 0
|
||||||
|
for index, choice in self.choices.items():
|
||||||
|
if choice.finish_reason is not None:
|
||||||
|
stored += self._store_choice(index, choice, store, scope, "final")
|
||||||
|
elif self._has_identified_tool_calls(choice):
|
||||||
|
stored += self._store_choice(index, choice, store, scope, "tool_call")
|
||||||
return stored
|
return stored
|
||||||
|
|
||||||
def messages(self) -> list[dict[str, Any]]:
|
def messages(self) -> list[dict[str, Any]]:
|
||||||
|
|
@ -131,14 +140,22 @@ class StreamAccumulator:
|
||||||
choice: StreamingChoice,
|
choice: StreamingChoice,
|
||||||
store: ReasoningStore,
|
store: ReasoningStore,
|
||||||
scope: str,
|
scope: str,
|
||||||
|
stage: str = "final",
|
||||||
) -> int:
|
) -> int:
|
||||||
if index in self._stored_choices:
|
stage_rank = {"tool_call": 1, "final": 2}
|
||||||
|
previous_stage = self._stored_choices.get(index)
|
||||||
|
if stage_rank.get(previous_stage or "", 0) >= stage_rank.get(stage, 0):
|
||||||
return 0
|
return 0
|
||||||
stored = store.store_assistant_message(choice.to_message(), scope)
|
stored = store.store_assistant_message(choice.to_message(), scope)
|
||||||
if stored:
|
if stored:
|
||||||
self._stored_choices.add(index)
|
self._stored_choices[index] = stage
|
||||||
return stored
|
return stored
|
||||||
|
|
||||||
|
def _has_identified_tool_calls(self, choice: StreamingChoice) -> bool:
|
||||||
|
if not choice.has_reasoning_content or not choice.tool_calls:
|
||||||
|
return False
|
||||||
|
return all(bool(tool_call.get("id")) for tool_call in choice.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
class CursorReasoningDisplayAdapter:
|
class CursorReasoningDisplayAdapter:
|
||||||
"""Mirror reasoning_content into content for Cursor's visible thinking UI path."""
|
"""Mirror reasoning_content into content for Cursor's visible thinking UI path."""
|
||||||
|
|
|
||||||
|
|
@ -313,6 +313,8 @@ class ToolCallStreamingBeforeDoneDeepSeekHandler(BaseHTTPRequestHandler):
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode("utf-8"))
|
self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode("utf-8"))
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
if chunk["choices"][0]["finish_reason"] is None:
|
||||||
|
time.sleep(0.2)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
self.wfile.write(b"data: [DONE]\n\n")
|
self.wfile.write(b"data: [DONE]\n\n")
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
@ -992,6 +994,84 @@ class StreamingToolRaceProxyTests(unittest.TestCase):
|
||||||
payload["choices"][0]["message"]["content"], "stream follow-up accepted"
|
payload["choices"][0]["message"]["content"], "stream follow-up accepted"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_streaming_tool_reasoning_is_available_before_finish_reason(self) -> None:
|
||||||
|
request_messages = [{"role": "user", "content": "stream tool"}]
|
||||||
|
request = Request(
|
||||||
|
f"{self.proxy.url}/v1/chat/completions",
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"stream": True,
|
||||||
|
"messages": request_messages,
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "lookup",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
).encode("utf-8"),
|
||||||
|
method="POST",
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer sk-cursor-test",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with urlopen(request, timeout=3) as response:
|
||||||
|
while True:
|
||||||
|
line = response.readline().decode("utf-8")
|
||||||
|
self.assertNotEqual(line, "")
|
||||||
|
if '"tool_calls"' in line:
|
||||||
|
break
|
||||||
|
|
||||||
|
status, payload = post_json(
|
||||||
|
f"{self.proxy.url}/v1/chat/completions",
|
||||||
|
{
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"messages": [
|
||||||
|
*request_messages,
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_stream_tool",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "lookup",
|
||||||
|
"arguments": "{}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_stream_tool",
|
||||||
|
"content": "tool result",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "lookup",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.read()
|
||||||
|
|
||||||
|
self.assertEqual(status, 200, payload)
|
||||||
|
self.assertEqual(
|
||||||
|
payload["choices"][0]["message"]["content"], "stream follow-up accepted"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def first_cursor_request() -> dict:
|
def first_cursor_request() -> dict:
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -2,21 +2,67 @@ from __future__ import annotations
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import gzip
|
import gzip
|
||||||
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
import unittest
|
import unittest
|
||||||
import zlib
|
import zlib
|
||||||
|
|
||||||
from deepseek_cursor_proxy.server import read_response_body, summarize_chat_payload
|
from deepseek_cursor_proxy.config import ProxyConfig
|
||||||
|
from deepseek_cursor_proxy.reasoning_store import ReasoningStore
|
||||||
|
from deepseek_cursor_proxy.server import (
|
||||||
|
DeepSeekProxyHandler,
|
||||||
|
read_response_body,
|
||||||
|
summarize_chat_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FakeResponse:
|
class FakeResponse:
|
||||||
def __init__(self, body: bytes, encoding: str = "") -> None:
|
def __init__(self, body: bytes, encoding: str = "", status: int = 200) -> None:
|
||||||
self._body = BytesIO(body)
|
self._body = BytesIO(body)
|
||||||
self.headers = {"Content-Encoding": encoding} if encoding else {}
|
self.headers = {"Content-Encoding": encoding} if encoding else {}
|
||||||
|
self.status = status
|
||||||
|
|
||||||
def read(self) -> bytes:
|
def read(self) -> bytes:
|
||||||
return self._body.read()
|
return self._body.read()
|
||||||
|
|
||||||
|
|
||||||
|
class FakeStreamingResponse:
|
||||||
|
status = 200
|
||||||
|
headers = {"Content-Type": "text/event-stream"}
|
||||||
|
|
||||||
|
def __init__(self, lines: list[bytes]) -> None:
|
||||||
|
self._lines = lines
|
||||||
|
self.readline_calls = 0
|
||||||
|
|
||||||
|
def readline(self) -> bytes:
|
||||||
|
self.readline_calls += 1
|
||||||
|
if not self._lines:
|
||||||
|
return b""
|
||||||
|
return self._lines.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
class BrokenPipeWfile:
|
||||||
|
def write(self, body: bytes) -> None:
|
||||||
|
raise BrokenPipeError("test disconnect")
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
raise BrokenPipeError("test disconnect")
|
||||||
|
|
||||||
|
|
||||||
|
def make_proxy_handler(wfile: object) -> DeepSeekProxyHandler:
|
||||||
|
handler = object.__new__(DeepSeekProxyHandler)
|
||||||
|
handler.server = SimpleNamespace(
|
||||||
|
config=ProxyConfig(),
|
||||||
|
reasoning_store=ReasoningStore(":memory:"),
|
||||||
|
)
|
||||||
|
handler.wfile = wfile
|
||||||
|
handler.close_connection = False
|
||||||
|
handler.send_response = lambda status: None
|
||||||
|
handler.send_header = lambda name, value: None
|
||||||
|
handler.end_headers = lambda: None
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
class ServerTests(unittest.TestCase):
|
class ServerTests(unittest.TestCase):
|
||||||
def test_read_response_body_handles_gzip(self) -> None:
|
def test_read_response_body_handles_gzip(self) -> None:
|
||||||
body = gzip.compress(b'{"ok":true}')
|
body = gzip.compress(b'{"ok":true}')
|
||||||
|
|
@ -47,6 +93,71 @@ class ServerTests(unittest.TestCase):
|
||||||
self.assertIn("tools=1", summary)
|
self.assertIn("tools=1", summary)
|
||||||
self.assertNotIn("secret prompt", summary)
|
self.assertNotIn("secret prompt", summary)
|
||||||
|
|
||||||
|
def test_regular_response_handles_client_disconnect(self) -> None:
|
||||||
|
handler = make_proxy_handler(BrokenPipeWfile())
|
||||||
|
body = json.dumps(
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-test",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"message": {"role": "assistant", "content": "ok"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
).encode("utf-8")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured:
|
||||||
|
sent = handler._proxy_regular_response(
|
||||||
|
FakeResponse(body),
|
||||||
|
"deepseek-v4-pro",
|
||||||
|
[{"role": "user", "content": "hi"}],
|
||||||
|
"cache-namespace",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
handler.server.reasoning_store.close()
|
||||||
|
|
||||||
|
self.assertFalse(sent)
|
||||||
|
self.assertIn("sending upstream response body", "\n".join(captured.output))
|
||||||
|
|
||||||
|
def test_streaming_response_stops_on_client_disconnect(self) -> None:
|
||||||
|
handler = make_proxy_handler(BrokenPipeWfile())
|
||||||
|
chunk = {
|
||||||
|
"id": "chatcmpl-stream",
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": "hello"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
response = FakeStreamingResponse(
|
||||||
|
[
|
||||||
|
f"data: {json.dumps(chunk)}\n\n".encode("utf-8"),
|
||||||
|
b"data: [DONE]\n\n",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.assertLogs("deepseek_cursor_proxy", level="WARNING") as captured:
|
||||||
|
sent = handler._proxy_streaming_response(
|
||||||
|
response,
|
||||||
|
"deepseek-v4-pro",
|
||||||
|
[{"role": "user", "content": "hi"}],
|
||||||
|
"cache-namespace",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
handler.server.reasoning_store.close()
|
||||||
|
|
||||||
|
self.assertFalse(sent)
|
||||||
|
self.assertEqual(response.readline_calls, 1)
|
||||||
|
self.assertIn("sending streaming response chunk", "\n".join(captured.output))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,75 @@ class StreamAccumulatorTests(unittest.TestCase):
|
||||||
self.assertEqual(accumulator.store_reasoning(store, scope), 0)
|
self.assertEqual(accumulator.store_reasoning(store, scope), 0)
|
||||||
store.close()
|
store.close()
|
||||||
|
|
||||||
|
def test_stores_tool_call_reasoning_before_finish_reason(self) -> None:
|
||||||
|
store = ReasoningStore(":memory:")
|
||||||
|
accumulator = StreamAccumulator()
|
||||||
|
accumulator.ingest_chunk(
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": "Need a tool.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
accumulator.ingest_chunk(
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": "call_stream",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "lookup",
|
||||||
|
"arguments": '{"query"',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scope = conversation_scope([{"role": "user", "content": "lookup"}])
|
||||||
|
stored = accumulator.store_ready_reasoning(store, scope)
|
||||||
|
|
||||||
|
self.assertGreater(stored, 0)
|
||||||
|
self.assertEqual(
|
||||||
|
store.get(f"scope:{scope}:tool_call:call_stream"), "Need a tool."
|
||||||
|
)
|
||||||
|
|
||||||
|
accumulator.ingest_chunk(
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"function": {"arguments": ':"README"}'},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertGreater(accumulator.store_ready_reasoning(store, scope), 0)
|
||||||
|
store.close()
|
||||||
|
|
||||||
def test_stores_empty_reasoning_content_when_stream_field_is_present(
|
def test_stores_empty_reasoning_content_when_stream_field_is_present(
|
||||||
self,
|
self,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue