
企业微信

飞书
选择您喜欢的方式加入群聊

扫码添加咨询专家
在 Text-to-SQL 场景中,最大的挑战不是生成 SQL 语法,而是如何让 AI 找到正确的表和字段。一个企业数据库可能有数百张表、数千个字段,如何从用户的自然语言问题中精准定位到相关的元数据?
AskTable 的 Schema Linking 引擎,通过 向量检索 + 全文搜索 + Training Pair 复用 的混合策略,实现了高准确率的元数据检索。
本文将深入剖析这套引擎的设计与实现。
问题:用户问"上个月销售额最高的产品是什么?"
数据库:
orders, products, customers, sales_records, ...order_id, product_name, sale_amount, created_at, ...AI 需要回答:
orders? sales_records?)sale_amount? total_price?)created_at? order_date?)Schema Linking 的任务:从海量元数据中,找出与问题相关的表和字段。
全量传递:
# 将所有表和字段传递给 LLM
prompt = f"数据库有以下表:{all_tables}\n问题:{question}"
问题:
关键词匹配:
# 简单的字符串匹配
relevant_tables = [t for t in all_tables if keyword in t.name]
问题:
Qdrant(向量检索):
Meilisearch(全文搜索):
Training Pair(示例复用):
async def _rewrite_question(self, question: Question) -> None:
"""将用户问题改写为关键词和子查询"""
if not question.subqueries:
response = await prompt_generate(
"query.extract_keywords_from_question",
QUESTION=question.text,
SPECIFICATION=question.specification,
EVIDENCE=question.evidence,
)
question.keywords = response["keywords"]
question.subqueries = response["subqueries"]
log.info(f"extracted keywords: {question.keywords}")
示例:
# 输入
question = "上个月销售额最高的产品是什么?"
# 输出
keywords = ["销售额", "产品", "上个月"]
subqueries = [
"查询销售记录表",
"按产品分组统计销售额",
"筛选上个月的数据"
]
关键点:
async def _retrieve_fields(self, queries: list[str]) -> list[RetrievedMetaEntity]:
"""通过向量检索查找相关字段"""
if not queries:
log.warning("No subqueries provided, skipping field retrieval")
return []
return await self.ds.retrieve_fields_by_question(queries)
Qdrant 检索实现:
async def retrieve_fields_by_question(
self, queries: list[str]
) -> list[RetrievedMetaEntity]:
"""向量检索字段"""
# 1. 将查询转换为向量
query_vectors = await self.embedding_model.encode(queries)
# 2. 在 Qdrant 中搜索
results = await self.qdrant_client.search(
collection_name=f"meta_{self.id}",
query_vector=query_vectors[0],
limit=20,
score_threshold=0.7,
)
# 3. 返回检索结果
return [
{
"id": hit.id,
"payload": hit.payload,
"score": hit.score,
}
for hit in results
]
关键点:
score_threshold=0.7 过滤低相关结果async def _retrieve_values(self, keywords: list[str]) -> list[RetrievedMetaEntity]:
"""通过全文搜索查找字段值"""
if not config.aisearch_host or not config.aisearch_master_key:
log.warning("Value index is not enabled, skipping value retrieval")
elif not keywords:
log.warning("No keywords provided, skipping value retrieval")
else:
values = await self.ds.retrieve_values_by_question(keywords)
return values
return []
Meilisearch 检索实现:
async def retrieve_values_by_question(
self, keywords: list[str]
) -> list[RetrievedMetaEntity]:
"""全文搜索字段值"""
# 1. 构建搜索查询
query = " ".join(keywords)
# 2. 在 Meilisearch 中搜索
results = await self.meilisearch_client.index(f"values_{self.id}").search(
query,
limit=50,
attributesToRetrieve=["schema_name", "table_name", "field_name", "value"],
)
# 3. 返回检索结果
return [
{
"id": hit["id"],
"payload": {
"schema_name": hit["schema_name"],
"table_name": hit["table_name"],
"field_name": hit["field_name"],
"value": hit["value"],
"type": "value",
},
"score": hit["_rankingScore"],
}
for hit in results["hits"]
]
关键点:
async def _retrieve_examples(self, query: str) -> list[TrainingPair]:
"""检索相似的历史问题-SQL 对"""
translation_examples = await retrieve_training_pairs(
datasource_id=self.ds.id,
query=query,
role_id=self.role.id if self.role else None,
)
return translation_examples
Training Pair 存储结构:
class TrainingPair(TypedDict):
question: str # 历史问题
sql: str # 对应的 SQL
score: float # 相似度分数
示例:
# 当前问题
question = "上个月销售额最高的产品是什么?"
# 检索到的相似问题
training_pairs = [
{
"question": "本月销售额最高的商品是哪个?",
"sql": "SELECT product_name, SUM(amount) FROM sales WHERE month = CURRENT_MONTH GROUP BY product_name ORDER BY SUM(amount) DESC LIMIT 1",
"score": 0.92,
},
{
"question": "去年销量最好的产品?",
"sql": "SELECT product_id, COUNT(*) FROM orders WHERE year = LAST_YEAR GROUP BY product_id ORDER BY COUNT(*) DESC LIMIT 1",
"score": 0.85,
},
]
关键点:
def _merge_values_fields(hits: list[RetrievedMetaEntity]) -> list[MetaEntity]:
"""合并字段和值检索结果"""
fields_buckets: dict[tuple, set] = {}
for hit in hits:
index = (
hit["payload"]["schema_name"],
hit["payload"]["table_name"],
hit["payload"]["field_name"],
)
if not fields_buckets.get(index):
fields_buckets[index] = set()
bucket = fields_buckets[index]
if hit["payload"]["type"] == "value":
bucket.add(hit["payload"]["value"])
fields_list: list[MetaEntity] = []
for index, values in fields_buckets.items():
fields_list.append(
{
"schema_name": index[0],
"table_name": index[1],
"field_name": index[2],
"sample_values": list(values),
}
)
return fields_list
关键点:
def _add_context_to_meta(meta: MetaAdmin, entities: list[MetaEntity]):
"""将检索到的字段值注入到元数据描述中"""
for entity in entities:
if schema := meta.schemas.get(entity["schema_name"]):
if table := schema.tables.get(entity["table_name"]):
if field := table.fields.get(entity["field_name"]):
values = [f'"{v}"' for v in entity["sample_values"]]
if values:
if field.curr_desc:
field.curr_desc += f"(e.g. {','.join(values)})"
else:
field.curr_desc = f"(e.g. {','.join(values)})"
效果:
# 原始元数据
field = {
"name": "status",
"type": "VARCHAR",
"description": "订单状态"
}
# 注入上下文后
field = {
"name": "status",
"type": "VARCHAR",
"description": "订单状态(e.g. \"已完成\",\"待支付\",\"已取消\")"
}
关键点:
async def _pick_tables(
self,
meta_candidate: MetaAdmin,
specification: str,
training_pairs: list[TrainingPair],
) -> list[tuple[str, str]]:
"""通过 LLM 重新排序和选择相关表"""
# 1. 让 LLM 选择最相关的表
table_of_interest_ = await prompt_generate(
"query.select_tables_by_question",
meta_data=meta_candidate.to_markdown(),
question=specification,
translation_examples=dict_to_markdown(training_pairs),
)
table_of_interest = table_of_interest_["table_names"]
if not table_of_interest:
raise errors.NoDataToQuery(params={"message": "No data to query"})
log.info(f"relevant table names: {table_of_interest}")
# 2. 验证表名格式
pairs: list[tuple[str, str]] = []
for table_name in table_of_interest:
schema, table = table_name.split(".", 1)
pairs.append((schema, table))
return pairs
关键点:
AskTable 根据数据库规模,自动选择最优的 Schema Linking 模式:
适用场景:
策略:
async def _naive_link(
self, accessible_meta: MetaAdmin, question: Question
) -> MetaContext:
"""朴素模式:全量传递元数据"""
# 只检索值和示例,不过滤表
values = await self._retrieve_values(question.keywords or [])
pairs = await self._retrieve_examples(question.specification)
entities = _merge_values_fields(values)
_add_context_to_meta(accessible_meta, entities)
return {"meta": accessible_meta, "training_pairs": pairs}
优势:
适用场景:
策略:
async def _rag_link(
self, accessible_meta: MetaAdmin, question: Question
) -> MetaContext:
"""RAG 模式:向量检索 + 全文搜索"""
# 1. 检索字段、值和示例
fields = await self._retrieve_fields(question.subqueries or [])
values = await self._retrieve_values(question.keywords or [])
pairs = await self._retrieve_examples(question.specification)
# 2. 合并实体
examples = _training_pair_to_entities(pairs, self.ds.dialect)
entities = _merge_values_fields(values + fields + examples)
_add_context_to_meta(accessible_meta, entities)
# 3. 提取命中的表
hit_table_names: set[tuple[str, str]] = set(
[(e["schema_name"], e["table_name"]) for e in entities]
)
# 4. 如果命中表过多,使用 LLM Rerank
if len(hit_table_names) > 3:
fields_full_names = _get_field_full_names_from_entities(entities)
hit_fields = accessible_meta.filter_fields_by_names(
[convert_full_name_to_tuple(f) for f in fields_full_names]
)
table_of_interest = await self._pick_tables(
hit_fields, question.specification, pairs
)
meta = accessible_meta.filter_tables_by_names(table_of_interest)
else:
meta = accessible_meta.filter_tables_by_names(list(hit_table_names))
return {"meta": meta, "training_pairs": pairs}
优势:
适用场景:
策略:
async def _reasoning_link(
self, accessible_meta: MetaAdmin, question: Question
) -> MetaContext:
"""推理模式:LLM 主导的表选择"""
# 1. 检索值和示例
values = await self._retrieve_values(question.keywords or [])
pairs = await self._retrieve_examples(question.specification)
# 2. 注入上下文
entities = _merge_values_fields(values)
_add_context_to_meta(accessible_meta, entities)
# 3. 让 LLM 选择相关表
table_of_interest = await self._pick_tables(
accessible_meta, question.specification, pairs
)
meta = accessible_meta.filter_tables_by_names(table_of_interest)
return {"meta": meta, "training_pairs": pairs}
优势:
async def link(self, question: Question) -> MetaContext:
"""自动选择最优模式"""
accessible_meta = self._get_accessible_meta()
if config.at_schema_linking_mode == SchemaLinkingMode.auto:
if accessible_meta.table_count <= 3 and accessible_meta.field_count <= 100:
return await self._naive_link(accessible_meta, question)
elif accessible_meta.table_count <= 7 and accessible_meta.field_count <= 300:
return await self._reasoning_link(accessible_meta, question)
else:
return await self._rag_link(accessible_meta, question)
elif config.at_schema_linking_mode == SchemaLinkingMode.naive:
return await self._naive_link(accessible_meta, question)
elif config.at_schema_linking_mode == SchemaLinkingMode.rag:
return await self._rag_link(accessible_meta, question)
elif config.at_schema_linking_mode == SchemaLinkingMode.reasoning:
return await self._reasoning_link(accessible_meta, question)
HNSW 索引:
# Qdrant 配置
collection_config = {
"vectors": {
"size": 1536, # OpenAI embedding 维度
"distance": "Cosine",
},
"hnsw_config": {
"m": 16, # 连接数
"ef_construct": 100, # 构建时的搜索深度
},
}
效果:
# 批量检索多个查询
query_vectors = await self.embedding_model.encode(queries)
results = await asyncio.gather(*[
self.qdrant_client.search(
collection_name=f"meta_{self.id}",
query_vector=vec,
limit=20,
)
for vec in query_vectors
])
效果:
# 缓存 Embedding 结果
@lru_cache(maxsize=1000)
async def get_embedding(text: str) -> list[float]:
return await embedding_model.encode(text)
效果:
# 问题
question = "有多少个用户?"
# Schema Linking 结果
meta = {
"tables": [
{
"name": "users",
"fields": [
{"name": "id", "type": "INT"},
{"name": "name", "type": "VARCHAR"},
]
}
]
}
# 生成的 SQL
sql = "SELECT COUNT(*) FROM users"
# 问题
question = "上个月销售额最高的产品是什么?"
# Schema Linking 结果
meta = {
"tables": [
{
"name": "orders",
"fields": [
{"name": "product_id", "type": "INT"},
{"name": "amount", "type": "DECIMAL", "description": "销售额(e.g. \"1000.00\",\"2500.50\")"},
{"name": "created_at", "type": "TIMESTAMP"},
]
},
{
"name": "products",
"fields": [
{"name": "id", "type": "INT"},
{"name": "name", "type": "VARCHAR", "description": "产品名称(e.g. \"iPhone\",\"MacBook\")"},
]
}
],
"training_pairs": [
{
"question": "本月销售额最高的商品?",
"sql": "SELECT p.name, SUM(o.amount) FROM orders o JOIN products p ON o.product_id = p.id WHERE MONTH(o.created_at) = MONTH(NOW()) GROUP BY p.name ORDER BY SUM(o.amount) DESC LIMIT 1"
}
]
}
# 生成的 SQL
sql = """
SELECT p.name, SUM(o.amount) as total_amount
FROM orders o
JOIN products p ON o.product_id = p.id
WHERE o.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH)
GROUP BY p.name
ORDER BY total_amount DESC
LIMIT 1
"""
AskTable 的 Schema Linking 引擎,通过 向量检索 + 全文搜索 + Training Pair 复用 的混合策略,实现了:
✅ 高准确率:多模态检索提升召回率和精准度 ✅ 低延迟:HNSW 索引 + 批量检索 < 50ms ✅ 自适应:根据数据库规模自动选择最优模式 ✅ 可扩展:支持大型数据库(1000+ 表)
相关阅读:
技术交流: