import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import os

from matplotlib.font_manager import FontProperties, fontManager

_candidates = [
    '/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc',
    '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
]
font_path = next((p for p in _candidates if os.path.exists(p)), None)

if font_path:
    fontManager.addfont(font_path)
    fp = FontProperties(fname=font_path)
    plt.rcParams['font.family'] = fp.get_name()
else:
    fp = FontProperties()

plt.rcParams['axes.unicode_minus'] = False

# ── 配色：原版粉蓝 + 评估指标金黄 ──
PINK = '#FFB6C1'
BLUE = '#AED6F1'
METRIC_COLOR = '#FFECB3'  # 评估指标浅黄背景
METRIC_BOX = '#FFD54F'    # 评估指标金黄

# ── 尺寸常量 ──
TITLE_H   = 0.9
STAGE_H   = 0.85
SUB_BOX_H = 4.1
SUB_H     = 3.6
LABEL_H   = 0.6
METRIC_H  = SUB_H + 0.5
RESULT_H  = 0.8
ARROW_LEN = 0.5
GAP_12    = 0.9
GAP_23    = 0.7
GAP_34    = 0.5
GAP_45    = 0.5
GAP_56    = 0.7

col_w = 16 / 5
sub_w = 0.88
metric_box_x = 1.2
metric_box_w = 13.6

title_y    = 20.0
stage_y    = title_y    - GAP_12 - STAGE_H
sub_top    = stage_y    - GAP_23
sub_bot    = sub_top    - SUB_BOX_H
label_y    = sub_bot    - GAP_34 - LABEL_H / 2
metric_top = label_y    - LABEL_H / 2 - GAP_45
metric_y   = metric_top - METRIC_H
result_y   = metric_y   - GAP_56 - RESULT_H

total_height = title_y + TITLE_H - result_y + 1.0
fig_h = max(32, int(total_height * 2.5))
fig, ax = plt.subplots(figsize=(24, fig_h))
ax.set_xlim(0, 16)
ax.set_ylim(result_y - 0.5, title_y + TITLE_H + 0.5)
ax.axis('off')

def draw_rect_box(ax, x, y, w, h, color, linestyle='-', edgecolor='#333333', lw=2.0):
    rect = mpatches.FancyBboxPatch((x, y), w, h,
        boxstyle="round,pad=0.07",
        facecolor=color, edgecolor=edgecolor,
        linewidth=2.8, linestyle=linestyle)
    ax.add_patch(rect)

def draw_text(ax, x, y, text, fp, fontsize=14, ha='center', va='center', rotation=0, linespacing=1.3):
    ax.text(x, y, text, ha=ha, va=va,
            fontproperties=fp, fontsize=fontsize,
            rotation=rotation, multialignment='center',
            linespacing=linespacing, color='black')

def draw_arrow(ax, x1, y1, x2, y2, lw=2.8):
    ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
        arrowprops=dict(arrowstyle='-|>,head_length=0.6,head_width=0.35', 
                        color='#444444', lw=lw, 
                        connectionstyle='arc3,rad=0'))

def draw_fork_arrows(ax, x_center, y_top, x_targets, y_bottom, lw=2.8):
    """汇总分叉箭头：带圆角转弯"""
    from matplotlib.path import Path
    import matplotlib.patches as mpatches
    
    y_mid = (y_top + y_bottom) / 2 + 0.1
    r = 0.25  # 圆角半径
    
    # 主干（从标题到横线位置）
    ax.plot([x_center, x_center], [y_top, y_mid], 
            color='#444444', linewidth=lw, solid_capstyle='round', zorder=4)
    
    for x in x_targets:
        if abs(x - x_center) < 0.1:
            # 中心分支直接下来
            ax.plot([x, x], [y_mid, y_bottom + 0.15],
                    color='#444444', linewidth=lw, solid_capstyle='round', zorder=4)
        else:
            # 横线部分
            if x < x_center:
                ax.plot([x_center, x + r], [y_mid, y_mid],
                        color='#444444', linewidth=lw, solid_capstyle='round', zorder=4)
            else:
                ax.plot([x_center, x - r], [y_mid, y_mid],
                        color='#444444', linewidth=lw, solid_capstyle='round', zorder=4)
            
            # 圆角（用二次贝塞尔曲线）
            if x < x_center:
                verts = [(x + r, y_mid), (x, y_mid), (x, y_mid - r)]
            else:
                verts = [(x - r, y_mid), (x, y_mid), (x, y_mid - r)]
            codes = [Path.MOVETO, Path.CURVE3, Path.CURVE3]
            path = Path(verts, codes)
            patch = mpatches.PathPatch(path, facecolor='none', edgecolor='#444444', 
                                       linewidth=lw, capstyle='round', zorder=4)
            ax.add_patch(patch)
            
            # 竖线（从圆角下方到箭头）
            ax.plot([x, x], [y_mid - r, y_bottom + 0.15],
                    color='#444444', linewidth=lw, solid_capstyle='round', zorder=4)
        
        # 箭头
        ax.annotate('', xy=(x, y_bottom), xytext=(x, y_bottom + 0.15),
            arrowprops=dict(arrowstyle='-|>,head_length=0.6,head_width=0.35', 
                            color='#444444', lw=lw))

stage_centers = [col_w * i + col_w / 2 for i in range(5)]

# ── 标题框（粉色）──
draw_rect_box(ax, 1.2, title_y, 13.6, TITLE_H, PINK, lw=2.5)
draw_text(ax, 8, title_y + TITLE_H / 2, '联邦农作物图像识别协同训练系统研究', fp, fontsize=43)

# ── 五大阶段框（蓝色）──
stage_labels = [
    '数据准备阶段', '模型构建阶段', '优化策略设计阶段', '实验验证阶段', '可视化系统实现阶段',
]
for i, label in enumerate(stage_labels):
    cx = col_w * i
    draw_rect_box(ax, cx + 0.12, stage_y, col_w - 0.24, STAGE_H, BLUE, lw=2.0)
    draw_text(ax, stage_centers[i], stage_y + STAGE_H / 2, label, fp, fontsize=30)

# 汇总分叉箭头：标题 -> 5个阶段
draw_fork_arrows(ax, 8, title_y - 0.1, stage_centers, stage_y + STAGE_H + 0.1)

# ── 子模块竖条（粉色 + 淡蓝背景虚线框）──
submodules = [
    ['标\n准\n数\n据\n集\n准\n备', '自\n建\n数\n据\n集\n准\n备', '扩\n展\n泛\n化\n集\n准\n备'],
    ['主\n干\n网\n络\n迁\n移\n适\n配', '预\n训\n练\n参\n数\n冻\n结', '联\n邦\n聚\n合\n框\n架\n构\n建'],
    ['动\n态\n近\n端\n约\n束\n衰\n减\n机\n制', 'MixUp\n混\n合\n数\n据\n增\n强'],
    ['基\n线\n算\n法\n对\n比\n实\n验', '消\n融\n实\n验\n分\n析', '主\n干\n网\n络\n对\n比\n泛\n化\n评\n估'],
    ['联\n邦\n训\n练\n监\n控\n模\n块', '农\n作\n物\n分\n类\n识\n别\n模\n块'],
]

for i, items in enumerate(submodules):
    cx = col_w * i
    n = len(items)
    # 虚线框加淡蓝背景
    rect = mpatches.Rectangle((cx + 0.08, sub_bot), col_w - 0.16, SUB_BOX_H,
        fill=True, facecolor=BLUE, alpha=0.3,
        edgecolor='#666666', linewidth=1.5, linestyle='--')
    ax.add_patch(rect)
    draw_arrow(ax, stage_centers[i], stage_y - (GAP_23 - ARROW_LEN) / 2,
               stage_centers[i], sub_top + (GAP_23 - ARROW_LEN) / 2)

    total_sub_w = n * sub_w
    gap = (col_w - 0.16 - total_sub_w) / (n + 1)
    sy = sub_bot + (SUB_BOX_H - SUB_H) / 2
    for j, text in enumerate(items):
        sx = cx + 0.08 + gap * (j + 1) + sub_w * j
        draw_rect_box(ax, sx, sy, sub_w, SUB_H, PINK)
        draw_text(ax, sx + sub_w / 2, sy + SUB_H / 2, text, fp, fontsize=30, rotation=0, linespacing=1.1)

    draw_arrow(ax, stage_centers[i], sub_bot - (GAP_34 - ARROW_LEN) / 2,
               stage_centers[i], label_y + LABEL_H / 2 + (GAP_34 - ARROW_LEN) / 2)

# ── 标签行（带蓝色背景框）──
row_labels = ['数据支撑', '模型基础', '性能优化', '效果验证', '系统实现']
for i, (lx, lb) in enumerate(zip(stage_centers, row_labels)):
    draw_rect_box(ax, lx - 0.7, label_y - 0.25, 1.4, 0.5, BLUE, lw=1.0)
    draw_text(ax, lx, label_y, lb, fp, fontsize=28)
    if i < len(row_labels) - 1:
        ax.annotate('', xy=(stage_centers[i + 1] - 0.9, label_y),
                    xytext=(lx + 0.9, label_y),
                    arrowprops=dict(arrowstyle='-|>,head_length=0.6,head_width=0.35', color='#444444', lw=2.8))

# ── 性能优化→评估指标箭头 ──
perf_x = stage_centers[2]
draw_arrow(ax, perf_x, label_y - LABEL_H / 2 - (GAP_45 - ARROW_LEN) / 2,
           perf_x, metric_top + (GAP_45 - ARROW_LEN) / 2)

# ── 评估指标竖条（金黄色）──
metrics = ['识\n别\n准\n确\n率', '收\n敛\n速\n度', '泛\n化\n能\n力', '推\n理\n时\n间', '预\n处\n理\n时\n间']
rect = mpatches.Rectangle((metric_box_x, metric_y), metric_box_w, METRIC_H,
    fill=True, facecolor=METRIC_COLOR, alpha=0.3,
    edgecolor='#666666', linewidth=1.5, linestyle='--')
ax.add_patch(rect)

m_sub_w = 0.88
n_m = len(metrics)
m_gap = (metric_box_w - n_m * m_sub_w) / (n_m + 1)
m_sy = metric_y + (METRIC_H - SUB_H) / 2
for i, m in enumerate(metrics):
    mx = metric_box_x + m_gap * (i + 1) + m_sub_w * i
    draw_rect_box(ax, mx, m_sy, m_sub_w, SUB_H, METRIC_BOX)
    draw_text(ax, mx + m_sub_w / 2, m_sy + SUB_H / 2, m, fp, fontsize=30, rotation=0, linespacing=1.1)

# ── 评估指标→总结框箭头 ──
draw_arrow(ax, 8, metric_y - (GAP_56 - ARROW_LEN) / 2,
           8, result_y + RESULT_H + (GAP_56 - ARROW_LEN) / 2)

# ── 总结框（蓝色）──
draw_rect_box(ax, 1.2, result_y, 13.6, RESULT_H, BLUE, lw=2.5)
draw_text(ax, 8, result_y + RESULT_H / 2, '分类系统性能提升与应用验证', fp, fontsize=30)

output_path = os.path.expanduser('~/技术路线图_v13.png')
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close()

size = os.path.getsize(output_path)
print(f"保存成功: {output_path} ({size} bytes)")
