__all__ = [
"SQLResponse",
"SQLErrorResponse",
"DatabaseErrorHandler",
"Database",
"table_display",
]
from .db_utils import *
from ..basic.request_utils import NetworkProxy
from ..basic.log_utils import get_logger
from ..basic.debug_utils import error_str, raise_mismatch, DatabaseError
from typing import Iterable, List, Dict, Tuple, Any, Union, Optional, Literal, Generator
from copy import deepcopy
from ..deps import deps
_sa = None
_prettytable = None
def get_sa():
global _sa
if _sa is None:
_sa = deps.load("sqlalchemy")
return _sa
def get_sa_elements():
return deps.load("sqlalchemy.sql.elements")
def get_sa_schema():
return deps.load("sqlalchemy.schema")
_prettytable = None
def get_prettytable():
global _prettytable
if _prettytable is None:
_prettytable = deps.load("prettytable")
return _prettytable
logger = get_logger(__name__)
[docs]
class SQLResponse:
"""\
Enhanced result wrapper for SQLAlchemy CursorResult with convenient data access methods.
"""
[docs]
def __init__(self, cursor_result):
self._result = cursor_result
self._fetched_data = None
self._columns = None
self._row_count = None
self._lastrowid = None
self._fetch()
@property
def raw(self):
"""\
Access the underlying SQLAlchemy CursorResult.
Returns:
CursorResult: The underlying SQLAlchemy CursorResult.
Warning:
When a connection is closed, the cursor result may no longer be available.
"""
return self._result
@property
def columns(self) -> List[str]:
"""\
Get column names from the result.
Returns:
List[str]: The list of column names. If the result is not available, returns an empty list.
"""
return deepcopy(self._columns)
def _fetch(self):
if self._fetched_data is None:
try:
self._fetched_data = list()
rows = self._result.fetchall()
for row in rows:
if hasattr(row, "_mapping"):
self._fetched_data.append(dict(row._mapping))
else:
self._fetched_data.append(dict(zip(self.columns, row)))
except Exception as e:
logger.debug(f"Failed to fetch result (likely DDL operation): {error_str(e)}")
self._fetched_data = list()
if self._columns is None:
try:
self._columns = list(self._result.keys()) if hasattr(self._result, "keys") else list()
except Exception as e:
logger.debug(f"Failed to get columns (likely DDL operation): {error_str(e)}")
self._columns = list()
if self._row_count is None:
try:
self._row_count = getattr(self._result, "row_count", -1)
except Exception as e:
logger.debug(f"Failed to get row count: {error_str(e)}")
self._row_count = -1
if self._lastrowid is None:
try:
self._lastrowid = getattr(self._result, "lastrowid", None)
except Exception as e:
logger.debug(f"Failed to get last row ID: {error_str(e)}")
self._lastrowid = None
@property
def row_count(self) -> int:
"""\
Get the number of affected rows.
Returns:
int: The number of affected rows. If the result is not available, returns -1.
"""
return self._row_count
@property
def lastrowid(self) -> Optional[int]:
"""\
Get the last inserted row ID.
Returns:
Optional[int]: The last inserted row ID. If the result is not available, returns None.
"""
return self._lastrowid
[docs]
def fetchall(self) -> Generator[Dict[str, Any], None, None]:
"""\
Fetch all rows as a list of dictionaries.
Yields:
Dict[str, Any]: The next row as a dictionary.
"""
yield from self._fetched_data
return
def _get_col_enums(self, row: Dict[str, Any], column_spec: Union[str, int]) -> Any:
"""Extract column value from row by name or index."""
if isinstance(column_spec, str):
if column_spec not in row:
raise ValueError(f"Column '{column_spec}' not found in row")
return row[column_spec]
elif isinstance(column_spec, int):
row_values = tuple(row.values())
if not (-len(row_values) <= column_spec < len(row_values)):
raise ValueError(f"Column index {column_spec} out of range for row with {len(row_values)} columns")
return row_values[column_spec]
else:
raise ValueError(f"Invalid column specification: {column_spec}")
def __getitem__(self, idx: Union[int, slice, Tuple[Union[int, slice], Union[int, str]]]) -> Any:
if isinstance(idx, (slice, int)):
return self._fetched_data[idx]
if isinstance(idx, tuple) and len(idx) == 2:
row_spec, col_spec = idx
if isinstance(row_spec, int):
row = self._fetched_data[row_spec]
return self._get_col_enums(row, col_spec)
elif isinstance(row_spec, slice):
rows = self._fetched_data[row_spec]
return [self._get_col_enums(row, col_spec) for row in rows]
else:
raise ValueError(f"Invalid row specification: {row_spec}")
raise ValueError(f"Invalid index: {idx}")
[docs]
def __len__(self) -> int:
"""\
Get the number of rows in the result.
Returns:
int: The number of rows in the result.
"""
return len(self._fetched_data)
[docs]
def to_list(self, row_fmt: Literal["dict", "tuple"] = "dict") -> Union[List[Tuple], List[Dict[str, Any]]]:
"""\
Convert result to list of tuples.
Args:
row_fmt (Literal['dict', 'tuple']): The format for the rows.
Returns:
Union[List[Tuple], List[Dict[str, Any]]]: The result as a list of tuples or dictionaries.
"""
if row_fmt == "dict":
return deepcopy(self._fetched_data)
if row_fmt == "tuple":
return [tuple(row.values()) for row in self._fetched_data]
raise_mismatch(["dict", "tuple"], got=row_fmt, name="row format")
[docs]
def close(self):
"""\
Close the result cursor.
"""
try:
self._result.close()
except Exception as e:
logger.warning(f"Failed to close result cursor: {error_str(e)}")
[docs]
class DatabaseErrorHandler:
"""\
Extensible handler for database errors with type-specific processing.
This class provides a clean way to handle different types of database errors,
extract relevant information, and provide helpful suggestions to users.
"""
[docs]
def __init__(self, db: Optional["Database"] = None):
"""\
Initialize the error handler.
Args:
db (Database, optional): Database instance for context-aware suggestions.
"""
self.db = db
self._handlers = {}
self._register_default_handlers()
def _register_default_handlers(self):
"""\
Register default error handlers for common SQLAlchemy exceptions.
"""
import sqlalchemy.exc as sa_exc
self.register_handler(sa_exc.OperationalError, self._handle_operational_error)
self.register_handler(sa_exc.ProgrammingError, self._handle_programming_error)
self.register_handler(sa_exc.IntegrityError, self._handle_integrity_error)
self.register_handler(sa_exc.DataError, self._handle_data_error)
[docs]
def register_handler(self, exception_type: type, handler_func):
"""\
Register a custom handler for a specific exception type.
Args:
exception_type (type): The exception type to handle.
handler_func (callable): Function that takes (exception, query, params) and returns (error_type, short_message).
"""
self._handlers[exception_type] = handler_func
def _add_suggestions(self, error_msg: str, pattern: str, get_options_func) -> str:
"""Add suggestions to error message (matching raise_mismatch format, excluding first line)."""
import re
from difflib import SequenceMatcher
match = re.search(pattern, error_msg, re.IGNORECASE)
if not match or not self.db:
return error_msg
item_name = match.group(1)
try:
options = get_options_func()
# Find best match using same logic as raise_mismatch
def similarity(a, b):
return SequenceMatcher(None, str(a), str(b)).ratio()
sorted_options = sorted(options, key=lambda x: similarity(item_name, x), reverse=True)
suggestion = sorted_options[0] if sorted_options and similarity(item_name, sorted_options[0]) >= 0.3 else None
# Build message lines (same as raise_mismatch, but skip first line)
if suggestion:
lines = [f"Did you mean '{suggestion}'?"]
else:
lines = []
lines.append(f"Available options: {', '.join(repr(opt) for opt in options)}.")
return error_msg + "\n" + "\n".join(lines)
except Exception:
pass
return error_msg
def _handle_operational_error(self, e: Exception, query: Optional[str], params: Optional[Any]) -> tuple:
"""Handle SQLAlchemy OperationalError exceptions."""
orig = getattr(e, "orig", None)
error_msg = str(orig) if orig else str(e)
# Table not found
if "no such table" in error_msg.lower():
msg = self._add_suggestions(error_msg, r"no such table:\s*(\w+)", self.db.db_tabs if self.db else lambda: [])
return "TableNotFound", msg
# Column not found (no suggestions - would need table context)
if "no such column" in error_msg.lower():
return "ColumnNotFound", error_msg
# Database locked
if "database is locked" in error_msg.lower():
return "DatabaseLocked", error_msg
# Disk I/O error
if "disk i/o error" in error_msg.lower():
return "DiskIOError", error_msg
return "OperationalError", error_msg
def _handle_programming_error(self, e: Exception, query: Optional[str], params: Optional[Any]) -> tuple:
"""Handle SQLAlchemy ProgrammingError exceptions (typically syntax errors)."""
orig = getattr(e, "orig", None)
error_msg = str(orig) if orig else str(e)
if "syntax error" in error_msg.lower() or "parser error" in error_msg.lower():
return "SyntaxError", error_msg
return "ProgrammingError", error_msg
def _handle_integrity_error(self, e: Exception, query: Optional[str], params: Optional[Any]) -> tuple:
"""Handle SQLAlchemy IntegrityError exceptions (constraint violations)."""
orig = getattr(e, "orig", None)
error_msg = str(orig) if orig else str(e)
if "foreign key constraint" in error_msg.lower():
return "ForeignKeyViolation", error_msg
if "unique constraint" in error_msg.lower() or "not unique" in error_msg.lower():
return "UniqueViolation", error_msg
if "not null constraint" in error_msg.lower() or "may not be null" in error_msg.lower():
return "NotNullViolation", error_msg
return "IntegrityError", error_msg
def _handle_data_error(self, e: Exception, query: Optional[str], params: Optional[Any]) -> tuple:
"""Handle SQLAlchemy DataError exceptions (data type/value issues)."""
orig = getattr(e, "orig", None)
return "DataError", str(orig) if orig else str(e)
[docs]
def handle(self, e: Exception, query: Optional[str] = None, params: Optional[Any] = None) -> tuple:
"""
Handle a database exception and extract structured error information.
Returns:
tuple: (error_type, short_message, full_message) extracted from the exception.
"""
orig = getattr(e, "orig", None)
full_msg = str(orig) if orig else str(e)
# Find and use the appropriate handler
for exc_type, handler in self._handlers.items():
if isinstance(e, exc_type):
error_type, short_msg = handler(e, query, params)
return error_type, short_msg, full_msg
# Fallback for unhandled exception types
return "UnknownError", str(e), full_msg
[docs]
class SQLErrorResponse:
"""\
Structured error response for database operation failures.
This class provides a clean, structured way to return error information
from database operations, making it easier for LLMs and tools to handle
and present errors to users.
"""
[docs]
def __init__(
self,
error_type: str,
short_message: str,
full_message: str,
query: Optional[str] = None,
params: Optional[Union[Dict[str, Any], List, Tuple]] = None,
):
"""\
Initialize a SQL error response.
Args:
error_type (str): The type/category of error (e.g., "TableNotFound", "SyntaxError").
short_message (str): A brief, human-readable error message.
full_message (str): The complete error message with traceback.
query (str, optional): The SQL query that caused the error.
params (Union[Dict, List, Tuple], optional): The parameters used with the query.
"""
self.error_type = error_type
self.short_message = short_message
self.full_message = full_message
self.query = query
self.params = params
[docs]
def to_string(self, include_full: bool = False) -> str:
"""\
Format the error as a user-friendly string.
Args:
include_full (bool): Whether to include the full original error message. Defaults to False.
Returns:
str: Formatted error message.
"""
lines = [
"Database query execution failed.",
f"Error Type: {self.error_type}",
f"Error: {self.short_message}",
]
if self.query:
lines.append(f"Query: {self.query}")
if self.params:
lines.append(f"Params: {self.params}")
if include_full and self.full_message != self.short_message:
lines.append(f"Original Error: {self.full_message}")
return "\n".join(lines)
def __str__(self) -> str:
return self.to_string(include_full=False)
def __repr__(self) -> str:
return f"SQLErrorResponse(error_type={self.error_type!r}, short_message={self.short_message!r})"
[docs]
class Database(object):
"""\
Universal Database Connector
Provides a clean, intuitive interface for database operations across different providers
(SQLite, PostgreSQL, DuckDB, MySQL) with standard connection management:
1. **Basic Usage**:
```python
db = Database(provider="sqlite", database=":memory:")
result = db.execute("SELECT * FROM table")
```
2. **Context Manager** (recommended for transactions):
```python
with Database(provider="pg", database="mydb") as db:
db.execute("INSERT INTO users (name) VALUES (:name)", params={"name": "Alice"})
db.execute("UPDATE users SET active = TRUE WHERE name = :name", params={"name": "Alice"})
# Automatically commits on success, rolls back on exception
```
3. **Manual Transaction Control**:
```python
db = Database(provider="sqlite", database="mydb")
try:
db.execute("INSERT INTO users (name) VALUES (:name)", params={"name": "Bob"}, autocommit=False)
db.execute("UPDATE users SET active = TRUE WHERE name = :name)", params={"name": "Bob"}, autocommit=False)
db.commit()
except Exception:
db.rollback()
finally:
db.close_conn()
```
The class automatically handles:
- Database creation (PostgreSQL auto-creation if database doesn't exist)
- Connection lifecycle management
- SQL transpilation between different database dialects
"""
[docs]
def __init__(
self,
database: Optional[str] = None,
provider: Optional[str] = None,
pool: Optional[Dict[str, Any]] = None,
connect: bool = False,
**kwargs,
):
"""\
Initialize database connection.
Args:
database: Database name or path (':memory:' for in-memory)
provider: Database provider ('sqlite', 'pg', 'duckdb', etc.)
pool: Pool configuration to override provider defaults (e.g., {'pool_size': 10})
connect: Whether to establish a connection immediately (default: False)
**kwargs: Additional connection parameters
"""
super().__init__()
self.config, self.config_conn_args = resolve_db_config(database=database, provider=provider, pool=pool, **kwargs)
self.dialect = self.config_conn_args.get("dialect", None)
self.proxy = NetworkProxy(
http_proxy=self.config.pop("http_proxy", None),
https_proxy=self.config.pop("https_proxy", None),
)
self.sql_processor = SQLProcessor(self.dialect)
self._in_context_manager = False
self._init(connect=connect)
def __post_init__(self): # ad-hoc fix for dataclass
if self.dialect == "sqlite":
self.execute("PRAGMA foreign_keys = ON;", autocommit=True)
def _init(self, connect: bool = False):
self.engine = create_database_engine(self.config, conn_args=self.config_conn_args)
self._conn = None
if connect:
self.connect()
[docs]
def clone(self) -> "Database":
"""\
Create an independent Database instance with the same configuration.
Each clone has its own connection, making it safe for parallel operations
where each worker needs an independent database connection.
Warning:
For in-memory databases (`:memory:`), cloned instances do NOT share data.
Each clone gets its own separate in-memory database.
Use file-based databases for parallel operations requiring shared state.
Returns:
Database: A new independent Database instance.
Example:
```python
# Parallel-safe pattern
def worker(db_template, task_id):
db = db_template.clone() # Each worker gets own connection
try:
result = db.execute("SELECT * FROM tasks WHERE id = :id", params={"id": task_id})
return result.to_list()
finally:
db.close()
# Use with threading/multiprocessing
with ThreadPoolExecutor() as executor:
futures = [executor.submit(worker, db, i) for i in range(10)]
```
"""
database = self.config_conn_args.get("database", "")
if database == ":memory:":
logger.warning(
"Cloning an in-memory database - cloned instances will NOT share data. "
"Each clone has its own separate in-memory database. "
"Use a file-based database for parallel operations requiring shared state."
)
# Extract original parameters from config_conn_args
return Database(
provider=self.config_conn_args.get("provider"),
database=database,
pool=self.config_conn_args.get("pool"),
connect=False,
**{k: v for k, v in self.config_conn_args.items() if k not in ["database", "provider", "pool", "dialect", "driver", "url"]},
)
[docs]
def connect(self):
"""\
Establish a database connection.
The connection pool (configured per dialect) handles:
- Stale connection detection via pool_pre_ping
- Connection recycling to prevent timeouts
- Thread-safe connection management
Returns:
Connection: The SQLAlchemy connection object
"""
if self._conn is None or self._conn.closed:
with self.proxy:
self._conn = self.engine.connect()
return self._conn
[docs]
def close_conn(self, commit: bool = True):
"""\
Close the database connection and return it to the pool.
Args:
commit: Whether to commit pending transaction before closing.
"""
if self._conn is not None:
try:
if self._conn.in_transaction():
if commit:
self._conn.commit()
else:
self._conn.rollback()
except Exception as e:
logger.debug(f"Transaction cleanup during close: {error_str(e, tb=False)}")
try:
self._conn.close()
except Exception as e:
logger.debug(f"Connection close: {error_str(e, tb=False)}")
finally:
self._conn = None
@property
def connected(self):
"""\
Check if database is currently connected.
"""
return (self._conn is not None) and (not self._conn.closed)
@property
def conn(self):
"""\
Get the current connection, establishing one if needed.
"""
if not self.connected:
self.connect()
return self._conn
[docs]
def in_transaction(self) -> bool:
"""\
Check if currently in a transaction.
Returns:
bool: True if in transaction, False otherwise
"""
return self._conn is not None and not self._conn.closed and self._conn.in_transaction()
[docs]
def commit(self):
"""\
Commit the current transaction.
"""
if self.in_transaction():
self._conn.commit()
[docs]
def rollback(self):
"""\
Rollback the current transaction.
"""
if self.in_transaction():
self._conn.rollback()
[docs]
def __enter__(self):
"""\
Context manager entry: establishes connection and begins transaction.
"""
if not self.connected:
self.connect()
# Begin transaction if not already in one
if not self.in_transaction():
self.conn.begin()
self._in_context_manager = True
return self
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
"""\
Context manager exit: commits or rolls back transaction and closes connection.
"""
try:
if exc_type is not None:
if self.in_transaction():
self.rollback()
else:
if self.in_transaction():
self.commit()
finally:
self.close_conn(commit=False) # Don't commit again, already handled above
self._in_context_manager = False
return False
[docs]
def orm_execute(
self,
query,
autocommit: Optional[bool] = False,
**kwargs,
) -> Union[SQLResponse, None]:
"""
Execute a SQLAlchemy ORM query or statement.
Args:
query: SQLAlchemy ORM statement or ClauseElement
autocommit: Whether to run in autocommit mode (default: False - no commits after execution)
**kwargs: Additional keyword arguments for query execution
Returns:
SQLResponse: Enhanced result wrapper with convenient data access
None: For operations that don't return results (e.g., INSERT, UPDATE, DELETE)
Examples:
# Using SQLAlchemy ORM statements
from sqlalchemy import select, insert, update, delete
from sqlalchemy.sql import text
# Select statement
stmt = select(users_table).where(users_table.c.id == 1)
result = db.orm_execute(stmt)
# Insert statement
stmt = insert(users_table).values(name="Alice")
db.orm_execute(stmt, autocommit=True)
# Update statement
stmt = update(users_table).where(users_table.c.id == 1).values(name="Bob")
db.orm_execute(stmt, autocommit=True)
# Delete statement
stmt = delete(users_table).where(users_table.c.id == 1)
db.orm_execute(stmt, autocommit=True)
# DDL operations
from sqlalchemy import MetaData, Table
metadata = MetaData()
table = Table('users', metadata, autoload_with=engine)
table.drop(engine) # This should work without literal_binds
"""
if not isinstance(query, get_sa_elements().ClauseElement):
raise ValueError("orm_execute only accepts SQLAlchemy ORM statements (ClauseElement)")
try:
if isinstance(query, get_sa_schema().DDLElement) or hasattr(query, "_is_ddl"):
return self._exec_sql(query, params=None, autocommit=autocommit)
else:
# query = query.compile(bind=self.engine, compile_kwargs={"literal_binds": True})
return self._exec_sql(query, params=None, autocommit=autocommit)
except Exception as e:
logger.debug(f"ORM Query: {query}")
logger.error(f"Database ORM execution failed:\n{error_str(e)}")
raise DatabaseError(f"\nDatabase ORM execution failed:\n{error_str(e)}\nQuery: {query}\n")
finally:
pass
[docs]
def execute(
self,
query: str,
transpile: Optional[str] = None,
autocommit: Optional[bool] = False,
params: Optional[Union[Dict[str, Any], List[Dict[str, Any]], Tuple]] = None,
safe: bool = False,
**kwargs,
) -> Union[SQLResponse, SQLErrorResponse, None]:
"""
Execute a raw SQL query against the database.
Args:
query: The SQL query to execute (raw SQL string)
transpile: Source dialect to transpile from (if different from target)
autocommit: Whether to run in autocommit mode (default: False - no commits after execution)
params: Query parameters (dict for named, tuple/list for positional)
safe: If True, returns SQLErrorResponse on error instead of raising exception (default: False)
**kwargs: Additional keyword arguments for query execution
Returns:
SQLResponse: Enhanced result wrapper with convenient data access
SQLErrorResponse: Structured error response (only if safe=True)
None: For operations that don't return results (e.g., INSERT, UPDATE, DELETE)
Examples:
# Simple query (uses temporary connection with autocommit)
result = db.execute("SELECT * FROM users")
rows = list(result.fetchall())
# Parameterized query
result = db.execute("SELECT * FROM users WHERE id = :id", params={"id": 1})
# Parameterized insert
db.execute(
"INSERT INTO users (name) VALUES (:name)",
params={"name": "Alice"}
)
# Transactional operation
with db:
db.execute("INSERT INTO users (name) VALUES (:name)", params={"name": "Bob"})
db.execute("UPDATE users SET active = TRUE WHERE name = :name", params={"name": "Bob"})
# Cross-database SQL (transpile from PostgreSQL to the current database dialect, i.e., SQLite)
result = db.execute("SELECT * FROM users LIMIT 10", transpile="postgresql")
# Safe mode - returns error instead of raising
result = db.execute("SELECT * FROM nonexistent", safe=True)
if isinstance(result, SQLErrorResponse):
print(result.to_string())
Note:
For SQLAlchemy ORM operations, use orm_execute() method instead.
"""
# If user passes a ClauseElement to execute(), redirect to orm_execute
# but don't pass params since ClauseElement should have its own parameters
if isinstance(query, get_sa_elements().ClauseElement):
if params is not None:
logger.warning("Parameters ignored when executing ClauseElement via execute(). Use orm_execute() for ClauseElement queries.")
return self.orm_execute(query, autocommit=autocommit, **kwargs)
try:
# Process string query with optional transpilation and parameters
processed_query, processed_params = self.sql_processor.process_query(query, params, transpile_from=transpile)
return self._exec_sql(get_sa().text(processed_query), processed_params, autocommit=autocommit, safe=safe)
except Exception as e:
if safe:
# Return structured error response using the error handler
error_handler = DatabaseErrorHandler(db=self)
error_type, short_msg, full_msg = error_handler.handle(e, query, params)
return SQLErrorResponse(
error_type=error_type,
short_message=short_msg,
full_message=full_msg,
query=query,
params=params,
)
else:
# Original behavior: raise exception
logger.debug(f"SQL Query: {query}")
logger.debug(f"Parameters: {params}")
logger.error(f"Database execution failed:\n{error_str(e)}")
raise DatabaseError(f"Database execution failed:\n{e}\nQuery: {query}\nParams: {params}\n")
finally:
pass
def _exec_sql(self, query, params=None, autocommit: Optional[bool] = False, safe: bool = False) -> Optional[Union[SQLResponse, SQLErrorResponse]]:
"""\
Internal method to execute SQL queries.
Connection pool handles stale connection recovery via pool_pre_ping.
Transaction management:
- autocommit=True: Execute and commit immediately (for DDL, single statements)
- autocommit=False: Execute without commit (for use in transactions)
- In context manager: autocommit=True is not allowed
Args:
query: SQLAlchemy text() or ClauseElement
params: Query parameters
autocommit: Whether to commit after execution
safe: If True, re-raise exceptions for caller to handle as SQLErrorResponse
Returns:
SQLResponse or None
"""
if autocommit and self._in_context_manager:
raise DatabaseError("Cannot use `autocommit=True` within a context manager!")
# Outside context manager: ensure clean transaction state
if not self._in_context_manager and self.in_transaction():
self.commit()
try:
# Execute the query
result = self.conn.execute(query, params) if params else self.conn.execute(query)
response = SQLResponse(result) if result else None
# Commit if autocommit mode
if autocommit:
self.commit()
return response
except Exception as e:
# Always attempt to recover from invalid transaction state
if self._conn is not None:
try:
# Try to rollback the connection directly (bypassing in_transaction check)
# This handles SQLAlchemy's PendingRollbackError state
self._conn.rollback()
except Exception:
# If rollback fails, close and dispose connection
pass
# Close the connection to return it to the pool in a clean state
self.close_conn()
if safe:
logger.debug(f"Database execution failed (safe mode): {error_str(e, tb=False)}")
raise
else:
logger.error(f"Database execution failed: {error_str(e)}")
raise DatabaseError(f"Database execution failed: {error_str(e)}")
# === Database Inspection Methods ===
[docs]
def db_tabs(self) -> List[str]:
"""\
List all table names in the database.
Returns:
List[str]: List of table names
"""
try:
inspector = get_sa().inspect(self.conn)
return inspector.get_table_names()
except Exception as e:
logger.warning(f"Inspector failed, falling back to SQL: {error_str(e)}")
try:
result = self.execute(load_builtin_sql("utils/db_tabs", dialect=self.dialect), autocommit=True)
return [row["tab_name"] for row in result.to_list()]
except Exception as e:
logger.error(f"Failed to list tables: {error_str(e)}")
return []
[docs]
def db_views(self) -> List[str]:
"""\
List all view names in the database.
Returns:
List[str]: List of view names
"""
try:
inspector = get_sa().inspect(self.conn)
return inspector.get_view_names()
except Exception as e:
logger.warning(f"Inspector failed, falling back to SQL: {error_str(e)}")
try:
result = self.execute(load_builtin_sql("utils/db_views", dialect=self.dialect), autocommit=True)
return [row["view_name"] for row in result.to_list()]
except Exception as e:
logger.error(f"Failed to list views: {error_str(e)}")
return []
[docs]
def tab_cols(self, tab_name: str, full_info: bool = False):
"""\
List column information for a specific table.
Args:
tab_name: Name of the table
full_info: If True, return full column information; if False, return only column names
Returns:
When full_info=True: List of column dictionaries with full metadata
When full_info=False: List[str] of column names
"""
try:
inspector = get_sa().inspect(self.conn)
columns = inspector.get_columns(tab_name)
if full_info:
return columns
else:
return [col["name"] for col in columns]
except Exception as e:
logger.warning(f"Inspector failed, falling back to SQL: {error_str(e)}")
try:
result = self.execute(load_builtin_sql("utils/tab_cols", dialect=self.dialect, tab_name=tab_name), autocommit=True)
if full_info:
return result.to_list()
else:
return [row["col_name"] for row in result.to_list()]
except Exception as e:
logger.error(f"Failed to list columns for table {tab_name}: {error_str(e)}")
return []
[docs]
def tab_pks(self, tab_name: str) -> List[str]:
"""\
List primary key column names for a specific table.
Args:
tab_name: Name of the table
Returns:
List[str]: List of primary key column names
"""
try:
inspector = get_sa().inspect(self.conn)
pks = inspector.get_pk_constraint(tab_name)
pk_columns = pks.get("constrained_columns", []) if pks else []
if pk_columns: # Only return if we found primary keys
return pk_columns
except Exception as e:
logger.warning(f"Inspector failed, falling back to SQL: {error_str(e)}")
try:
# Standard execution for all dialects
result = self.execute(load_builtin_sql("utils/tab_pks", dialect=self.dialect, tab_name=tab_name), autocommit=True)
return [row["col_name"] for row in result.to_list()]
except Exception as e:
logger.error(f"Failed to list primary keys for table {tab_name}: {error_str(e)}")
return []
[docs]
def tab_fks(self, tab_name: str) -> List[Dict[str, str]]:
"""\
List foreign key information for a specific table.
Args:
tab_name: Name of the table
Returns:
List[Dict[str, str]]: List of foreign key information with keys:
- col_name: Column name in the current table
- tab_ref: Referenced table name
- col_ref: Referenced column name
- name: Foreign key constraint name
"""
try:
inspector = get_sa().inspect(self.conn)
fks = inspector.get_foreign_keys(tab_name)
result = []
for fk in fks:
for col, ref_col in zip(fk["constrained_columns"], fk["referred_columns"]):
result.append(
{
"col_name": col,
"tab_ref": fk["referred_table"],
"col_ref": ref_col,
"name": fk.get("name", f"FK_{col}_{fk['referred_table']}_{ref_col}"),
}
)
return result
except Exception as e:
logger.warning(f"Inspector failed, falling back to SQL: {error_str(e)}")
try:
result = self.execute(load_builtin_sql("utils/tab_fks", dialect=self.dialect, tab_name=tab_name), autocommit=True)
return result.to_list()
except Exception as e:
logger.error(f"Failed to list foreign keys for table {tab_name}: {error_str(e)}")
return []
[docs]
def row_count(self, tab_name: str) -> int:
"""\
Get row count for a specific table.
Args:
tab_name: Name of the table
Returns:
int: Number of rows in the table
"""
try:
result = self.execute(load_builtin_sql("utils/row_count", dialect=self.dialect, tab_name=tab_name), autocommit=True)
return result.to_list()[0]["cnt"]
except Exception as e:
logger.error(f"Failed to count rows for table {tab_name}: {error_str(e)}")
return 0
[docs]
def col_type(self, tab_name: str, col_name: str) -> str:
"""\
Get column type for a specific column in a table.
Args:
tab_name: Name of the table
col_name: Name of the column
Returns:
str: Column type
"""
try:
inspector = get_sa().inspect(self.conn)
for col in inspector.get_columns(tab_name):
if col["name"] == col_name:
return str(col["type"])
raise ValueError(f"Column {col_name} not found in table {tab_name}")
except Exception as e:
logger.warning(f"Inspector failed, falling back to SQL: {error_str(e)}")
try:
result = self.execute(
load_builtin_sql("utils/col_type", dialect=self.dialect, tab_name=tab_name, col_name=col_name),
autocommit=True,
)
return result.to_list()[0]["col_type"]
except Exception as e:
logger.error(f"Failed to get column type for {tab_name}.{col_name}: {error_str(e)}")
return ""
# === Data Analysis Methods ===
[docs]
def col_distincts(self, tab_name: str, col_name: str) -> List[Any]:
"""\
Get distinct values for a specific column.
Args:
tab_name: Name of the table
col_name: Name of the column
Returns:
List[Any]: List of distinct values
"""
try:
result = self.execute(
load_builtin_sql("utils/col_distincts", dialect=self.dialect, tab_name=tab_name, col_name=col_name),
autocommit=True,
)
return [row["col_enums"] for row in result.to_list()]
except Exception as e:
logger.error(f"Failed to get distinct values for column {col_name} in table {tab_name}: {error_str(e)}")
return []
[docs]
def col_enums(self, tab_name: str, col_name: str) -> List[Any]:
"""\
Get all enumerated values for a specific column (including duplicates).
This method returns all values from a column, including duplicates.
For unique values only, use col_distincts() instead.
Args:
tab_name: Name of the table
col_name: Name of the column
Returns:
List[Any]: List of all enumerated values (may contain duplicates)
"""
try:
result = self.execute(
load_builtin_sql("utils/col_enums", dialect=self.dialect, tab_name=tab_name, col_name=col_name),
autocommit=True,
)
return [row["col_enums"] for row in result.to_list()]
except Exception as e:
logger.error(f"Failed to get enumerated values for column {col_name} in table {tab_name}: {error_str(e)}")
return []
[docs]
def col_freqs(self, tab_name: str, col_name: str) -> List[Dict[str, Any]]:
"""\
Get value frequencies for a specific column.
Args:
tab_name: Name of the table
col_name: Name of the column
Returns:
List[Dict[str, Any]]: List of value-frequency pairs
"""
try:
result = self.execute(
load_builtin_sql("utils/col_freqs", dialect=self.dialect, tab_name=tab_name, col_name=col_name),
autocommit=True,
)
return result.to_list()
except Exception as e:
logger.error(f"Failed to get frequencies for column {col_name} in table {tab_name}: {error_str(e)}")
return []
[docs]
def col_freqk(self, tab_name: str, col_name: str, topk: int = 20) -> List[Dict[str, Any]]:
"""\
Get top-k value frequencies for a specific column.
Args:
tab_name: Name of the table
col_name: Name of the column
k: Number of top values to return
Returns:
List[Dict[str, Any]]: List of top-k value-frequency pairs
"""
try:
result = self.execute(
load_builtin_sql("utils/col_freqk", dialect=self.dialect, tab_name=tab_name, col_name=col_name, topk=topk),
autocommit=True,
)
return result.to_list()
except Exception as e:
logger.error(f"Failed to get top-{topk} frequencies for column {col_name} in table {tab_name}: {error_str(e)}")
return []
[docs]
def col_nonnulls(self, tab_name: str, col_name: str) -> List[Any]:
"""\
Get list of non-null values for a specific column.
Args:
tab_name: Name of the table
col_name: Name of the column
Returns:
List[Any]: List of non-null values
"""
try:
result = self.execute(
load_builtin_sql("utils/col_nonnulls", dialect=self.dialect, tab_name=tab_name, col_name=col_name),
autocommit=True,
)
return [row["col_enums"] for row in result.to_list()]
except Exception as e:
logger.error(f"Failed to get non-null values for column {col_name} in table {tab_name}: {error_str(e)}")
return []
# === Database Manipulation Methods ===
[docs]
def clear_tab(self, tab_name: str) -> None:
"""\
Clear all data from a specific table without deleting the table itself.
Uses SQLAlchemy ORM to ensure compatibility across all database backends.
Args:
tab_name: Name of the table to clear
Raises:
Exception: If the clearing operation fails
"""
try:
from sqlalchemy import MetaData, Table, delete
metadata = MetaData()
table = Table(tab_name, metadata, autoload_with=self.engine)
delete_stmt = delete(table)
self.orm_execute(delete_stmt, autocommit=True)
logger.info(f"Cleared table: {tab_name}")
except Exception as e:
logger.error(f"Failed to clear table {tab_name}: {error_str(e)}")
raise Exception(f"Table clear failed for {tab_name}: {e}")
[docs]
def drop_tab(self, tab_name: str) -> None:
"""\
Drop a specific table from the database.
Uses SQLAlchemy ORM to ensure compatibility across all database backends.
Args:
tab_name: Name of the table to drop
Raises:
Exception: If the drop operation fails
"""
try:
metadata = get_sa().MetaData()
table = get_sa().Table(tab_name, metadata, autoload_with=self.engine)
table.drop(self.engine)
logger.info(f"Dropped table: {tab_name}")
except Exception as e:
logger.error(f"Failed to drop table {tab_name}: {error_str(e)}")
raise Exception(f"Table drop failed for {tab_name}: {e}")
[docs]
def drop_view(self, view_name: str) -> None:
"""\
Drop a specific view from the database.
Args:
view_name: Name of the view to drop
Raises:
Exception: If the drop operation fails
"""
try:
self.execute(f"DROP VIEW IF EXISTS {view_name}", autocommit=True)
logger.info(f"Dropped view: {view_name}")
except Exception as e:
logger.error(f"Failed to drop view {view_name}: {error_str(e)}")
raise Exception(f"View drop failed for {view_name}: {e}")
[docs]
def drop(self) -> None:
"""\
Drop all tables in the database.
Uses SQLAlchemy metadata reflection to drop all tables.
Raises:
DatabaseError: If the database drop operation fails
"""
try:
# Use SQLAlchemy metadata to reflect and drop all tables
metadata = get_sa().MetaData()
metadata.reflect(bind=self.engine)
metadata.drop_all(bind=self.engine, checkfirst=True)
logger.info("Dropped all tables using metadata")
except Exception as e:
logger.warning(f"Metadata drop failed, trying fallback: {error_str(e)}")
# Fallback: try to drop tables individually
tables = self.db_tabs()
for table_name in tables:
try:
self.drop_tab(table_name)
except Exception as table_e:
logger.warning(f"Failed to drop table {table_name}: {error_str(table_e)}")
finally:
# Close connection and dispose engine
self.close_conn(commit=False)
if hasattr(self, "engine") and self.engine:
self.engine.dispose()
self.engine = None
[docs]
def init(self, connect: bool = True) -> None:
"""\
Drop the entire database and create a new one.
This method combines drop() and database creation. After dropping,
it will recreate the database and establish a new connection.
Raises:
Exception: If the database initialization fails
"""
self.drop()
self._init(connect=connect)
[docs]
def clear(self) -> None:
"""\
Clear all data from tables in the database without deleting the tables themselves.
Uses the `clear_tab` method to ensure compatibility across all database backends.
Raises:
Exception: If the clearing operation fails
"""
tables = self.db_tabs()
for table_name in tables:
try:
self.clear_tab(table_name)
except Exception as e:
logger.error(f"Failed to clear table {table_name}: {error_str(e)}")
[docs]
def close(self) -> None:
"""\
Close the database connection and dispose of the engine.
"""
self.close_conn(commit=True)
if hasattr(self, "engine") and self.engine:
self.engine.dispose()
self.engine = None
[docs]
def table_display(
table: Union["SQLResponse", Iterable[Dict]],
schema: Optional[List[str]] = None,
max_rows: int = 64,
max_width: int = 64,
style: Literal["DEFAULT", "MARKDOWN", "PLAIN_COLUMNS", "MSWORD_FRIENDLY", "ORGMODE", "SINGLE_BORDER", "DOUBLE_BORDER", "RANDOM"] = "DEFAULT",
**kwargs,
):
"""\
Render a tabular display of SQL query results or iterable dictionaries using PrettyTable.
Args:
table (Union[SQLResponse, Iterable[Dict]]): The table data to display. Can be a SQLResponse object
(from a database query) or any iterable of dictionaries (e.g., list of dicts).
schema (Optional[List[str]], optional): List of column names to use as the table schema. If not provided,
the schema is inferred from the SQLResponse or from the first row of the iterable.
max_rows (int, optional): Maximum number of rows to display (including the last row and an ellipsis row if truncated).
If the table has more than `max_rows + 1` rows, the output will show the first `max_rows-1` rows, an ellipsis row,
and the last row. Defaults to 64.
max_width (int, optional): Maximum width for each column in the output table. Defaults to 64.
style (Literal["DEFAULT", "MARKDOWN", "PLAIN_COLUMNS", "MSWORD_FRIENDLY", "ORGMODE", "SINGLE_BORDER", "DOUBLE_BORDER", "RANDOM"], optional): The style to use for the table (supported by PrettyTable). Defaults to "DEFAULT".
**kwargs: Additional keyword arguments passed to PrettyTable.
Returns:
str: A string representation of the formatted table, including the number of rows in total.
Raises:
ValueError: If the provided table rows do not match the schema in length.
Example:
>>> result = db.execute("SELECT * FROM users")
>>> table_display(result, max_rows=5)
"""
if isinstance(table, SQLResponse):
schema = table.columns
table = table.to_list(row_fmt="dict")
else:
table = list(table)
schema = schema or (list() if not table else list(table[0].keys()))
if not all(len(row) == len(schema) for row in table):
raise ValueError(f"Table failed to display. All rows must have the same number of columns as the schema.\nSchema: {schema}\nTable:\n{table}")
# Define table styles directly
styles = {
"DEFAULT": get_prettytable().TableStyle.DEFAULT,
"MARKDOWN": get_prettytable().TableStyle.MARKDOWN,
"PLAIN_COLUMNS": get_prettytable().TableStyle.PLAIN_COLUMNS,
"MSWORD_FRIENDLY": get_prettytable().TableStyle.MSWORD_FRIENDLY,
"ORGMODE": get_prettytable().TableStyle.ORGMODE,
"SINGLE_BORDER": get_prettytable().TableStyle.SINGLE_BORDER,
"DOUBLE_BORDER": get_prettytable().TableStyle.DOUBLE_BORDER,
"RANDOM": get_prettytable().TableStyle.RANDOM,
}
ptable = get_prettytable().PrettyTable(schema, **kwargs)
ptable.set_style(styles.get(style, "DEFAULT"))
ptable.float_format = ".6"
ptable.max_width = max_width
if (max_rows is not None) and (len(table) > max_rows + 1):
bottom_cnt = max_rows // 2
top_cnt = max_rows - bottom_cnt
omitted_cnt = len(table) - max_rows
for row in table[:top_cnt]:
ptable.add_row([val for _, val in zip(schema, row.values())])
ptable.add_row([f"... ({omitted_cnt} rows omitted)" if i == 0 else "..." for i, _ in enumerate(schema)])
for row in table[-bottom_cnt:] if bottom_cnt > 0 else []:
ptable.add_row([val for _, val in zip(schema, row.values())])
else:
for row in table:
ptable.add_row([val for _, val in zip(schema, row.values())])
return str(ptable) + f"\n{len(table)} rows in total."