from pydantic import BaseModel, Field, computed_field from typing import Sequence, Optional import re class TaskPlanContext(BaseModel): '''Background context relevant to the task plan''' initial_context: str = Field( min_length=1, description="Background information about the database/tables and previous conversations relevant to the task.", ) assumptions: Sequence[str] = Field( description="Assumptions made while working on the task.", ) constraints: Optional[Sequence[str]] = Field( description="Constraints that apply to the task.", ) class TaskPlanResponse(BaseModel): '''Structured plan to achieve the task objective''' goal: str = Field( min_length=1, description="Single-sentence objective the plan must achieve.", ) reflection: str = Field( min_length=1, description="High-level natural-language reasoning describing the user's request and the intended solution approach.", ) context: TaskPlanContext = Field( description="Background context relevant to the task plan.", ) steps: Sequence[str] = Field( min_length=1, description="Ordered list of steps to execute that follow the 'Step : ' pattern.", ) _IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>")) _CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL) _FORBIDDEN_MODULES = ( "subprocess", "sys", "eval", "exec", "socket", "urllib", "shutil", "pickle", "ctypes", "multiprocessing", "tempfile", "glob", "pty", "commands", "cgi", "cgitb", "xml.etree.ElementTree", "builtins", ) _FORBIDDEN_MODULE_PATTERN = re.compile( r"^((?:[^#].*)?\b(" + "|".join(map(re.escape, _FORBIDDEN_MODULES)) + r")\b.*)$", flags=re.MULTILINE, ) class CodeGenerationResponse(BaseModel): '''Code generation response structure''' code: str = Field(description="The generated code snippet to accomplish the task") explanation: str = Field(description="Explanation of the generated code and its functionality") @computed_field(return_type=str) @property def parsed_code(self) -> str: '''Extracts the code snippet without any surrounding text''' normalised = _IM_SEP_TOKEN_PATTERN.sub("```", self.code).strip() match = _CODE_BLOCK_PATTERN.search(normalised) candidate = match.group(1).strip() if match else normalised sanitised = _FORBIDDEN_MODULE_PATTERN.sub(r"# not allowed \1", candidate) return sanitised.strip() class RankResponse(BaseModel): '''Code ranking response structure''' rank: int = Field( ge=1, le=10, description="Rank of the code snippet from 1 (best) to 10 (worst)" )