""" 生成 RoPE 旋转位置编码的频率分解结构示意图 (fig:ch2_rope_frequency) 三个子图:(a) 维度对旋转频率 (b) 位置响应衰减曲线 (c) 复平面旋转轨迹 """ import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl # ---------- 全局样式(支持中文) ---------- plt.rcParams.update({ "font.family": ["Arial Unicode MS", "Songti SC", "SimSun", "serif"], "mathtext.fontset": "cm", "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 11, "legend.fontsize": 8, "figure.dpi": 150, "axes.unicode_minus": False, }) # ---------- RoPE 参数 ---------- d_h = 64 # 头维度 omega = 10000.0 # RoPE base n_pairs = d_h // 2 # 32 个维度对 indices = np.arange(n_pairs) theta = omega ** (-2.0 * indices / d_h) # 旋转频率 # ---------- 创建图 ---------- fig, axes = plt.subplots(1, 3, figsize=(14, 4.0)) # ===== (a) 维度对旋转频率 ===== ax = axes[0] ax.semilogy(indices, theta, "o-", color="#2563EB", markersize=3.5, linewidth=1.5) ax.set_xlabel(r"维度对索引 $i$") ax.set_ylabel(r"旋转频率 $\theta_i$(对数刻度)", labelpad=2) ax.tick_params(axis="y", pad=1) ax.set_title("(a)", fontweight="bold", loc="left") # 标注高频 / 低频区域 bbox_style = dict(boxstyle="round,pad=0.25", fc="white", ec="none", alpha=0.85) ax.annotate( "高频\n(局部位置敏感)", xy=(2, theta[2]), xytext=(12, theta[1] * 1.5), fontsize=9, ha="center", arrowprops=dict(arrowstyle="->", color="#DC2626", lw=1.2), color="#DC2626", bbox=bbox_style, ) ax.annotate( "低频\n(全局依赖)", xy=(28, theta[28]), xytext=(18, theta[15] * 0.8), fontsize=9, ha="center", arrowprops=dict(arrowstyle="->", color="#059669", lw=1.2), color="#059669", bbox=bbox_style, ) ax.set_xlim(-1, n_pairs) ax.grid(True, alpha=0.3) # ===== (b) 位置响应衰减曲线 ===== ax = axes[1] rel_pos = np.arange(0, 129) # 相对位置距离 # 选 4 条代表性维度对 selected = [0, 5, 15, 31] colors_b = ["#DC2626", "#F59E0B", "#2563EB", "#059669"] labels_b = [rf"$i={s}$" for s in selected] for s, c, lb in zip(selected, colors_b, labels_b): # 注意力得分贡献 ∝ cos(theta_i * delta) score = np.cos(theta[s] * rel_pos) ax.plot(rel_pos, score, color=c, linewidth=1.5, label=lb) ax.set_xlabel(r"相对位置距离 $|t_1 - t_2|$") ax.set_ylabel("注意力得分贡献", labelpad=2) ax.tick_params(axis="y", pad=1) ax.set_title("(b)", fontweight="bold", loc="left") ax.legend(loc="upper right", framealpha=0.9) ax.set_xlim(0, 128) ax.set_ylim(-1.15, 1.15) ax.axhline(0, color="gray", linewidth=0.5, linestyle="--") ax.grid(True, alpha=0.3) # ===== (c) 复平面旋转轨迹 ===== ax = axes[2] T = 64 # 位置范围 positions = np.arange(0, T + 1) # 高频维度对 i=0 i_high = 0 angles_high = theta[i_high] * positions x_high = np.cos(angles_high) y_high = np.sin(angles_high) # 低频维度对 i=31 i_low = 31 angles_low = theta[i_low] * positions x_low = np.cos(angles_low) y_low = np.sin(angles_low) # 单位圆 circle_t = np.linspace(0, 2 * np.pi, 200) ax.plot(np.cos(circle_t), np.sin(circle_t), color="gray", linewidth=0.6, linestyle="--", alpha=0.5) # 绘制轨迹 ax.plot(x_high, y_high, color="#DC2626", linewidth=1.5, label=rf"$i={i_high}$(高频)", alpha=0.85) ax.plot(x_high[0], y_high[0], "o", color="#DC2626", markersize=5) ax.plot(x_high[-1], y_high[-1], "s", color="#DC2626", markersize=5) ax.plot(x_low, y_low, color="#059669", linewidth=2.0, label=rf"$i={i_low}$(低频)", alpha=0.85) ax.plot(x_low[0], y_low[0], "o", color="#059669", markersize=5) ax.plot(x_low[-1], y_low[-1], "s", color="#059669", markersize=5) # 起点标注 — 放在圆外右上方,避免重叠 ax.annotate( r"$t=0$", xy=(x_high[0], y_high[0]), xytext=(1.15, 0.25), fontsize=8, color="#555", arrowprops=dict(arrowstyle="->", color="#555", lw=0.8), bbox=bbox_style, ) # 高频终点标注 — 放在左下 ax.annotate( rf"$t={T}$", xy=(x_high[-1], y_high[-1]), xytext=(-1.15, -0.9), fontsize=8, color="#DC2626", arrowprops=dict(arrowstyle="->", color="#DC2626", lw=0.8), bbox=bbox_style, ) # 低频终点标注 — 放在右侧偏下 ax.annotate( rf"$t={T}$", xy=(x_low[-1], y_low[-1]), xytext=(0.7, -0.9), fontsize=8, color="#059669", arrowprops=dict(arrowstyle="->", color="#059669", lw=0.8), bbox=bbox_style, ) ax.set_xlabel("实部 Re") ax.set_ylabel("虚部 Im", labelpad=2) ax.tick_params(axis="y", pad=1) ax.set_title("(c)", fontweight="bold", loc="left") ax.set_aspect("equal") ax.legend(loc="lower left", framealpha=0.9, fontsize=8) ax.set_xlim(-1.4, 1.4) ax.set_ylim(-1.4, 1.4) ax.grid(True, alpha=0.3) ax.axhline(0, color="gray", linewidth=0.4) ax.axvline(0, color="gray", linewidth=0.4) # ---------- 保存 ---------- plt.tight_layout(w_pad=2.5) plt.subplots_adjust(left=0.06) output_path = "assets/2_rope_frequency.pdf" fig.savefig(output_path, bbox_inches="tight", pad_inches=0.1) print(f"Saved to {output_path}") plt.show()