170 lines
5.0 KiB
Python
170 lines
5.0 KiB
Python
"""
|
|
生成 RoPE 旋转位置编码的频率分解结构示意图 (fig:ch2_rope_frequency)
|
|
三个子图:(a) 维度对旋转频率 (b) 位置响应衰减曲线 (c) 复平面旋转轨迹
|
|
"""
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib as mpl
|
|
|
|
# ---------- 全局样式(支持中文) ----------
|
|
plt.rcParams.update({
|
|
"font.family": ["Arial Unicode MS", "Songti SC", "SimSun", "serif"],
|
|
"mathtext.fontset": "cm",
|
|
"font.size": 9,
|
|
"axes.labelsize": 10,
|
|
"axes.titlesize": 11,
|
|
"legend.fontsize": 8,
|
|
"figure.dpi": 150,
|
|
"axes.unicode_minus": False,
|
|
})
|
|
|
|
# ---------- RoPE 参数 ----------
|
|
d_h = 64 # 头维度
|
|
omega = 10000.0 # RoPE base
|
|
n_pairs = d_h // 2 # 32 个维度对
|
|
indices = np.arange(n_pairs)
|
|
theta = omega ** (-2.0 * indices / d_h) # 旋转频率
|
|
|
|
# ---------- 创建图 ----------
|
|
fig, axes = plt.subplots(1, 3, figsize=(14, 4.0))
|
|
|
|
# ===== (a) 维度对旋转频率 =====
|
|
ax = axes[0]
|
|
ax.semilogy(indices, theta, "o-", color="#2563EB", markersize=3.5, linewidth=1.5)
|
|
ax.set_xlabel(r"维度对索引 $i$")
|
|
ax.set_ylabel(r"旋转频率 $\theta_i$(对数刻度)", labelpad=2)
|
|
ax.tick_params(axis="y", pad=1)
|
|
ax.set_title("(a)", fontweight="bold", loc="left")
|
|
|
|
# 标注高频 / 低频区域
|
|
bbox_style = dict(boxstyle="round,pad=0.25", fc="white", ec="none", alpha=0.85)
|
|
ax.annotate(
|
|
"高频\n(局部位置敏感)",
|
|
xy=(2, theta[2]),
|
|
xytext=(12, theta[1] * 1.5),
|
|
fontsize=9,
|
|
ha="center",
|
|
arrowprops=dict(arrowstyle="->", color="#DC2626", lw=1.2),
|
|
color="#DC2626",
|
|
bbox=bbox_style,
|
|
)
|
|
ax.annotate(
|
|
"低频\n(全局依赖)",
|
|
xy=(28, theta[28]),
|
|
xytext=(18, theta[15] * 0.8),
|
|
fontsize=9,
|
|
ha="center",
|
|
arrowprops=dict(arrowstyle="->", color="#059669", lw=1.2),
|
|
color="#059669",
|
|
bbox=bbox_style,
|
|
)
|
|
ax.set_xlim(-1, n_pairs)
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
# ===== (b) 位置响应衰减曲线 =====
|
|
ax = axes[1]
|
|
rel_pos = np.arange(0, 129) # 相对位置距离
|
|
# 选 4 条代表性维度对
|
|
selected = [0, 5, 15, 31]
|
|
colors_b = ["#DC2626", "#F59E0B", "#2563EB", "#059669"]
|
|
labels_b = [rf"$i={s}$" for s in selected]
|
|
|
|
for s, c, lb in zip(selected, colors_b, labels_b):
|
|
# 注意力得分贡献 ∝ cos(theta_i * delta)
|
|
score = np.cos(theta[s] * rel_pos)
|
|
ax.plot(rel_pos, score, color=c, linewidth=1.5, label=lb)
|
|
|
|
ax.set_xlabel(r"相对位置距离 $|t_1 - t_2|$")
|
|
ax.set_ylabel("注意力得分贡献", labelpad=2)
|
|
ax.tick_params(axis="y", pad=1)
|
|
ax.set_title("(b)", fontweight="bold", loc="left")
|
|
ax.legend(loc="upper right", framealpha=0.9)
|
|
ax.set_xlim(0, 128)
|
|
ax.set_ylim(-1.15, 1.15)
|
|
ax.axhline(0, color="gray", linewidth=0.5, linestyle="--")
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
# ===== (c) 复平面旋转轨迹 =====
|
|
ax = axes[2]
|
|
T = 64 # 位置范围
|
|
positions = np.arange(0, T + 1)
|
|
|
|
# 高频维度对 i=0
|
|
i_high = 0
|
|
angles_high = theta[i_high] * positions
|
|
x_high = np.cos(angles_high)
|
|
y_high = np.sin(angles_high)
|
|
|
|
# 低频维度对 i=31
|
|
i_low = 31
|
|
angles_low = theta[i_low] * positions
|
|
x_low = np.cos(angles_low)
|
|
y_low = np.sin(angles_low)
|
|
|
|
# 单位圆
|
|
circle_t = np.linspace(0, 2 * np.pi, 200)
|
|
ax.plot(np.cos(circle_t), np.sin(circle_t), color="gray", linewidth=0.6, linestyle="--", alpha=0.5)
|
|
|
|
# 绘制轨迹
|
|
ax.plot(x_high, y_high, color="#DC2626", linewidth=1.5, label=rf"$i={i_high}$(高频)", alpha=0.85)
|
|
ax.plot(x_high[0], y_high[0], "o", color="#DC2626", markersize=5)
|
|
ax.plot(x_high[-1], y_high[-1], "s", color="#DC2626", markersize=5)
|
|
|
|
ax.plot(x_low, y_low, color="#059669", linewidth=2.0, label=rf"$i={i_low}$(低频)", alpha=0.85)
|
|
ax.plot(x_low[0], y_low[0], "o", color="#059669", markersize=5)
|
|
ax.plot(x_low[-1], y_low[-1], "s", color="#059669", markersize=5)
|
|
|
|
# 起点标注 — 放在圆外右上方,避免重叠
|
|
ax.annotate(
|
|
r"$t=0$",
|
|
xy=(x_high[0], y_high[0]),
|
|
xytext=(1.15, 0.25),
|
|
fontsize=8,
|
|
color="#555",
|
|
arrowprops=dict(arrowstyle="->", color="#555", lw=0.8),
|
|
bbox=bbox_style,
|
|
)
|
|
|
|
# 高频终点标注 — 放在左下
|
|
ax.annotate(
|
|
rf"$t={T}$",
|
|
xy=(x_high[-1], y_high[-1]),
|
|
xytext=(-1.15, -0.9),
|
|
fontsize=8,
|
|
color="#DC2626",
|
|
arrowprops=dict(arrowstyle="->", color="#DC2626", lw=0.8),
|
|
bbox=bbox_style,
|
|
)
|
|
|
|
# 低频终点标注 — 放在右侧偏下
|
|
ax.annotate(
|
|
rf"$t={T}$",
|
|
xy=(x_low[-1], y_low[-1]),
|
|
xytext=(0.7, -0.9),
|
|
fontsize=8,
|
|
color="#059669",
|
|
arrowprops=dict(arrowstyle="->", color="#059669", lw=0.8),
|
|
bbox=bbox_style,
|
|
)
|
|
|
|
ax.set_xlabel("实部 Re")
|
|
ax.set_ylabel("虚部 Im", labelpad=2)
|
|
ax.tick_params(axis="y", pad=1)
|
|
ax.set_title("(c)", fontweight="bold", loc="left")
|
|
ax.set_aspect("equal")
|
|
ax.legend(loc="lower left", framealpha=0.9, fontsize=8)
|
|
ax.set_xlim(-1.4, 1.4)
|
|
ax.set_ylim(-1.4, 1.4)
|
|
ax.grid(True, alpha=0.3)
|
|
ax.axhline(0, color="gray", linewidth=0.4)
|
|
ax.axvline(0, color="gray", linewidth=0.4)
|
|
|
|
# ---------- 保存 ----------
|
|
plt.tight_layout(w_pad=2.5)
|
|
plt.subplots_adjust(left=0.06)
|
|
output_path = "assets/2_rope_frequency.pdf"
|
|
fig.savefig(output_path, bbox_inches="tight", pad_inches=0.1)
|
|
print(f"Saved to {output_path}")
|
|
plt.show()
|