Rev for Pre

This commit is contained in:
panda361
2026-03-23 22:07:08 +08:00
parent de3d1598b0
commit 1fd3c5771f
12 changed files with 303 additions and 121 deletions

View File

@@ -0,0 +1,169 @@
"""
生成 RoPE 旋转位置编码的频率分解结构示意图 (fig:ch2_rope_frequency)
三个子图:(a) 维度对旋转频率 (b) 位置响应衰减曲线 (c) 复平面旋转轨迹
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
# ---------- 全局样式(支持中文) ----------
plt.rcParams.update({
"font.family": ["Arial Unicode MS", "Songti SC", "SimSun", "serif"],
"mathtext.fontset": "cm",
"font.size": 9,
"axes.labelsize": 10,
"axes.titlesize": 11,
"legend.fontsize": 8,
"figure.dpi": 150,
"axes.unicode_minus": False,
})
# ---------- RoPE 参数 ----------
d_h = 64 # 头维度
omega = 10000.0 # RoPE base
n_pairs = d_h // 2 # 32 个维度对
indices = np.arange(n_pairs)
theta = omega ** (-2.0 * indices / d_h) # 旋转频率
# ---------- 创建图 ----------
fig, axes = plt.subplots(1, 3, figsize=(14, 4.0))
# ===== (a) 维度对旋转频率 =====
ax = axes[0]
ax.semilogy(indices, theta, "o-", color="#2563EB", markersize=3.5, linewidth=1.5)
ax.set_xlabel(r"维度对索引 $i$")
ax.set_ylabel(r"旋转频率 $\theta_i$(对数刻度)", labelpad=2)
ax.tick_params(axis="y", pad=1)
ax.set_title("(a)", fontweight="bold", loc="left")
# 标注高频 / 低频区域
bbox_style = dict(boxstyle="round,pad=0.25", fc="white", ec="none", alpha=0.85)
ax.annotate(
"高频\n(局部位置敏感)",
xy=(2, theta[2]),
xytext=(12, theta[1] * 1.5),
fontsize=9,
ha="center",
arrowprops=dict(arrowstyle="->", color="#DC2626", lw=1.2),
color="#DC2626",
bbox=bbox_style,
)
ax.annotate(
"低频\n(全局依赖)",
xy=(28, theta[28]),
xytext=(18, theta[15] * 0.8),
fontsize=9,
ha="center",
arrowprops=dict(arrowstyle="->", color="#059669", lw=1.2),
color="#059669",
bbox=bbox_style,
)
ax.set_xlim(-1, n_pairs)
ax.grid(True, alpha=0.3)
# ===== (b) 位置响应衰减曲线 =====
ax = axes[1]
rel_pos = np.arange(0, 129) # 相对位置距离
# 选 4 条代表性维度对
selected = [0, 5, 15, 31]
colors_b = ["#DC2626", "#F59E0B", "#2563EB", "#059669"]
labels_b = [rf"$i={s}$" for s in selected]
for s, c, lb in zip(selected, colors_b, labels_b):
# 注意力得分贡献 ∝ cos(theta_i * delta)
score = np.cos(theta[s] * rel_pos)
ax.plot(rel_pos, score, color=c, linewidth=1.5, label=lb)
ax.set_xlabel(r"相对位置距离 $|t_1 - t_2|$")
ax.set_ylabel("注意力得分贡献", labelpad=2)
ax.tick_params(axis="y", pad=1)
ax.set_title("(b)", fontweight="bold", loc="left")
ax.legend(loc="upper right", framealpha=0.9)
ax.set_xlim(0, 128)
ax.set_ylim(-1.15, 1.15)
ax.axhline(0, color="gray", linewidth=0.5, linestyle="--")
ax.grid(True, alpha=0.3)
# ===== (c) 复平面旋转轨迹 =====
ax = axes[2]
T = 64 # 位置范围
positions = np.arange(0, T + 1)
# 高频维度对 i=0
i_high = 0
angles_high = theta[i_high] * positions
x_high = np.cos(angles_high)
y_high = np.sin(angles_high)
# 低频维度对 i=31
i_low = 31
angles_low = theta[i_low] * positions
x_low = np.cos(angles_low)
y_low = np.sin(angles_low)
# 单位圆
circle_t = np.linspace(0, 2 * np.pi, 200)
ax.plot(np.cos(circle_t), np.sin(circle_t), color="gray", linewidth=0.6, linestyle="--", alpha=0.5)
# 绘制轨迹
ax.plot(x_high, y_high, color="#DC2626", linewidth=1.5, label=rf"$i={i_high}$(高频)", alpha=0.85)
ax.plot(x_high[0], y_high[0], "o", color="#DC2626", markersize=5)
ax.plot(x_high[-1], y_high[-1], "s", color="#DC2626", markersize=5)
ax.plot(x_low, y_low, color="#059669", linewidth=2.0, label=rf"$i={i_low}$(低频)", alpha=0.85)
ax.plot(x_low[0], y_low[0], "o", color="#059669", markersize=5)
ax.plot(x_low[-1], y_low[-1], "s", color="#059669", markersize=5)
# 起点标注 — 放在圆外右上方,避免重叠
ax.annotate(
r"$t=0$",
xy=(x_high[0], y_high[0]),
xytext=(1.15, 0.25),
fontsize=8,
color="#555",
arrowprops=dict(arrowstyle="->", color="#555", lw=0.8),
bbox=bbox_style,
)
# 高频终点标注 — 放在左下
ax.annotate(
rf"$t={T}$",
xy=(x_high[-1], y_high[-1]),
xytext=(-1.15, -0.9),
fontsize=8,
color="#DC2626",
arrowprops=dict(arrowstyle="->", color="#DC2626", lw=0.8),
bbox=bbox_style,
)
# 低频终点标注 — 放在右侧偏下
ax.annotate(
rf"$t={T}$",
xy=(x_low[-1], y_low[-1]),
xytext=(0.7, -0.9),
fontsize=8,
color="#059669",
arrowprops=dict(arrowstyle="->", color="#059669", lw=0.8),
bbox=bbox_style,
)
ax.set_xlabel("实部 Re")
ax.set_ylabel("虚部 Im", labelpad=2)
ax.tick_params(axis="y", pad=1)
ax.set_title("(c)", fontweight="bold", loc="left")
ax.set_aspect("equal")
ax.legend(loc="lower left", framealpha=0.9, fontsize=8)
ax.set_xlim(-1.4, 1.4)
ax.set_ylim(-1.4, 1.4)
ax.grid(True, alpha=0.3)
ax.axhline(0, color="gray", linewidth=0.4)
ax.axvline(0, color="gray", linewidth=0.4)
# ---------- 保存 ----------
plt.tight_layout(w_pad=2.5)
plt.subplots_adjust(left=0.06)
output_path = "assets/2_rope_frequency.pdf"
fig.savefig(output_path, bbox_inches="tight", pad_inches=0.1)
print(f"Saved to {output_path}")
plt.show()