import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch, PathPatch
from matplotlib.path import Path
import numpy as np

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 配色
PINK = '#FFB6C1'
BLUE = '#AED6F1'
METRIC_BOX = '#FFD54F'
METRIC_BG = '#FFECB3'
STAGE_BG = '#F8F9FA'
LABEL_BG = '#E8E8E8'
ARROW_MAIN = '#5D6D7E'
ARROW_LIGHT = '#8899AA'

fig, ax = plt.subplots(1, 1, figsize=(14, 18))
ax.set_xlim(0, 14)
ax.set_ylim(0, 18)
ax.axis('off')

# 标题
ax.text(7, 17.3, '技术路线图', fontsize=22, fontweight='bold', ha='center', va='center')

# 阶段数据
stages = [
    {
        'title': '第一阶段：数据准备与预处理',
        'y': 14.8,
        'items': [
            ('PlantVillage数据集获取', PINK),
            ('数据清洗与增强', PINK),
            ('Non-IID数据划分', BLUE),
        ]
    },
    {
        'title': '第二阶段：联邦学习框架搭建',
        'y': 11.6,
        'items': [
            ('Flower框架部署', PINK),
            ('客户端-服务器架构', BLUE),
            ('通信协议实现', BLUE),
        ]
    },
    {
        'title': '第三阶段：算法实现与优化',
        'y': 8.4,
        'items': [
            ('FedAvg基线实现', PINK),
            ('FedProx近端约束', PINK),
            ('动态衰减机制', BLUE),
            ('Mixup数据增强', BLUE),
        ]
    },
    {
        'title': '第四阶段：实验设计与执行',
        'y': 5.0,
        'items': [
            ('异构性对比实验', PINK),
            ('消融实验', PINK),
            ('收敛性分析', BLUE),
        ]
    },
    {
        'title': '第五阶段：评估与总结',
        'y': 1.8,
        'items': [
            ('准确率', METRIC_BOX),
            ('F1分数', METRIC_BOX),
            ('通信效率', METRIC_BOX),
            ('论文撰写', BLUE),
        ],
        'is_metric': True
    },
]

def draw_stage(ax, stage, idx):
    y = stage['y']
    items = stage['items']
    is_metric = stage.get('is_metric', False)
    
    # 虚线框背景
    bg_color = METRIC_BG if is_metric else STAGE_BG
    frame = FancyBboxPatch((0.8, y - 1.8), 12.4, 2.6,
                           boxstyle="round,pad=0.05,rounding_size=0.3",
                           facecolor=bg_color, edgecolor='#CCCCCC',
                           linestyle='--', linewidth=1.5, alpha=0.6)
    ax.add_patch(frame)
    
    # 标签背景框
    label_box = FancyBboxPatch((1.0, y + 0.35), 12.0, 0.55,
                               boxstyle="round,pad=0.02,rounding_size=0.15",
                               facecolor=LABEL_BG, edgecolor='none', alpha=0.7)
    ax.add_patch(label_box)
    
    # 阶段标题
    ax.text(7, y + 0.6, stage['title'], fontsize=13, fontweight='bold', 
            ha='center', va='center', color='#2C3E50')
    
    # 内容框
    n = len(items)
    box_width = 2.6
    total_width = n * box_width + (n - 1) * 0.3
    start_x = 7 - total_width / 2
    
    for i, (text, color) in enumerate(items):
        x = start_x + i * (box_width + 0.3) + box_width / 2
        
        # 阴影效果
        shadow = FancyBboxPatch((x - box_width/2 + 0.06, y - 1.1 - 0.06), 
                                box_width, 0.9,
                                boxstyle="round,pad=0.02,rounding_size=0.2",
                                facecolor='#D0D0D0', edgecolor='none', alpha=0.5)
        ax.add_patch(shadow)
        
        # 主框
        box = FancyBboxPatch((x - box_width/2, y - 1.1), box_width, 0.9,
                             boxstyle="round,pad=0.02,rounding_size=0.2",
                             facecolor=color, edgecolor='#888888', linewidth=1.2)
        ax.add_patch(box)
        
        ax.text(x, y - 0.65, text, fontsize=10, ha='center', va='center',
                color='#2C3E50', fontweight='medium')

def draw_fancy_arrow(ax, y_start, y_end):
    """绘制双层圆润箭头：外层浅色宽箭头 + 内层深色细箭头"""
    x = 7
    y1 = y_start - 1.85
    y2 = y_end + 0.9
    
    # 外层：浅色宽箭头（阴影/光晕效果）
    outer_arrow = FancyArrowPatch(
        (x, y1), (x, y2),
        arrowstyle="Simple,head_length=18,head_width=16,tail_width=10",
        facecolor=ARROW_LIGHT,
        edgecolor='none',
        alpha=0.5,
        zorder=1
    )
    ax.add_patch(outer_arrow)
    
    # 内层：深色细箭头
    inner_arrow = FancyArrowPatch(
        (x, y1), (x, y2),
        arrowstyle="Simple,head_length=14,head_width=12,tail_width=6",
        facecolor=ARROW_MAIN,
        edgecolor='#4D5D6E',
        linewidth=0.5,
        zorder=2
    )
    ax.add_patch(inner_arrow)

# 绘制所有阶段
for idx, stage in enumerate(stages):
    draw_stage(ax, stage, idx)

# 绘制箭头
arrow_positions = [
    (14.8, 11.6),
    (11.6, 8.4),
    (8.4, 5.0),
    (5.0, 1.8),
]

for y_start, y_end in arrow_positions:
    draw_fancy_arrow(ax, y_start, y_end)

plt.tight_layout()
plt.savefig('/home/ubuntu/技术路线图_v6.png', dpi=150, bbox_inches='tight', 
            facecolor='white', edgecolor='none')
plt.close()
print("✅ v6 已生成：双层箭头（外层光晕 + 内层实心）")
