from __future__ import annotations from dataclasses import dataclass, field import hashlib import json import logging import re from typing import Any from .config import ProxyConfig from .reasoning_store import ( ReasoningStore, conversation_scope, message_signature, tool_call_ids, tool_call_signature, turn_context_signature, ) from .streaming import fold_reasoning_into_content LOG = logging.getLogger("deepseek_cursor_proxy") SUPPORTED_REQUEST_FIELDS = { "model", "messages", "stream", "stream_options", "max_tokens", "response_format", "stop", "tools", "tool_choice", "thinking", "reasoning_effort", "temperature", "top_p", "presence_penalty", "frequency_penalty", "logprobs", "top_logprobs", # Standard OpenAI Chat Completions fields that DeepSeek either honors or # safely ignores. Cursor and most OpenAI SDKs send these unconditionally, # so forwarding keeps clients happy and avoids log spam. "user", "seed", "n", "logit_bias", } MESSAGE_FIELDS = { "role", "content", "name", "tool_call_id", "tool_calls", "reasoning_content", "prefix", } ROLE_MESSAGE_FIELDS = { "system": {"role", "content", "name"}, "user": {"role", "content", "name"}, "assistant": { "role", "content", "name", "tool_calls", "reasoning_content", "prefix", }, "tool": {"role", "content", "tool_call_id"}, } EFFORT_ALIASES = { "low": "high", "medium": "high", "high": "high", "max": "max", "xhigh": "max", } CURSOR_THINKING_BLOCK_RE = re.compile( r""" (?: <(?:think|thinking)\b[^>]*>[\s\S]*?(?:|\Z) | ]*>\s* ]*>\s*Thinking\s* [\s\S]*?(?:|\Z) )\s* """, re.IGNORECASE | re.VERBOSE, ) RECOVERY_NOTICE_TEXT = "[deepseek-cursor-proxy] Refreshed reasoning_content history." RECOVERY_NOTICE_CONTENT = f"{RECOVERY_NOTICE_TEXT}\n\n" RECOVERY_SYSTEM_CONTENT = ( "deepseek-cursor-proxy recovered this request because older DeepSeek " "thinking-mode tool-call reasoning_content was unavailable. Older " "unrecoverable tool-call history was omitted; continue using only the " "remaining recovered context." ) @dataclass(frozen=True) class PreparedRequest: payload: dict[str, Any] original_model: str upstream_model: str cache_namespace: str patched_reasoning_messages: int missing_reasoning_messages: int recovered_reasoning_messages: int = 0 recovery_dropped_messages: int = 0 recovery_notice: str | None = None record_response_scope: str | None = None record_response_messages: list[dict[str, Any]] = field(default_factory=list) record_response_contexts: list[tuple[str, list[dict[str, Any]]]] = field( default_factory=list ) reasoning_diagnostics: list[dict[str, Any]] = field(default_factory=list) recovery_steps: list[dict[str, Any]] = field(default_factory=list) continued_recovery_boundary: bool = False retired_prefix_messages: int = 0 def normalize_reasoning_effort(value: Any) -> str: if not isinstance(value, str): return "high" return EFFORT_ALIASES.get(value.strip().lower(), "high") def extract_text_content(content: Any) -> str | None: if content is None or isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, str): parts.append(item) continue if not isinstance(item, dict): parts.append(str(item)) continue item_type = item.get("type") text = item.get("text") or item.get("content") if item_type in {"text", "input_text"} and isinstance(text, str): parts.append(text) elif isinstance(text, str): parts.append(text) elif item_type: parts.append(f"[{item_type} omitted by DeepSeek text proxy]") return "\n".join(part for part in parts if part) if isinstance(content, (dict, tuple)): return json.dumps(content, ensure_ascii=False, sort_keys=True) return str(content) def strip_cursor_thinking_blocks(content: str) -> str: return CURSOR_THINKING_BLOCK_RE.sub("", content).lstrip("\r\n") def normalize_tool_call(tool_call: Any) -> dict[str, Any]: if not isinstance(tool_call, dict): tool_call = {} function = tool_call.get("function") or {} if not isinstance(function, dict): function = {} arguments = function.get("arguments", "") if not isinstance(arguments, str): arguments = json.dumps(arguments, ensure_ascii=False, sort_keys=True) normalized: dict[str, Any] = { "id": str(tool_call.get("id") or ""), "type": tool_call.get("type") or "function", "function": { "name": str(function.get("name") or ""), "arguments": arguments, }, } if not normalized["id"]: normalized.pop("id") return normalized def normalize_tool(tool: Any) -> dict[str, Any]: if not isinstance(tool, dict): return { "type": "function", "function": {"name": "", "description": "", "parameters": {}}, } normalized = dict(tool) normalized["type"] = normalized.get("type") or "function" function = normalized.get("function") if isinstance(function, dict): normalized["function"] = function return normalized def legacy_function_to_tool(function: Any) -> dict[str, Any]: if not isinstance(function, dict): function = {} return {"type": "function", "function": function} def convert_function_call(function_call: Any) -> Any: if isinstance(function_call, str): if function_call in {"auto", "none", "required"}: return function_call return None if isinstance(function_call, dict) and function_call.get("name"): return { "type": "function", "function": {"name": str(function_call["name"])}, } return None def normalize_tool_choice(tool_choice: Any) -> Any: if isinstance(tool_choice, str): if tool_choice in {"auto", "none", "required"}: return tool_choice return None if isinstance(tool_choice, dict): if tool_choice.get("type") == "function": function = tool_choice.get("function") if isinstance(function, dict) and function.get("name"): return { "type": "function", "function": {"name": str(function["name"])}, } return tool_choice return tool_choice def normalize_message( message: Any, store: ReasoningStore | None, prior_messages: list[dict[str, Any]], cache_namespace: str, repair_reasoning: bool, keep_reasoning: bool, ) -> tuple[dict[str, Any], bool, bool, dict[str, Any] | None]: if not isinstance(message, dict): message = {"role": "user", "content": str(message)} normalized = {key: value for key, value in message.items() if key in MESSAGE_FIELDS} role = normalized.get("role") or "user" normalized["role"] = role if role == "function": normalized["role"] = "tool" if "content" in normalized: normalized["content"] = extract_text_content(normalized["content"]) or "" elif normalized["role"] in {"assistant", "tool", "system", "user"}: normalized["content"] = "" if normalized["role"] == "assistant" and isinstance(normalized.get("content"), str): normalized["content"] = strip_cursor_thinking_blocks(normalized["content"]) if normalized.get("tool_calls"): normalized["tool_calls"] = [ normalize_tool_call(tool_call) for tool_call in normalized.get("tool_calls") or [] ] patched = False missing = False diagnostic: dict[str, Any] | None = None if normalized["role"] == "assistant": if not keep_reasoning: normalized.pop("reasoning_content", None) elif repair_reasoning: reasoning = normalized.get("reasoning_content") if not isinstance(reasoning, str): normalized.pop("reasoning_content", None) needs_reasoning = assistant_needs_reasoning_for_tool_context( normalized, prior_messages ) lookup_scope = conversation_scope(prior_messages, cache_namespace) lookup_keys = ( reasoning_lookup_keys( normalized, lookup_scope, cache_namespace, prior_messages, ) if needs_reasoning else [] ) hit_kind = None if needs_reasoning and store is not None: for lookup_key in lookup_keys: restored = store.get(str(lookup_key["key"])) if restored is not None: lookup_key["hit"] = True hit_kind = lookup_key["kind"] normalized["reasoning_content"] = restored patched = True if not lookup_key.get("portable"): store.backfill_portable_aliases( normalized, restored, cache_namespace, prior_messages, ) break if needs_reasoning and not patched: missing = True if needs_reasoning: diagnostic = { "message_index": len(prior_messages), "role": "assistant", "needs_reasoning": True, "had_reasoning_content": False, "patched": patched, "missing": missing, "lookup_scope": lookup_scope, "message_signature": message_signature(normalized), "tool_call_ids": tool_call_ids(normalized), "lookup_keys": lookup_keys, "hit_kind": hit_kind, } elif assistant_needs_reasoning_for_tool_context(normalized, prior_messages): diagnostic = { "message_index": len(prior_messages), "role": "assistant", "needs_reasoning": True, "had_reasoning_content": True, "patched": False, "missing": False, "lookup_scope": conversation_scope(prior_messages, cache_namespace), "message_signature": message_signature(normalized), "tool_call_ids": tool_call_ids(normalized), "lookup_keys": [], "hit_kind": "request", } allowed_fields = ROLE_MESSAGE_FIELDS.get(str(normalized["role"]), MESSAGE_FIELDS) normalized = { key: value for key, value in normalized.items() if key in allowed_fields } return normalized, patched, missing, diagnostic def reasoning_lookup_keys( message: dict[str, Any], scope: str, cache_namespace: str = "", prior_messages: list[dict[str, Any]] | None = None, ) -> list[dict[str, Any]]: keys = [ { "kind": "message_signature", "key": f"scope:{scope}:signature:{message_signature(message)}", "portable": False, "hit": False, } ] keys.extend( { "kind": "tool_call_id", "tool_call_id": tool_call_id, "key": f"scope:{scope}:tool_call:{tool_call_id}", "portable": False, "hit": False, } for tool_call_id in tool_call_ids(message) ) keys.extend( { "kind": "tool_call_signature", "function_name": str((tool_call.get("function") or {}).get("name") or ""), "key": ( f"scope:{scope}:tool_call_signature:" f"{tool_call_signature(tool_call)}" ), "portable": False, "hit": False, } for tool_call in (message.get("tool_calls") or []) if isinstance(tool_call, dict) ) if cache_namespace and prior_messages is not None: turn_signature = turn_context_signature(prior_messages) keys.append( { "kind": "portable_message_signature", "key": ( f"namespace:{cache_namespace}:turn:{turn_signature}:" f"signature:{message_signature(message)}" ), "turn_context_signature": turn_signature, "portable": True, "hit": False, } ) keys.extend( { "kind": "portable_tool_call_id", "tool_call_id": tool_call_id, "key": ( f"namespace:{cache_namespace}:turn:{turn_signature}:" f"tool_call:{tool_call_id}" ), "turn_context_signature": turn_signature, "portable": True, "hit": False, } for tool_call_id in tool_call_ids(message) ) keys.extend( { "kind": "portable_tool_call_signature", "function_name": str( (tool_call.get("function") or {}).get("name") or "" ), "key": ( f"namespace:{cache_namespace}:turn:{turn_signature}:" f"tool_call_signature:{tool_call_signature(tool_call)}" ), "turn_context_signature": turn_signature, "portable": True, "hit": False, } for tool_call in (message.get("tool_calls") or []) if isinstance(tool_call, dict) ) return keys def normalize_messages( messages: Any, store: ReasoningStore | None, cache_namespace: str, repair_reasoning: bool, keep_reasoning: bool, ) -> tuple[list[dict[str, Any]], int, list[int], list[dict[str, Any]]]: if not isinstance(messages, list): return [], 0, [], [] normalized_messages: list[dict[str, Any]] = [] patched_count = 0 missing_indexes: list[int] = [] diagnostics: list[dict[str, Any]] = [] for message in messages: normalized, patched, missing, diagnostic = normalize_message( message, store, normalized_messages, cache_namespace, repair_reasoning, keep_reasoning, ) normalized_messages.append(normalized) if patched: patched_count += 1 if missing: missing_indexes.append(len(normalized_messages) - 1) if diagnostic is not None: diagnostics.append(diagnostic) return normalized_messages, patched_count, missing_indexes, diagnostics def has_recovery_notice(message: dict[str, Any]) -> bool: content = message.get("content") return ( message.get("role") == "assistant" and isinstance(content, str) and content.startswith(RECOVERY_NOTICE_TEXT) ) def strip_recovery_notice_for_upstream( messages: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Cursor echoes the proxy's recovery notice back to us in later turns. The notice serves as a boundary marker for the proxy, but DeepSeek must not see proxy-generated prose. Return a copy with assistant prefixes stripped; leave the input untouched so cache scopes/recording contexts keep matching the with-prefix history that Cursor will send next time.""" stripped: list[dict[str, Any]] = [] for message in messages: if message.get("role") != "assistant": stripped.append(message) continue content = message.get("content") if not isinstance(content, str) or not content.startswith(RECOVERY_NOTICE_TEXT): stripped.append(message) continue cleaned = dict(message) cleaned["content"] = content[len(RECOVERY_NOTICE_TEXT) :].lstrip("\r\n") stripped.append(cleaned) return stripped def leading_system_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: leading_messages: list[dict[str, Any]] = [] for message in messages: if message.get("role") == "system": leading_messages.append(message) continue break return leading_messages def active_messages_from_recovery_boundary( messages: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], int, dict[str, Any]] | None: recovery_boundary_index = next( ( index for index in range(len(messages) - 1, -1, -1) if has_recovery_notice(messages[index]) ), -1, ) if recovery_boundary_index == -1: return None context_user_index = next( ( index for index in range(recovery_boundary_index - 1, -1, -1) if messages[index].get("role") == "user" ), -1, ) leading_messages = leading_system_messages(messages) recovered_tail = [] if context_user_index != -1: recovered_tail.append(messages[context_user_index]) recovered_tail.extend(messages[recovery_boundary_index:]) active_messages = [ *leading_messages, {"role": "system", "content": RECOVERY_SYSTEM_CONTENT}, *recovered_tail, ] kept_context_messages = 1 if context_user_index != -1 else 0 retired_messages = ( recovery_boundary_index - len(leading_messages) - kept_context_messages ) retired_messages = max(retired_messages, 0) step = { "strategy": "continued_recovery_boundary", "recovery_boundary_index": recovery_boundary_index, "context_user_index": context_user_index, "retired_prefix_messages": retired_messages, } return active_messages, retired_messages, step def recover_messages_from_missing_reasoning( messages: list[dict[str, Any]], missing_indexes: list[int], ) -> tuple[list[dict[str, Any]], int, str | None, dict[str, Any]]: recovery_boundary_index = next( ( index for index in range(len(messages) - 1, -1, -1) if has_recovery_notice(messages[index]) and any(missing_index < index for missing_index in missing_indexes) ), -1, ) if recovery_boundary_index != -1: context_user_index = next( ( index for index in range(recovery_boundary_index - 1, -1, -1) if messages[index].get("role") == "user" ), -1, ) leading_messages = leading_system_messages(messages) recovered_tail = [] if context_user_index != -1: recovered_tail.append(messages[context_user_index]) recovered_tail.extend(messages[recovery_boundary_index:]) recovered = [ *leading_messages, {"role": "system", "content": RECOVERY_SYSTEM_CONTENT}, *recovered_tail, ] kept_context_messages = 1 if context_user_index != -1 else 0 omitted_messages = ( recovery_boundary_index - len(leading_messages) - kept_context_messages ) return ( recovered, omitted_messages, None, { "strategy": "recovery_boundary", "missing_indexes": missing_indexes, "recovery_boundary_index": recovery_boundary_index, "context_user_index": context_user_index, "dropped_messages": omitted_messages, "notice": None, }, ) last_user_index = next( ( index for index in range(len(messages) - 1, -1, -1) if messages[index].get("role") == "user" ), -1, ) if last_user_index == -1: return ( messages, 0, None, { "strategy": "none", "missing_indexes": missing_indexes, "last_user_index": None, "dropped_messages": 0, "notice": None, }, ) recovered = leading_system_messages(messages) omitted_messages = len(messages) - len(recovered) - 1 recovered.append({"role": "system", "content": RECOVERY_SYSTEM_CONTENT}) recovered.append(messages[last_user_index]) return ( recovered, omitted_messages, RECOVERY_NOTICE_CONTENT, { "strategy": "latest_user", "missing_indexes": missing_indexes, "last_user_index": last_user_index, "dropped_messages": omitted_messages, "notice": RECOVERY_NOTICE_CONTENT, }, ) def assistant_needs_reasoning_for_tool_context( message: dict[str, Any], prior_messages: list[dict[str, Any]], ) -> bool: if message.get("tool_calls"): return True for prior_message in reversed(prior_messages): role = prior_message.get("role") if role == "tool": return True if role in {"user", "system"}: return False return False def upstream_model_for(original_model: str, config: ProxyConfig) -> str: if original_model.startswith("deepseek-"): return original_model LOG.warning( "rewriting non-DeepSeek model %r to configured fallback %r", original_model, config.upstream_model, ) return config.upstream_model def reasoning_model_family(upstream_model: str) -> str: if upstream_model in {"deepseek-v4-pro", "deepseek-v4-flash"}: return "deepseek-v4" return upstream_model def reasoning_cache_namespace( config: ProxyConfig, upstream_model: str, thinking: Any, reasoning_effort: Any, authorization: str | None = None, ) -> str: auth_hash = "" if authorization: auth_hash = hashlib.sha256(authorization.encode("utf-8")).hexdigest() payload = { "base_url": config.upstream_base_url, "model": reasoning_model_family(upstream_model), "thinking": thinking, "reasoning_effort": reasoning_effort, "authorization_hash": auth_hash, } canonical = json.dumps( payload, ensure_ascii=False, sort_keys=True, separators=(",", ":") ) return hashlib.sha256(canonical.encode("utf-8")).hexdigest() def response_recording_contexts( *items: tuple[str, list[dict[str, Any]]] | None, ) -> list[tuple[str, list[dict[str, Any]]]]: contexts: list[tuple[str, list[dict[str, Any]]]] = [] seen: set[str] = set() for item in items: if item is None: continue scope, messages = item if scope in seen: continue seen.add(scope) contexts.append((scope, messages)) return contexts def prepare_upstream_request( payload: dict[str, Any], config: ProxyConfig, store: ReasoningStore | None, authorization: str | None = None, ) -> PreparedRequest: original_model = str(payload.get("model") or config.upstream_model) upstream_model = upstream_model_for(original_model, config) prepared = { key: value for key, value in payload.items() if key in SUPPORTED_REQUEST_FIELDS } dropped_fields = sorted( key for key in payload.keys() if key not in SUPPORTED_REQUEST_FIELDS and key not in {"max_completion_tokens", "functions", "function_call"} ) if dropped_fields: LOG.warning( "dropping unsupported request field(s): %s", ", ".join(dropped_fields) ) if "max_tokens" not in prepared and "max_completion_tokens" in payload: prepared["max_tokens"] = payload["max_completion_tokens"] prepared["model"] = upstream_model if prepared.get("stream"): stream_options = prepared.get("stream_options") if not isinstance(stream_options, dict): stream_options = {} else: stream_options = dict(stream_options) stream_options["include_usage"] = True prepared["stream_options"] = stream_options if "tools" in prepared and isinstance(prepared["tools"], list): prepared["tools"] = [normalize_tool(tool) for tool in prepared["tools"]] elif isinstance(payload.get("functions"), list): prepared["tools"] = [ legacy_function_to_tool(function) for function in payload["functions"] ] if "tool_choice" in prepared: tool_choice = normalize_tool_choice(prepared["tool_choice"]) if tool_choice is None: prepared.pop("tool_choice", None) else: prepared["tool_choice"] = tool_choice elif "function_call" in payload: tool_choice = convert_function_call(payload.get("function_call")) if tool_choice is not None: prepared["tool_choice"] = tool_choice prepared["thinking"] = {"type": config.thinking} thinking_enabled = config.thinking == "enabled" thinking_disabled = config.thinking == "disabled" if thinking_enabled: prepared["reasoning_effort"] = normalize_reasoning_effort( prepared.get("reasoning_effort") or config.reasoning_effort ) cache_namespace = reasoning_cache_namespace( config, upstream_model, prepared.get("thinking"), prepared.get("reasoning_effort"), authorization, ) pre_repair_messages, _, _, _ = normalize_messages( payload.get("messages"), None, cache_namespace, repair_reasoning=False, keep_reasoning=not thinking_disabled, ) record_response_messages = pre_repair_messages record_response_scope = conversation_scope( record_response_messages, cache_namespace ) messages_for_repair = pre_repair_messages continued_recovery_boundary = False retired_prefix_messages = 0 recovered_count = 0 recovery_dropped_messages = 0 recovery_notice = None recovery_steps: list[dict[str, Any]] = [] if thinking_enabled and config.missing_reasoning_strategy == "recover": boundary = active_messages_from_recovery_boundary(pre_repair_messages) if boundary is not None: messages_for_repair, retired_prefix_messages, boundary_step = boundary continued_recovery_boundary = True recovery_steps.append(boundary_step) messages, patched_count, missing_indexes, reasoning_diagnostics = ( normalize_messages( messages_for_repair, store, cache_namespace, repair_reasoning=thinking_enabled, keep_reasoning=not thinking_disabled, ) ) while missing_indexes and config.missing_reasoning_strategy == "recover": recovered_messages, dropped_messages, notice, recovery_step = ( recover_messages_from_missing_reasoning(messages, missing_indexes) ) recovery_steps.append(recovery_step) if not dropped_messages: break recovered_count += len(missing_indexes) recovery_dropped_messages += dropped_messages if notice: recovery_notice = notice ( messages, patched_count, missing_indexes, latest_diagnostics, ) = normalize_messages( recovered_messages, store, cache_namespace, repair_reasoning=thinking_enabled, keep_reasoning=not thinking_disabled, ) reasoning_diagnostics.extend(latest_diagnostics) active_record_response_scope = conversation_scope(messages, cache_namespace) record_response_contexts = response_recording_contexts( (record_response_scope, record_response_messages), (active_record_response_scope, messages), ) prepared["messages"] = strip_recovery_notice_for_upstream(messages) return PreparedRequest( payload=prepared, original_model=original_model, upstream_model=upstream_model, cache_namespace=cache_namespace, patched_reasoning_messages=patched_count, missing_reasoning_messages=len(missing_indexes), recovered_reasoning_messages=recovered_count, recovery_dropped_messages=recovery_dropped_messages, recovery_notice=recovery_notice, record_response_scope=record_response_scope, record_response_messages=record_response_messages, record_response_contexts=record_response_contexts, reasoning_diagnostics=reasoning_diagnostics, recovery_steps=recovery_steps, continued_recovery_boundary=continued_recovery_boundary, retired_prefix_messages=retired_prefix_messages, ) def record_response_reasoning( response_payload: dict[str, Any], store: ReasoningStore | None, request_messages: list[dict[str, Any]], cache_namespace: str = "", scope: str | None = None, prior_messages: list[dict[str, Any]] | None = None, recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None, ) -> int: if store is None: return 0 stored = 0 choices = response_payload.get("choices") if not isinstance(choices, list): return stored if recording_contexts is None: response_scope = ( scope if scope is not None else conversation_scope(request_messages, cache_namespace) ) response_prior_messages = ( prior_messages if prior_messages is not None else request_messages ) recording_contexts = [(response_scope, response_prior_messages)] for choice in choices: if not isinstance(choice, dict): continue message = choice.get("message") if isinstance(message, dict): for response_scope, response_prior_messages in recording_contexts: stored += store.store_assistant_message( message, response_scope, cache_namespace, response_prior_messages, ) return stored def rewrite_response_body( body: bytes, original_model: str, store: ReasoningStore | None, request_messages: list[dict[str, Any]], cache_namespace: str = "", content_prefix: str | None = None, scope: str | None = None, prior_messages: list[dict[str, Any]] | None = None, recording_contexts: list[tuple[str, list[dict[str, Any]]]] | None = None, display_reasoning: bool = False, collapsible_reasoning: bool = True, ) -> bytes: response_payload = json.loads(body.decode("utf-8")) if isinstance(response_payload, dict): if content_prefix: prefix_response_content(response_payload, content_prefix) record_response_reasoning( response_payload, store, request_messages, cache_namespace, scope=scope, prior_messages=prior_messages, recording_contexts=recording_contexts, ) if display_reasoning: fold_reasoning_into_content(response_payload, collapsible_reasoning) if "model" in response_payload: response_payload["model"] = original_model return json.dumps( response_payload, ensure_ascii=False, separators=(",", ":") ).encode("utf-8") def prefix_response_content(response_payload: dict[str, Any], prefix: str) -> bool: choices = response_payload.get("choices") if not isinstance(choices, list): return False for choice in choices: if not isinstance(choice, dict): continue message = choice.get("message") if not isinstance(message, dict): continue content = message.get("content") message["content"] = prefix + (content if isinstance(content, str) else "") return True return False