
企业微信

飞书
选择您喜欢的方式加入群聊

扫码添加咨询专家
在数据可视化场景中,用户往往需要多次调整图表才能达到理想效果。传统的做法是重新生成整个图表,但这会丢失之前的上下文和用户意图。AskTable 的 ChartImprovementAgent 采用了一种更智能的方法:保留原始上下文,理解改进意图,迭代优化图表。
用户在使用 AI 生成图表时,常见的改进需求包括:
如果每次都重新生成,LLM 可能会:
ChartImprovementAgent 在初始化时接收完整的上下文信息:
class ChartImprovementAgent:
def __init__(
self,
parent_nodes_context: list[dict], # 父节点数据
original_question: str, # 原始问题
original_code: str, # 原始 JSX 代码
improvement_request: str, # 改进请求
):
self.parent_nodes_context = parent_nodes_context
self.original_question = original_question
self.original_code = original_code
self.improvement_request = improvement_request
父节点上下文包含:
parent_nodes_context = [
{
"id": "node_123",
"sql": "SELECT region, SUM(sales) as total FROM orders GROUP BY region",
"description": "各地区销售汇总",
"dataframe": {
"columns": ["region", "total"],
"sample_data": [
{"region": "华东", "total": 1000000},
{"region": "华北", "total": 800000},
]
}
}
]
将上下文信息格式化为系统 Prompt:
# 格式化父节点信息
parent_info_parts = []
for i, node in enumerate(parent_nodes_context, 1):
parent_info_parts.append(f"Node {i} (ID: {node['id']}):")
if node.get("description"):
parent_info_parts.append(f" Description: {node['description']}")
if node.get("sql"):
parent_info_parts.append(f" SQL: {node['sql']}")
if node.get("dataframe"):
df = node["dataframe"]
parent_info_parts.append(f" Columns: {', '.join(df.get('columns', []))}")
if df.get("sample_data"):
parent_info_parts.append(f" Sample rows: {len(df['sample_data'])}")
parent_info_parts.append("")
formatted_parent_info = "\n".join(parent_info_parts)
# 构建系统 Prompt
system_prompt = get_prompt("agent/canvas/edit_chart").compile(
formatted_parent_info=formatted_parent_info,
original_question=original_question,
original_code=original_code,
)
Agent 提供
submit_improved_chart 工具,用于提交改进后的代码:
def submit_improved_chart(
self,
question: str = Field(
...,
description="The rewritten question that naturally incorporates both original requirement and the improvement",
),
description: str = Field(
...,
description="Brief description of the improvements made (1-2 sentences), or error reason if code is None",
),
code: str | None = Field(
None,
description="The improved JSX code for the chart component. Set to None if improvement failed.",
),
) -> str:
if code is not None:
try:
# 编译 JSX 代码
self.compiled_code = compile_jsx(code)
self.source_code = code
log.info("Improved chart code compiled successfully")
except Exception as e:
log.error(f"Failed to compile improved chart code: {str(e)}")
raise ValueError(f"Code compilation failed: {str(e)}")
else:
# 改进失败,记录原因
self.compiled_code = None
self.source_code = None
log.warning(f"Chart improvement failed: {description}")
self.submitted_description = description
self.submitted_question = question
status_msg = "error" if code is None else "success"
return json.dumps({
"question": question,
"description": description,
"status": status_msg,
"has_code": code is not None,
})
Agent 需要将原始问题和改进请求合并为新的问题描述:
示例:
这样可以保持问题的完整性,便于后续理解和维护。
原始代码:
<code>
import { BarChart } from '@/components/charts';
function RegionSales() {
const data = load_dataframe('df_region_sales');
return (
<BarChart
data={data}
xField="region"
yField="total"
title="各地区销售额"
/>
);
}
</code>
改进请求:「改成折线图」
改进后代码:
<code>
import { LineChart } from '@/components/charts';
function RegionSales() {
const data = load_dataframe('df_region_sales');
return (
<LineChart
data={data}
xField="region"
yField="total"
title="各地区销售额"
smooth={true}
/>
);
}
</code>
改进请求:「只显示销售额前 5 的地区」
改进后代码:
<code>
import { BarChart } from '@/components/charts';
function RegionSales() {
const data = load_dataframe('df_region_sales');
// 按销售额降序排列,取前 5
const topData = data
.sort((a, b) => b.total - a.total)
.slice(0, 5);
return (
<BarChart
data={topData}
xField="region"
yField="total"
title="销售额 Top 5 地区"
/>
);
}
</code>
改进请求:「使用蓝色系配色,图例放在底部」
改进后代码:
<code>
import { BarChart } from '@/components/charts';
function RegionSales() {
const data = load_dataframe('df_region_sales');
return (
<BarChart
data={data}
xField="region"
yField="total"
title="各地区销售额"
color="#1890ff"
legend={{
position: 'bottom'
}}
/>
);
}
</code>
改进请求:「添加平均线,标注最高值,使用渐变色」
改进后代码:
<code>
import { BarChart } from '@/components/charts';
function RegionSales() {
const data = load_dataframe('df_region_sales');
// 计算平均值
const average = data.reduce((sum, item) => sum + item.total, 0) / data.length;
// 找到最高值
const maxItem = data.reduce((max, item) => item.total > max.total ? item : max);
return (
<div>
<BarChart
data={data}
xField="region"
yField="total"
title="各地区销售额"
color={{
type: 'gradient',
colors: ['#1890ff', '#52c41a']
}}
annotations={[
{
type: 'line',
start: ['min', average],
end: ['max', average],
style: { stroke: '#ff4d4f', lineDash: [4, 4] },
text: { content: `平均值: ${average.toFixed(0)}`, position: 'end' }
},
{
type: 'text',
position: [maxItem.region, maxItem.total],
content: `最高: ${maxItem.total}`,
style: { fill: '#ff4d4f', fontWeight: 'bold' }
}
]}
/>
</div>
);
}
</code>
如果改进后的代码无法编译,Agent 会返回错误信息:
try:
self.compiled_code = compile_jsx(code)
except Exception as e:
raise ValueError(f"Code compilation failed: {str(e)}")
如果 LLM 判断改进请求无法实现,可以返回
code=None:
def submit_improved_chart(
self,
question: str,
description: str,
code: str | None = None, # None 表示改进失败
):
if code is None:
# 记录失败原因
self.compiled_code = None
self.source_code = None
log.warning(f"Chart improvement failed: {description}")
示例:
code=None, description="当前图表库不支持 3D 效果"确保改进后的代码仍然引用正确的数据源:
# 提取 load_dataframe 引用
load_df_pattern = r"load_dataframe\(\s*['\"]( df_[A-Za-z0-9]+)['\"]\s*\)"
referenced_dataframes = re.findall(load_df_pattern, code)
# 验证数据源是否存在
missing_ids = set(referenced_dataframes) - set(self.data_workspace.keys())
if missing_ids:
raise ValueError(f"Referenced dataframes {missing_ids} are not in the data workspace")
| 特性 | ChartNodeAgent | ChartImprovementAgent |
|---|---|---|
| 用途 | 从零生成图表 | 改进现有图表 |
| 输入 | 用户问题 + 数据 | 原始代码 + 改进请求 |
| 上下文 | 父节点数据 | 父节点数据 + 原始代码 + 原始问题 |
| 输出 | 新的 JSX 代码 | 改进后的 JSX 代码 + 重写的问题 |
| 问题描述 | 用户原始问题 | 合并原始问题和改进请求 |
用户可以多次迭代改进图表:
用户: 展示各地区销售额
→ 生成柱状图
用户: 改成折线图
→ 改进为折线图
用户: 只显示前 5 名
→ 添加数据筛选
用户: 使用蓝色系配色
→ 调整颜色方案
每次改进都保留之前的所有优化:
// 第一次改进:改成折线图
<LineChart ... />
// 第二次改进:只显示前 5 名
const topData = data.slice(0, 5);
<LineChart data={topData} ... />
// 第三次改进:使用蓝色系配色
const topData = data.slice(0, 5);
<LineChart data={topData} color="#1890ff" ... />
如果改进失败,保留原始图表:
def get_result(self) -> dict:
if self.source_code is None:
return {
"code": None,
"compiled_code": None,
"question": self.submitted_question or self.original_question,
"description": self.submitted_description,
"status": "error",
"error": self.submitted_description,
}
return {
"code": self.source_code,
"compiled_code": self.compiled_code,
"question": self.submitted_question,
"description": self.submitted_description,
"status": "success",
"error": None,
}
只编译改进后的代码,不重新编译整个项目:
self.compiled_code = compile_jsx(code) # 单文件编译
复用父节点数据,避免重复查询:
# 父节点数据已经包含 DataFrame
parent_nodes_context = [
{
"id": "node_123",
"dataframe": cached_dataframe # 复用缓存
}
]
多个改进请求可以并行处理:
# 并行处理多个改进
agents = [
ChartImprovementAgent(..., improvement_request="改成折线图"),
ChartImprovementAgent(..., improvement_request="只显示前 5 名"),
]
results = await asyncio.gather(*[agent.run() for agent in agents])
AskTable 的 ChartImprovementAgent 通过以下技术实现了智能的图表迭代优化:
这种设计不仅提升了用户体验,还保证了图表改进的可靠性和一致性,是 AI 驱动数据可视化系统的重要组成部分。