Refactor: Move backend files to backend/ directory and split .gitignore

This commit is contained in:
Yunxiao Xu
2026-02-11 17:40:44 -08:00
parent 48924affa0
commit 7a69133e26
96 changed files with 144 additions and 176 deletions

View File

@@ -0,0 +1,12 @@
from .db_client import DBClient
from .llm_factory import get_llm_model
from .logging import get_logger, LangChainLoggingHandler
from . import helpers
__all__ = [
"DBClient",
"get_llm_model",
"get_logger",
"LangChainLoggingHandler",
"helpers"
]

View File

@@ -0,0 +1,234 @@
from typing import Optional, Dict, Any, List, TYPE_CHECKING
import yaml
import json
import os
from ea_chatbot.utils.db_client import DBClient
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
def _get_table_checksum(db_client: DBClient, table: str) -> str:
"""Calculates the checksum of the table using DML statistics from pg_stat_user_tables."""
query = f"""
SELECT md5(concat_ws('|', n_tup_ins, n_tup_upd, n_tup_del)) AS dml_hash
FROM pg_stat_user_tables
WHERE schemaname = 'public' AND relname = '{table}';"""
try:
return str(db_client.query_df(query).iloc[0, 0])
except Exception:
return "unknown_checksum"
def _update_checksum_file(filepath: str, table: str, checksum: str):
"""Updates the checksum file with the new checksum for the table."""
checksums = {}
if os.path.exists(filepath):
with open(filepath, 'r') as f:
for line in f:
if ':' in line:
k, v = line.strip().split(':', 1)
checksums[k] = v
checksums[table] = checksum
with open(filepath, 'w') as f:
for k, v in checksums.items():
f.write(f"{k}:{v}")
def get_data_summary(data_dir: str = "data") -> Optional[str]:
"""
Reads the inspection.yaml file and returns its content as a string.
"""
inspection_file = os.path.join(data_dir, "inspection.yaml")
if os.path.exists(inspection_file):
with open(inspection_file, 'r') as f:
return f.read()
return None
def get_primary_key(db_client: DBClient, table_name: str) -> Optional[str]:
"""
Dynamically identifies the primary key of the table.
Returns the column name of the primary key, or None if not found.
"""
query = f"""
SELECT kcu.column_name
FROM information_schema.key_column_usage AS kcu
JOIN information_schema.table_constraints AS tc
ON kcu.constraint_name = tc.constraint_name
AND kcu.table_schema = tc.table_schema
WHERE kcu.table_name = '{table_name}'
AND tc.constraint_type = 'PRIMARY KEY'
LIMIT 1;
"""
try:
df = db_client.query_df(query)
if not df.empty:
return str(df.iloc[0, 0])
except Exception as e:
print(f"Warning: Could not determine primary key for {table_name}: {e}")
return None
def inspect_db_table(
db_client: Optional[DBClient]=None,
db_settings: Optional["DBSettings"]=None,
data_dir: str = "data",
force_update: bool = False
) -> Optional[str]:
"""
Inspects the database table, generates statistics for each column,
and saves the inspection results to a YAML file locally.
Improvements:
- Dynamic Primary Key Discovery
- Cardinality (Unique Counts)
- Categorical Sample Values for low cardinality columns
- Robust Quoting
"""
inspection_file = os.path.join(data_dir, "inspection.yaml")
checksum_file = os.path.join(data_dir, "checksum")
# Initialize DB Client
if db_client is None:
if db_settings is None:
print("Error: Either db_client or db_settings must be provided.")
return None
try:
db_client = DBClient(db_settings)
except Exception as e:
print(f"Failed to create DBClient: {e}")
return None
table_name = db_client.settings.get('table')
if not table_name:
print("Error: Table name must be specified in DBSettings.")
return None
os.makedirs(data_dir, exist_ok=True)
# Checksum verification
new_checksum = _get_table_checksum(db_client, table_name)
has_changed = True
if os.path.exists(checksum_file):
try:
with open(checksum_file, 'r') as f:
saved_checksums = f.read().strip()
if f"{table_name}:{new_checksum}" in saved_checksums:
has_changed = False
except Exception:
pass # Force update on read error
if not has_changed and not force_update:
return get_data_summary(data_dir)
print(f"Regenerating inspection file for table '{table_name}'...")
# Fetch Table Metadata
try:
# Get columns and types
columns_query = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
columns_df = db_client.query_df(columns_query)
# Get Row Counts
total_rows_df = db_client.query_df(f'SELECT COUNT(*) FROM "{table_name}"')
total_rows = int(total_rows_df.iloc[0, 0])
# Dynamic Primary Key
primary_key = get_primary_key(db_client, table_name)
# Get First/Last Rows (if PK exists)
first_row_df = None
last_row_df = None
if primary_key:
first_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" ASC LIMIT 1')
last_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" DESC LIMIT 1')
except Exception as e:
print(f"Failed to retrieve basic table info: {e}")
return None
stats_dict: Dict[str, Any] = {}
if primary_key:
stats_dict['primary_key'] = primary_key
for _, row in columns_df.iterrows():
col_name = row['column_name']
dtype = row['data_type']
try:
# Count Values
# Using robust quoting
count_df = db_client.query_df(f'SELECT COUNT("{col_name}") FROM "{table_name}"')
count_val = int(count_df.iloc[0,0])
# Count Unique (Cardinality)
unique_df = db_client.query_df(f'SELECT COUNT(DISTINCT "{col_name}") FROM "{table_name}"')
unique_count = int(unique_df.iloc[0,0])
col_stats: Dict[str, Any] = {
'dtype': dtype,
'count_of_values': count_val,
'count_of_nulls': total_rows - count_val,
'unique_count': unique_count
}
if count_val == 0:
stats_dict[col_name] = col_stats
continue
# Numerical Stats
if any(t in dtype for t in ('int', 'float', 'numeric', 'double', 'real', 'decimal')):
stats_query = f'SELECT AVG("{col_name}"), MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
stats_df = db_client.query_df(stats_query)
if not stats_df.empty:
col_stats['mean'] = float(stats_df.iloc[0,0]) if stats_df.iloc[0,0] is not None else None
col_stats['min'] = float(stats_df.iloc[0,1]) if stats_df.iloc[0,1] is not None else None
col_stats['max'] = float(stats_df.iloc[0,2]) if stats_df.iloc[0,2] is not None else None
# Temporal Stats
elif any(t in dtype for t in ('date', 'timestamp')):
stats_query = f'SELECT MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
stats_df = db_client.query_df(stats_query)
if not stats_df.empty:
col_stats['min'] = str(stats_df.iloc[0,0])
col_stats['max'] = str(stats_df.iloc[0,1])
# Categorical/Text Stats
else:
# Sample values if cardinality is low (< 20)
if 0 < unique_count < 20:
distinct_query = f'SELECT DISTINCT "{col_name}" FROM "{table_name}" ORDER BY "{col_name}" LIMIT 20'
distinct_df = db_client.query_df(distinct_query)
col_stats['distinct_values'] = distinct_df.iloc[:, 0].tolist()
if first_row_df is not None and not first_row_df.empty and col_name in first_row_df.columns:
col_stats['first_value'] = str(first_row_df.iloc[0][col_name])
if last_row_df is not None and not last_row_df.empty and col_name in last_row_df.columns:
col_stats['last_value'] = str(last_row_df.iloc[0][col_name])
stats_dict[col_name] = col_stats
except Exception as e:
print(f"Warning: Could not process column {col_name}: {e}")
# Load existing inspections to merge (if multiple tables)
existing_inspections = {}
if os.path.exists(inspection_file):
try:
with open(inspection_file, 'r') as f:
existing_inspections = yaml.safe_load(f) or {}
# Backup old file
os.rename(inspection_file, inspection_file + ".old")
except Exception:
pass
existing_inspections[table_name] = stats_dict
# Save new inspection
inspection_content = yaml.dump(existing_inspections, sort_keys=False, default_flow_style=False)
with open(inspection_file, 'w') as f:
f.write(inspection_content)
# Update Checksum
_update_checksum_file(checksum_file, table_name, new_checksum)
print(f"Inspection saved to {inspection_file}")
return inspection_content

View File

@@ -0,0 +1,21 @@
from typing import Optional, TYPE_CHECKING
import pandas as pd
from sqlalchemy import create_engine, text
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
class DBClient:
def __init__(self, settings: "DBSettings"):
self.settings = settings
self._engine = self._create_engine()
def _create_engine(self):
url = f"postgresql://{self.settings['user']}:{self.settings['pswd']}@{self.settings['host']}:{self.settings['port']}/{self.settings['db']}"
return create_engine(url)
def query_df(self, sql: str, params: Optional[dict] = None) -> pd.DataFrame:
with self._engine.connect() as conn:
result = conn.execute(text(sql), params or {})
df = pd.DataFrame(result.fetchall(), columns=result.keys())
return df

View File

@@ -0,0 +1,73 @@
from typing import Optional, TYPE_CHECKING, Dict, Any
from datetime import datetime, timezone
import yaml
import json
if TYPE_CHECKING:
from ea_chatbot.graph.state import AgentState
def ordinal(n: int) -> str:
return f"{n}{'th' if 11<=n<=13 else {1:'st',2:'nd',3:'rd'}.get(n%10, 'th')}"
def get_readable_date(date_obj: Optional[datetime] = None, tz: Optional[timezone] = None) -> str:
if date_obj is None:
date_obj = datetime.now(timezone.utc)
if tz:
date_obj = date_obj.astimezone(tz)
return date_obj.strftime(f"%a {ordinal(date_obj.day)} of %b %Y")
def to_yaml(json_str: str, indent: int = 2) -> str:
"""
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
"""
if not json_str: return ""
try:
# Try direct parse
data = json.loads(json_str)
except json.JSONDecodeError:
# Try simplified repair: replace single quotes
try:
cleaned = json_str.replace("'", '"')
data = json.loads(cleaned)
except Exception:
# Fallback: return raw string if unparseable
return json_str
return yaml.dump(data, indent=indent, sort_keys=False)
def merge_agent_state(current_state: "AgentState", update: Dict[str, Any]) -> "AgentState":
"""
Merges a partial state update into the current state, mimicking LangGraph reduction logic.
- Lists (messages, plots) are appended.
- Dictionaries (dfs) are shallow merged.
- Other fields are overwritten.
"""
new_state = current_state.copy()
for key, value in update.items():
if value is None:
new_state[key] = None
continue
# Accumulate lists (messages, plots)
if key in ["messages", "plots"] and isinstance(value, list):
current_list = new_state.get(key, [])
if not isinstance(current_list, list):
current_list = []
new_state[key] = current_list + value
# Shallow merge dictionaries (dfs)
elif key == "dfs" and isinstance(value, dict):
current_dict = new_state.get(key, {})
if not isinstance(current_dict, dict):
current_dict = {}
merged_dict = current_dict.copy()
merged_dict.update(value)
new_state[key] = merged_dict
# Overwrite everything else
else:
new_state[key] = value
return new_state

View File

@@ -0,0 +1,36 @@
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.callbacks import BaseCallbackHandler
from ea_chatbot.config import LLMConfig
def get_llm_model(config: LLMConfig, callbacks: Optional[List[BaseCallbackHandler]] = None) -> BaseChatModel:
"""
Factory function to get a LangChain chat model based on configuration.
Args:
config: LLMConfig object containing model settings.
callbacks: Optional list of LangChain callback handlers.
Returns:
Initialized BaseChatModel instance.
Raises:
ValueError: If the provider is not supported.
"""
params = {
"temperature": config.temperature,
"max_tokens": config.max_tokens,
**config.provider_specific
}
# Filter out None values to allow defaults to take over if not specified
params = {k: v for k, v in params.items() if v is not None}
if config.provider.lower() == "openai":
return ChatOpenAI(model=config.model, callbacks=callbacks, **params)
elif config.provider.lower() == "google" or config.provider.lower() == "google_genai":
return ChatGoogleGenerativeAI(model=config.model, callbacks=callbacks, **params)
else:
raise ValueError(f"Unsupported LLM provider: {config.provider}")

View File

@@ -0,0 +1,141 @@
import json
import os
import logging
from rich.logging import RichHandler
from langchain_core.callbacks import BaseCallbackHandler
from logging.handlers import RotatingFileHandler
from typing import Any, Optional, Dict, List
class LangChainLoggingHandler(BaseCallbackHandler):
"""Callback handler for logging LangChain events."""
def __init__(self, logger: Optional[logging.Logger] = None):
self.logger = logger or get_logger("langchain")
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
# Serialized might be empty or missing name depending on how it's called
model_name = serialized.get("name") or kwargs.get("name") or "LLM"
self.logger.info(f"[bold blue]LLM Started:[/bold blue] {model_name}")
def on_llm_end(self, response: Any, **kwargs: Any) -> Any:
llm_output = getattr(response, "llm_output", {}) or {}
# Try to find model name in output or use fallback
model_name = llm_output.get("model_name") or "LLM"
token_usage = llm_output.get("token_usage", {})
msg = f"[bold green]LLM Ended:[/bold green] {model_name}"
if token_usage:
prompt = token_usage.get("prompt_tokens", 0)
completion = token_usage.get("completion_tokens", 0)
total = token_usage.get("total_tokens", 0)
msg += f" | [yellow]Tokens: {total}[/yellow] ({prompt} prompt, {completion} completion)"
self.logger.info(msg)
def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> Any:
self.logger.error(f"[bold red]LLM Error:[/bold red] {str(error)}")
class ContextLoggerAdapter(logging.LoggerAdapter):
"""Adapter to inject contextual metadata into log records."""
def process(self, msg: Any, kwargs: Any) -> tuple[Any, Any]:
extra = self.extra.copy()
if "extra" in kwargs:
extra.update(kwargs.pop("extra"))
kwargs["extra"] = extra
return msg, kwargs
class FlexibleJSONEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if hasattr(obj, 'model_dump'): # Pydantic v2
return obj.model_dump()
if hasattr(obj, 'dict'): # Pydantic v1
return obj.dict()
if hasattr(obj, '__dict__'):
return self.serialize_custom_object(obj)
elif isinstance(obj, dict):
return {k: self.default(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.default(item) for item in obj]
return super().default(obj)
def serialize_custom_object(self, obj: Any) -> dict:
obj_dict = obj.__dict__.copy()
obj_dict['__custom_class__'] = obj.__class__.__name__
return obj_dict
class JsonFormatter(logging.Formatter):
"""Custom JSON formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
# Standard fields
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"level": record.levelname,
"message": record.getMessage(),
"module": record.module,
"name": record.name,
}
# Add exception info if present
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
# Add all other extra fields from the record
# Filter out standard logging attributes
standard_attrs = {
'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
'funcName', 'levelname', 'levelno', 'lineno', 'module',
'msecs', 'message', 'msg', 'name', 'pathname', 'process',
'processName', 'relativeCreated', 'stack_info', 'thread', 'threadName'
}
for key, value in record.__dict__.items():
if key not in standard_attrs:
log_record[key] = value
return json.dumps(log_record, cls=FlexibleJSONEncoder)
def get_logger(name: str = "ea_chatbot", level: Optional[str] = None, log_file: Optional[str] = None) -> logging.Logger:
"""Get a configured logger with RichHandler and optional Json FileHandler."""
# Ensure name starts with ea_chatbot for hierarchy if not already
if name != "ea_chatbot" and not name.startswith("ea_chatbot."):
full_name = f"ea_chatbot.{name}"
else:
full_name = name
logger = logging.getLogger(full_name)
# Configure root ea_chatbot logger if it hasn't been configured
root_logger = logging.getLogger("ea_chatbot")
if not root_logger.handlers:
# Default to INFO if level not provided
log_level = getattr(logging, (level or "INFO").upper(), logging.INFO)
root_logger.setLevel(log_level)
# Console Handler (Rich)
rich_handler = RichHandler(
rich_tracebacks=True,
markup=True,
show_time=False,
show_path=False
)
root_logger.addHandler(rich_handler)
root_logger.propagate = False
# Always check if we need to add a FileHandler, even if root is already configured
if log_file:
existing_file_handlers = [h for h in root_logger.handlers if isinstance(h, RotatingFileHandler)]
if not existing_file_handlers:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
file_handler = RotatingFileHandler(
log_file, maxBytes=5*1024*1024, backupCount=3
)
file_handler.setFormatter(JsonFormatter())
root_logger.addHandler(file_handler)
# Refresh logger object in case it was created before root was configured
logger = logging.getLogger(full_name)
# If level is explicitly provided for a sub-logger, set it
if level:
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
return logger