
企业微信

飞书
选择您喜欢的方式加入群聊

扫码添加咨询专家
AI 生成的 SQL 可能因为表名错误、字段不存在、语法错误等原因执行失败。如何让 Agent 从错误中学习,自动纠正并重试?
AskTable 的 Agent 自我纠错机制,通过 错误检测 + Prompt 调整 + Case 收集 的闭环,实现了持续优化。
表名错误:
-- AI 生成 SELECT * FROM order -- 错误:应该是 orders -- 错误信息 Table 'order' doesn't exist
字段不存在:
-- AI 生成 SELECT user_name FROM users -- 错误:应该是 username -- 错误信息 Unknown column 'user_name' in 'field list'
语法错误:
-- AI 生成 SELECT * FROM orders WHERE -- 错误:WHERE 后缺少条件 -- 错误信息 You have an error in your SQL syntax
直接返回错误:
无限重试:
加载图表中...
class SQLError(Enum): TABLE_NOT_FOUND = "table_not_found" COLUMN_NOT_FOUND = "column_not_found" SYNTAX_ERROR = "syntax_error" PERMISSION_DENIED = "permission_denied" TIMEOUT = "timeout" UNKNOWN = "unknown" def classify_error(error_message: str) -> SQLError: """分类 SQL 错误""" if "doesn't exist" in error_message.lower(): return SQLError.TABLE_NOT_FOUND elif "unknown column" in error_message.lower(): return SQLError.COLUMN_NOT_FOUND elif "syntax" in error_message.lower(): return SQLError.SYNTAX_ERROR elif "permission denied" in error_message.lower(): return SQLError.PERMISSION_DENIED elif "timeout" in error_message.lower(): return SQLError.TIMEOUT else: return SQLError.UNKNOWN
async def execute_with_retry( question: str, datasource: DataSourceAdmin, max_retries: int = 3, ) -> QueryResult: """带重试的 SQL 执行""" error_history: list[dict] = [] for attempt in range(max_retries): try: # 1. 生成 SQL sql = await generate_sql( question, datasource, error_history=error_history, ) # 2. 执行 SQL df = await datasource.execute_sql(sql) # 3. 成功,保存 Good Case await save_good_case(question, sql, df) return QueryResult(sql=sql, dataframe=df) except Exception as e: # 4. 失败,记录错误 error_type = classify_error(str(e)) error_history.append({ "sql": sql, "error": str(e), "error_type": error_type, "attempt": attempt + 1, }) # 5. 最后一次重试失败 if attempt == max_retries - 1: await save_bad_case(question, sql, str(e)) raise # 6. 继续重试 log.warning(f"SQL execution failed (attempt {attempt + 1}): {e}")
async def generate_sql( question: str, datasource: DataSourceAdmin, error_history: list[dict] | None = None, ) -> str: """生成 SQL,根据错误历史调整 Prompt""" # 基础 Prompt prompt = f""" 数据库:{datasource.to_markdown()} 问题:{question} 请生成 SQL 查询。 """ # 如果有错误历史,添加纠错提示 if error_history: prompt += "\n\n## 错误历史\n" for error in error_history: prompt += f""" 尝试 {error['attempt']}: SQL: {error['sql']} 错误: {error['error']} 错误类型: {error['error_type']} 请根据错误信息修正 SQL。 """ # 调用 LLM response = await llm.generate(prompt) return response["sql"]
async def save_good_case( question: str, sql: str, dataframe: pd.DataFrame, ): """保存成功案例""" await db.execute( """ INSERT INTO training_pairs (question, sql, status, created_at) VALUES (:question, :sql, 'good', NOW()) """, {"question": question, "sql": sql}, ) async def save_bad_case( question: str, sql: str, error: str, ): """保存失败案例""" await db.execute( """ INSERT INTO training_pairs (question, sql, error, status, created_at) VALUES (:question, :sql, :error, 'bad', NOW()) """, {"question": question, "sql": sql, "error": error}, )
# 第 1 次尝试 sql_1 = "SELECT * FROM order" error_1 = "Table 'order' doesn't exist" # 第 2 次尝试(根据错误调整) sql_2 = "SELECT * FROM orders" # 成功!
# 第 1 次尝试 sql_1 = "SELECT user_name FROM users" error_1 = "Unknown column 'user_name'" # 第 2 次尝试 sql_2 = "SELECT username FROM users" # 成功!
max_retries = 3 # 最多重试 3 次
效果:
error_cache: dict[str, str] = {} def get_cached_fix(sql: str, error: str) -> str | None: """获取缓存的修复方案""" cache_key = f"{sql}:{error}" return error_cache.get(cache_key)
效果:
AskTable Agent 的自我纠错机制,通过 错误检测 + 重试 + Case 收集 的闭环,实现了:
✅ 自动修复:常见错误自动纠正 ✅ 持续优化:从错误中学习 ✅ 用户体验:减少手动干预 ✅ 成本控制:限制重试次数
相关阅读:
技术交流: