84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
from pydantic import BaseModel, Field, computed_field
|
|
from typing import Sequence, Optional
|
|
import re
|
|
|
|
class TaskPlanContext(BaseModel):
|
|
'''Background context relevant to the task plan'''
|
|
initial_context: str = Field(
|
|
min_length=1,
|
|
description="Background information about the database/tables and previous conversations relevant to the task.",
|
|
)
|
|
assumptions: Sequence[str] = Field(
|
|
description="Assumptions made while working on the task.",
|
|
)
|
|
constraints: Optional[Sequence[str]] = Field(
|
|
description="Constraints that apply to the task.",
|
|
)
|
|
|
|
class TaskPlanResponse(BaseModel):
|
|
'''Structured plan to achieve the task objective'''
|
|
goal: str = Field(
|
|
min_length=1,
|
|
description="Single-sentence objective the plan must achieve.",
|
|
)
|
|
reflection: str = Field(
|
|
min_length=1,
|
|
description="High-level natural-language reasoning describing the user's request and the intended solution approach.",
|
|
)
|
|
context: TaskPlanContext = Field(
|
|
description="Background context relevant to the task plan.",
|
|
)
|
|
steps: Sequence[str] = Field(
|
|
min_length=1,
|
|
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
|
)
|
|
|
|
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
|
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
|
_FORBIDDEN_MODULES = (
|
|
"subprocess",
|
|
"sys",
|
|
"eval",
|
|
"exec",
|
|
"socket",
|
|
"urllib",
|
|
"shutil",
|
|
"pickle",
|
|
"ctypes",
|
|
"multiprocessing",
|
|
"tempfile",
|
|
"glob",
|
|
"pty",
|
|
"commands",
|
|
"cgi",
|
|
"cgitb",
|
|
"xml.etree.ElementTree",
|
|
"builtins",
|
|
)
|
|
_FORBIDDEN_MODULE_PATTERN = re.compile(
|
|
r"^((?:[^#].*)?\b(" + "|".join(map(re.escape, _FORBIDDEN_MODULES)) + r")\b.*)$",
|
|
flags=re.MULTILINE,
|
|
)
|
|
|
|
class CodeGenerationResponse(BaseModel):
|
|
'''Code generation response structure'''
|
|
code: str = Field(description="The generated code snippet to accomplish the task")
|
|
explanation: str = Field(description="Explanation of the generated code and its functionality")
|
|
|
|
@computed_field(return_type=str)
|
|
@property
|
|
def parsed_code(self) -> str:
|
|
'''Extracts the code snippet without any surrounding text'''
|
|
normalised = _IM_SEP_TOKEN_PATTERN.sub("```", self.code).strip()
|
|
match = _CODE_BLOCK_PATTERN.search(normalised)
|
|
candidate = match.group(1).strip() if match else normalised
|
|
sanitised = _FORBIDDEN_MODULE_PATTERN.sub(r"# not allowed \1", candidate)
|
|
return sanitised.strip()
|
|
|
|
class RankResponse(BaseModel):
|
|
'''Code ranking response structure'''
|
|
rank: int = Field(
|
|
ge=1, le=10,
|
|
description="Rank of the code snippet from 1 (best) to 10 (worst)"
|
|
)
|