deepseek-cursor-proxy/src/deepseek_cursor_proxy/transform.py

957 lines
32 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
import hashlib
import json
import logging
import re
from typing import Any
from .config import ProxyConfig
from .reasoning_store import (
ReasoningStore,
conversation_scope,
message_signature,
tool_call_ids,
tool_call_signature,
turn_context_signature,
)
from .streaming import fold_reasoning_into_content
LOG = logging.getLogger("deepseek_cursor_proxy")
SUPPORTED_REQUEST_FIELDS = {
"model",
"messages",
"stream",
"stream_options",
"max_tokens",
"response_format",
"stop",
"tools",
"tool_choice",
"thinking",
"reasoning_effort",
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"logprobs",
"top_logprobs",
# Standard OpenAI Chat Completions fields that DeepSeek either honors or
# safely ignores. Cursor and most OpenAI SDKs send these unconditionally,
# so forwarding keeps clients happy and avoids log spam.
"user",
"seed",
"n",
"logit_bias",
}
MESSAGE_FIELDS = {
"role",
"content",
"name",
"tool_call_id",
"tool_calls",
"reasoning_content",
"prefix",
}
ROLE_MESSAGE_FIELDS = {
"system": {"role", "content", "name"},
"user": {"role", "content", "name"},
"assistant": {
"role",
"content",
"name",
"tool_calls",
"reasoning_content",
"prefix",
},
"tool": {"role", "content", "tool_call_id"},
}
EFFORT_ALIASES = {
"low": "high",
"medium": "high",
"high": "high",
"max": "max",
"xhigh": "max",
}
CURSOR_THINKING_BLOCK_RE = re.compile(
r"""
(?:
<(?:think|thinking)\b[^>]*>[\s\S]*?(?:</(?:think|thinking)>|\Z)
|
<details\b[^>]*>\s*
<summary\b[^>]*>\s*Thinking\s*</summary>
[\s\S]*?(?:</details>|\Z)
)\s*
""",
re.IGNORECASE | re.VERBOSE,
)
RECOVERY_NOTICE_TEXT = "[deepseek-cursor-proxy] Refreshed reasoning_content history."
RECOVERY_NOTICE_CONTENT = f"{RECOVERY_NOTICE_TEXT}\n\n"
RECOVERY_SYSTEM_CONTENT = (
"deepseek-cursor-proxy recovered this request because older DeepSeek "
"thinking-mode tool-call reasoning_content was unavailable. Older "
"unrecoverable tool-call history was omitted; continue using only the "
"remaining recovered context."
)
@dataclass(frozen=True)
class PreparedRequest:
payload: dict[str, Any]
original_model: str
upstream_model: str
cache_namespace: str
patched_reasoning_messages: int
missing_reasoning_messages: int
recovered_reasoning_messages: int = 0
recovery_dropped_messages: int = 0
recovery_notice: str | None = None
record_response_scope: str | None = None
record_response_messages: list[dict[str, Any]] = field(default_factory=list)
record_response_contexts: list[tuple[str, list[dict[str, Any]]]] = field(
default_factory=list
)
reasoning_diagnostics: list[dict[str, Any]] = field(default_factory=list)
recovery_steps: list[dict[str, Any]] = field(default_factory=list)
continued_recovery_boundary: bool = False
retired_prefix_messages: int = 0
def normalize_reasoning_effort(value: Any) -> str:
if not isinstance(value, str):
return "high"
return EFFORT_ALIASES.get(value.strip().lower(), "high")
def extract_text_content(content: Any) -> str | None:
if content is None or isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
continue
if not isinstance(item, dict):
parts.append(str(item))
continue
item_type = item.get("type")
text = item.get("text") or item.get("content")
if item_type in {"text", "input_text"} and isinstance(text, str):
parts.append(text)
elif isinstance(text, str):
parts.append(text)
elif item_type:
parts.append(f"[{item_type} omitted by DeepSeek text proxy]")
return "\n".join(part for part in parts if part)
if isinstance(content, (dict, tuple)):
return json.dumps(content, ensure_ascii=False, sort_keys=True)
return str(content)
def strip_cursor_thinking_blocks(content: str) -> str:
return CURSOR_THINKING_BLOCK_RE.sub("", content).lstrip("\r\n")
def normalize_tool_call(tool_call: Any) -> dict[str, Any]:
if not isinstance(tool_call, dict):
tool_call = {}
function = tool_call.get("function") or {}
if not isinstance(function, dict):
function = {}
arguments = function.get("arguments", "")
if not isinstance(arguments, str):
arguments = json.dumps(arguments, ensure_ascii=False, sort_keys=True)
normalized: dict[str, Any] = {
"id": str(tool_call.get("id") or ""),
"type": tool_call.get("type") or "function",
"function": {
"name": str(function.get("name") or ""),
"arguments": arguments,
},
}
if not normalized["id"]:
normalized.pop("id")
return normalized
def normalize_tool(tool: Any) -> dict[str, Any]:
if not isinstance(tool, dict):
return {
"type": "function",
"function": {"name": "", "description": "", "parameters": {}},
}
normalized = dict(tool)
normalized["type"] = normalized.get("type") or "function"
function = normalized.get("function")
if isinstance(function, dict):
normalized["function"] = function
return normalized
def legacy_function_to_tool(function: Any) -> dict[str, Any]:
if not isinstance(function, dict):
function = {}
return {"type": "function", "function": function}
def convert_function_call(function_call: Any) -> Any:
if isinstance(function_call, str):
if function_call in {"auto", "none", "required"}:
return function_call
return None
if isinstance(function_call, dict) and function_call.get("name"):
return {
"type": "function",
"function": {"name": str(function_call["name"])},
}
return None
def normalize_tool_choice(tool_choice: Any) -> Any:
if isinstance(tool_choice, str):
if tool_choice in {"auto", "none", "required"}:
return tool_choice
return None
if isinstance(tool_choice, dict):
if tool_choice.get("type") == "function":
function = tool_choice.get("function")
if isinstance(function, dict) and function.get("name"):
return {
"type": "function",
"function": {"name": str(function["name"])},
}
return tool_choice
return tool_choice
def normalize_message(
message: Any,
store: ReasoningStore | None,
prior_messages: list[dict[str, Any]],
cache_namespace: str,
repair_reasoning: bool,
keep_reasoning: bool,
) -> tuple[dict[str, Any], bool, bool, dict[str, Any] | None]:
if not isinstance(message, dict):
message = {"role": "user", "content": str(message)}
normalized = {key: value for key, value in message.items() if key in MESSAGE_FIELDS}
role = normalized.get("role") or "user"
normalized["role"] = role
if role == "function":
normalized["role"] = "tool"
if "content" in normalized:
normalized["content"] = extract_text_content(normalized["content"]) or ""
elif normalized["role"] in {"assistant", "tool", "system", "user"}:
normalized["content"] = ""
if normalized["role"] == "assistant" and isinstance(normalized.get("content"), str):
normalized["content"] = strip_cursor_thinking_blocks(normalized["content"])
if normalized.get("tool_calls"):
normalized["tool_calls"] = [
normalize_tool_call(tool_call)
for tool_call in normalized.get("tool_calls") or []
]
patched = False
missing = False
diagnostic: dict[str, Any] | None = None
if normalized["role"] == "assistant":
if not keep_reasoning:
normalized.pop("reasoning_content", None)
elif repair_reasoning:
reasoning = normalized.get("reasoning_content")
if not isinstance(reasoning, str):
normalized.pop("reasoning_content", None)
needs_reasoning = assistant_needs_reasoning_for_tool_context(
normalized, prior_messages
)
lookup_scope = conversation_scope(prior_messages, cache_namespace)
lookup_keys = (
reasoning_lookup_keys(
normalized,
lookup_scope,
cache_namespace,
prior_messages,
)
if needs_reasoning
else []
)
hit_kind = None
if needs_reasoning and store is not None:
for lookup_key in lookup_keys:
restored = store.get(str(lookup_key["key"]))
if restored is not None:
lookup_key["hit"] = True
hit_kind = lookup_key["kind"]
normalized["reasoning_content"] = restored
patched = True
if not lookup_key.get("portable"):
store.backfill_portable_aliases(
normalized,
restored,
cache_namespace,
prior_messages,
)
break
if needs_reasoning and not patched:
missing = True
if needs_reasoning:
diagnostic = {
"message_index": len(prior_messages),
"role": "assistant",
"needs_reasoning": True,
"had_reasoning_content": False,
"patched": patched,
"missing": missing,
"lookup_scope": lookup_scope,
"message_signature": message_signature(normalized),
"tool_call_ids": tool_call_ids(normalized),
"lookup_keys": lookup_keys,
"hit_kind": hit_kind,
}
elif assistant_needs_reasoning_for_tool_context(normalized, prior_messages):
diagnostic = {
"message_index": len(prior_messages),
"role": "assistant",
"needs_reasoning": True,
"had_reasoning_content": True,
"patched": False,
"missing": False,
"lookup_scope": conversation_scope(prior_messages, cache_namespace),
"message_signature": message_signature(normalized),
"tool_call_ids": tool_call_ids(normalized),
"lookup_keys": [],
"hit_kind": "request",
}
allowed_fields = ROLE_MESSAGE_FIELDS.get(str(normalized["role"]), MESSAGE_FIELDS)
normalized = {
key: value for key, value in normalized.items() if key in allowed_fields
}
return normalized, patched, missing, diagnostic
def reasoning_lookup_keys(
message: dict[str, Any],
scope: str,
cache_namespace: str = "",
prior_messages: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
keys = [
{
"kind": "message_signature",
"key": f"scope:{scope}:signature:{message_signature(message)}",
"portable": False,
"hit": False,
}
]
keys.extend(
{
"kind": "tool_call_id",
"tool_call_id": tool_call_id,
"key": f"scope:{scope}:tool_call:{tool_call_id}",
"portable": False,
"hit": False,
}
for tool_call_id in tool_call_ids(message)
)
keys.extend(
{
"kind": "tool_call_signature",
"function_name": str((tool_call.get("function") or {}).get("name") or ""),
"key": (
f"scope:{scope}:tool_call_signature:"
f"{tool_call_signature(tool_call)}"
),
"portable": False,
"hit": False,
}
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
if cache_namespace and prior_messages is not None:
turn_signature = turn_context_signature(prior_messages)
keys.append(
{
"kind": "portable_message_signature",
"key": (
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"signature:{message_signature(message)}"
),
"turn_context_signature": turn_signature,
"portable": True,
"hit": False,
}
)
keys.extend(
{
"kind": "portable_tool_call_id",
"tool_call_id": tool_call_id,
"key": (
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"tool_call:{tool_call_id}"
),
"turn_context_signature": turn_signature,
"portable": True,
"hit": False,
}
for tool_call_id in tool_call_ids(message)
)
keys.extend(
{
"kind": "portable_tool_call_signature",
"function_name": str(
(tool_call.get("function") or {}).get("name") or ""
),
"key": (
f"namespace:{cache_namespace}:turn:{turn_signature}:"
f"tool_call_signature:{tool_call_signature(tool_call)}"
),
"turn_context_signature": turn_signature,
"portable": True,
"hit": False,
}
for tool_call in (message.get("tool_calls") or [])
if isinstance(tool_call, dict)
)
return keys
def normalize_messages(
messages: Any,
store: ReasoningStore | None,
cache_namespace: str,
repair_reasoning: bool,
keep_reasoning: bool,
) -> tuple[list[dict[str, Any]], int, list[int], list[dict[str, Any]]]:
if not isinstance(messages, list):
return [], 0, [], []
normalized_messages: list[dict[str, Any]] = []
patched_count = 0
missing_indexes: list[int] = []
diagnostics: list[dict[str, Any]] = []
for message in messages:
normalized, patched, missing, diagnostic = normalize_message(
message,
store,
normalized_messages,
cache_namespace,
repair_reasoning,
keep_reasoning,
)
normalized_messages.append(normalized)
if patched:
patched_count += 1
if missing:
missing_indexes.append(len(normalized_messages) - 1)
if diagnostic is not None:
diagnostics.append(diagnostic)
return normalized_messages, patched_count, missing_indexes, diagnostics
def has_recovery_notice(message: dict[str, Any]) -> bool:
content = message.get("content")
return (
message.get("role") == "assistant"
and isinstance(content, str)
and content.startswith(RECOVERY_NOTICE_TEXT)
)
def strip_recovery_notice_for_upstream(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Cursor echoes the proxy's recovery notice back to us in later turns.
The notice serves as a boundary marker for the proxy, but DeepSeek must
not see proxy-generated prose. Return a copy with assistant prefixes
stripped; leave the input untouched so cache scopes/recording contexts
keep matching the with-prefix history that Cursor will send next time."""
stripped: list[dict[str, Any]] = []
for message in messages:
if message.get("role") != "assistant":
stripped.append(message)
continue
content = message.get("content")
if not isinstance(content, str) or not content.startswith(RECOVERY_NOTICE_TEXT):
stripped.append(message)
continue
cleaned = dict(message)
cleaned["content"] = content[len(RECOVERY_NOTICE_TEXT) :].lstrip("\r\n")
stripped.append(cleaned)
return stripped
def leading_system_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
leading_messages: list[dict[str, Any]] = []
for message in messages:
if message.get("role") == "system":
leading_messages.append(message)
continue
break
return leading_messages
def active_messages_from_recovery_boundary(
messages: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], int, dict[str, Any]] | None:
recovery_boundary_index = next(
(
index
for index in range(len(messages) - 1, -1, -1)
if has_recovery_notice(messages[index])
),
-1,
)
if recovery_boundary_index == -1:
return None
context_user_index = next(
(
index
for index in range(recovery_boundary_index - 1, -1, -1)
if messages[index].get("role") == "user"
),
-1,
)
leading_messages = leading_system_messages(messages)
recovered_tail = []
if context_user_index != -1:
recovered_tail.append(messages[context_user_index])
recovered_tail.extend(messages[recovery_boundary_index:])
active_messages = [
*leading_messages,
{"role": "system", "content": RECOVERY_SYSTEM_CONTENT},
*recovered_tail,
]
kept_context_messages = 1 if context_user_index != -1 else 0
retired_messages = (
recovery_boundary_index - len(leading_messages) - kept_context_messages
)
retired_messages = max(retired_messages, 0)
step = {
"strategy": "continued_recovery_boundary",
"recovery_boundary_index": recovery_boundary_index,
"context_user_index": context_user_index,
"retired_prefix_messages": retired_messages,
}
return active_messages, retired_messages, step
def recover_messages_from_missing_reasoning(
messages: list[dict[str, Any]],
missing_indexes: list[int],
) -> tuple[list[dict[str, Any]], int, str | None, dict[str, Any]]:
recovery_boundary_index = next(
(
index
for index in range(len(messages) - 1, -1, -1)
if has_recovery_notice(messages[index])
and any(missing_index < index for missing_index in missing_indexes)
),
-1,
)
if recovery_boundary_index != -1:
context_user_index = next(
(
index
for index in range(recovery_boundary_index - 1, -1, -1)
if messages[index].get("role") == "user"
),
-1,
)
leading_messages = leading_system_messages(messages)
recovered_tail = []
if context_user_index != -1:
recovered_tail.append(messages[context_user_index])
recovered_tail.extend(messages[recovery_boundary_index:])
recovered = [
*leading_messages,
{"role": "system", "content": RECOVERY_SYSTEM_CONTENT},
*recovered_tail,
]
kept_context_messages = 1 if context_user_index != -1 else 0
omitted_messages = (
recovery_boundary_index - len(leading_messages) - kept_context_messages
)
return (
recovered,
omitted_messages,
None,
{
"strategy": "recovery_boundary",
"missing_indexes": missing_indexes,
"recovery_boundary_index": recovery_boundary_index,
"context_user_index": context_user_index,
"dropped_messages": omitted_messages,
"notice": None,
},
)
last_user_index = next(
(
index
for index in range(len(messages) - 1, -1, -1)
if messages[index].get("role") == "user"
),
-1,
)
if last_user_index == -1:
return (
messages,
0,
None,
{
"strategy": "none",
"missing_indexes": missing_indexes,
"last_user_index": None,
"dropped_messages": 0,
"notice": None,
},
)
recovered = leading_system_messages(messages)
omitted_messages = len(messages) - len(recovered) - 1
recovered.append({"role": "system", "content": RECOVERY_SYSTEM_CONTENT})
recovered.append(messages[last_user_index])
return (
recovered,
omitted_messages,
RECOVERY_NOTICE_CONTENT,
{
"strategy": "latest_user",
"missing_indexes": missing_indexes,
"last_user_index": last_user_index,
"dropped_messages": omitted_messages,
"notice": RECOVERY_NOTICE_CONTENT,
},
)
def assistant_needs_reasoning_for_tool_context(
message: dict[str, Any],
prior_messages: list[dict[str, Any]],
) -> bool:
if message.get("tool_calls"):
return True
for prior_message in reversed(prior_messages):
role = prior_message.get("role")
if role == "tool":
return True
if role in {"user", "system"}:
return False
return False
def upstream_model_for(original_model: str, config: ProxyConfig) -> str:
if original_model.startswith("deepseek-"):
return original_model
LOG.warning(
"rewriting non-DeepSeek model %r to configured fallback %r",
original_model,
config.upstream_model,
)
return config.upstream_model
def reasoning_model_family(upstream_model: str) -> str:
if upstream_model in {"deepseek-v4-pro", "deepseek-v4-flash"}:
return "deepseek-v4"
return upstream_model
def reasoning_cache_namespace(
config: ProxyConfig,
upstream_model: str,
thinking: Any,
reasoning_effort: Any,
authorization: str | None = None,
) -> str:
auth_hash = ""
if authorization:
auth_hash = hashlib.sha256(authorization.encode("utf-8")).hexdigest()
payload = {
"base_url": config.upstream_base_url,
"model": reasoning_model_family(upstream_model),
"thinking": thinking,
"reasoning_effort": reasoning_effort,
"authorization_hash": auth_hash,
}
canonical = json.dumps(
payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")
)
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
def response_recording_contexts(
*items: tuple[str, list[dict[str, Any]]] | None,
) -> list[tuple[str, list[dict[str, Any]]]]:
contexts: list[tuple[str, list[dict[str, Any]]]] = []
seen: set[str] = set()
for item in items:
if item is None:
continue
scope, messages = item
if scope in seen:
continue
seen.add(scope)
contexts.append((scope, messages))
return contexts
def prepare_upstream_request(
payload: dict[str, Any],
config: ProxyConfig,
store: ReasoningStore | None,
authorization: str | None = None,
) -> PreparedRequest:
original_model = str(payload.get("model") or config.upstream_model)
upstream_model = upstream_model_for(original_model, config)
prepared = {
key: value for key, value in payload.items() if key in SUPPORTED_REQUEST_FIELDS
}
dropped_fields = sorted(
key
for key in payload.keys()
if key not in SUPPORTED_REQUEST_FIELDS
and key not in {"max_completion_tokens", "functions", "function_call"}
)
if dropped_fields:
LOG.warning(
"dropping unsupported request field(s): %s", ", ".join(dropped_fields)
)
if "max_tokens" not in prepared and "max_completion_tokens" in payload:
prepared["max_tokens"] = payload["max_completion_tokens"]
prepared["model"] = upstream_model
if prepared.get("stream"):
stream_options = prepared.get("stream_options")
if not isinstance(stream_options, dict):
stream_options = {}
else:
stream_options = dict(stream_options)
stream_options["include_usage"] = True
prepared["stream_options"] = stream_options
if "tools" in prepared and isinstance(prepared["tools"], list):
prepared["tools"] = [normalize_tool(tool) for tool in prepared["tools"]]
elif isinstance(payload.get("functions"), list):
prepared["tools"] = [
legacy_function_to_tool(function) for function in payload["functions"]
]
if "tool_choice" in prepared:
tool_choice = normalize_tool_choice(prepared["tool_choice"])
if tool_choice is None:
prepared.pop("tool_choice", None)
else:
prepared["tool_choice"] = tool_choice
elif "function_call" in payload:
tool_choice = convert_function_call(payload.get("function_call"))
if tool_choice is not None:
prepared["tool_choice"] = tool_choice
prepared["thinking"] = {"type": config.thinking}
thinking_enabled = config.thinking == "enabled"
thinking_disabled = config.thinking == "disabled"
if thinking_enabled:
prepared["reasoning_effort"] = normalize_reasoning_effort(
prepared.get("reasoning_effort") or config.reasoning_effort
)
cache_namespace = reasoning_cache_namespace(
config,
upstream_model,
prepared.get("thinking"),
prepared.get("reasoning_effort"),
authorization,
)
pre_repair_messages, _, _, _ = normalize_messages(
payload.get("messages"),
None,
cache_namespace,
repair_reasoning=False,
keep_reasoning=not thinking_disabled,
)
record_response_messages = pre_repair_messages
record_response_scope = conversation_scope(
record_response_messages, cache_namespace
)
messages_for_repair = pre_repair_messages
continued_recovery_boundary = False
retired_prefix_messages = 0
recovered_count = 0
recovery_dropped_messages = 0
recovery_notice = None
recovery_steps: list[dict[str, Any]] = []
if thinking_enabled and config.missing_reasoning_strategy == "recover":
boundary = active_messages_from_recovery_boundary(pre_repair_messages)
if boundary is not None:
messages_for_repair, retired_prefix_messages, boundary_step = boundary
continued_recovery_boundary = True
recovery_steps.append(boundary_step)
messages, patched_count, missing_indexes, reasoning_diagnostics = (
normalize_messages(
messages_for_repair,
store,
cache_namespace,
repair_reasoning=thinking_enabled,
keep_reasoning=not thinking_disabled,
)
)
while missing_indexes and config.missing_reasoning_strategy == "recover":
recovered_messages, dropped_messages, notice, recovery_step = (
recover_messages_from_missing_reasoning(messages, missing_indexes)
)
recovery_steps.append(recovery_step)
if not dropped_messages:
break
recovered_count += len(missing_indexes)
recovery_dropped_messages += dropped_messages
if notice:
recovery_notice = notice
(
messages,
patched_count,
missing_indexes,
latest_diagnostics,
) = normalize_messages(
recovered_messages,
store,
cache_namespace,
repair_reasoning=thinking_enabled,
keep_reasoning=not thinking_disabled,
)
reasoning_diagnostics.extend(latest_diagnostics)
active_record_response_scope = conversation_scope(messages, cache_namespace)
record_response_contexts = response_recording_contexts(
(record_response_scope, record_response_messages),
(active_record_response_scope, messages),
)
prepared["messages"] = strip_recovery_notice_for_upstream(messages)
return PreparedRequest(
payload=prepared,
original_model=original_model,
upstream_model=upstream_model,
cache_namespace=cache_namespace,
patched_reasoning_messages=patched_count,
missing_reasoning_messages=len(missing_indexes),
recovered_reasoning_messages=recovered_count,
recovery_dropped_messages=recovery_dropped_messages,
recovery_notice=recovery_notice,
record_response_scope=record_response_scope,
record_response_messages=record_response_messages,
record_response_contexts=record_response_contexts,
reasoning_diagnostics=reasoning_diagnostics,
recovery_steps=recovery_steps,
continued_recovery_boundary=continued_recovery_boundary,
retired_prefix_messages=retired_prefix_messages,
)
def record_response_reasoning(
response_payload: dict[str, Any],
store: ReasoningStore | None,
request_messages: list[dict[str, Any]],
cache_namespace: str = "",
scope: str | None = None,
prior_messages: list[dict[str, Any]] | None = None,
recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
) -> int:
if store is None:
return 0
stored = 0
choices = response_payload.get("choices")
if not isinstance(choices, list):
return stored
if recording_contexts is None:
response_scope = (
scope
if scope is not None
else conversation_scope(request_messages, cache_namespace)
)
response_prior_messages = (
prior_messages if prior_messages is not None else request_messages
)
recording_contexts = [(response_scope, response_prior_messages)]
for choice in choices:
if not isinstance(choice, dict):
continue
message = choice.get("message")
if isinstance(message, dict):
for response_scope, response_prior_messages in recording_contexts:
stored += store.store_assistant_message(
message,
response_scope,
cache_namespace,
response_prior_messages,
)
return stored
def rewrite_response_body(
body: bytes,
original_model: str,
store: ReasoningStore | None,
request_messages: list[dict[str, Any]],
cache_namespace: str = "",
content_prefix: str | None = None,
scope: str | None = None,
prior_messages: list[dict[str, Any]] | None = None,
recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None,
display_reasoning: bool = False,
collapsible_reasoning: bool = True,
) -> bytes:
response_payload = json.loads(body.decode("utf-8"))
if isinstance(response_payload, dict):
if content_prefix:
prefix_response_content(response_payload, content_prefix)
record_response_reasoning(
response_payload,
store,
request_messages,
cache_namespace,
scope=scope,
prior_messages=prior_messages,
recording_contexts=recording_contexts,
)
if display_reasoning:
fold_reasoning_into_content(response_payload, collapsible_reasoning)
if "model" in response_payload:
response_payload["model"] = original_model
return json.dumps(
response_payload, ensure_ascii=False, separators=(",", ":")
).encode("utf-8")
def prefix_response_content(response_payload: dict[str, Any], prefix: str) -> bool:
choices = response_payload.get("choices")
if not isinstance(choices, list):
return False
for choice in choices:
if not isinstance(choice, dict):
continue
message = choice.get("message")
if not isinstance(message, dict):
continue
content = message.get("content")
message["content"] = prefix + (content if isinstance(content, str) else "")
return True
return False