kewei 2bd03665dd add rwkv | 4 months ago | |
---|---|---|
.. | ||
model.py | 4 months ago | |
readme.md | 4 months ago |
由于时间久远,没有找到对应的预训练脚本,对于v1主要做源码分析,v2-v6均可以找到载入模型实现的脚本。本文档详细解析 RWKV v1 版本的核心代码,具体包括初始化、时间混合模块、通道混合模块和多头注意力机制。代码使用 PyTorch 实现,具有良好的可读性和扩展性。
首先,我们定义了一个用于初始化线性层和嵌入层的函数 RWKV_Init
。
def RWKV_Init(module, config):
for m in module.modules():
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
with torch.no_grad():
name = '[unknown weight]'
for name, parameter in module.named_parameters():
if id(m.weight) == id(parameter):
break
shape = m.weight.data.shape
gain = 1.0
scale = 1.0
if isinstance(m, nn.Linear):
if m.bias is not None:
m.bias.data.zero_()
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == config.vocab_size and shape[1] == config.n_embd:
scale = config.rwkv_emb_scale
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == config.vocab_size and shape[1] == config.n_embd:
scale = config.rwkv_emb_scale
if hasattr(m, 'scale_init'):
scale = m.scale_init
print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)
gain *= scale
if gain == 0:
nn.init.zeros_(m.weight)
elif gain > 0:
nn.init.orthogonal_(m.weight, gain=gain)
else:
nn.init.normal_(m.weight, mean=0, std=-gain)
该函数遍历模块中的所有线性层和嵌入层,并根据特定条件初始化其权重。对于线性层,若偏置存在,则将其初始化为零;对于嵌入层,计算权重的增益和缩放因子。根据不同的条件使用不同的初始化方法,例如正交初始化或正态初始化。
RWKV_TimeMix
类实现了时间混合机制。
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head
with torch.no_grad():
ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)])
for h in range(config.n_head):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else:
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
self.output = nn.Linear(config.n_attn, config.n_embd)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
def forward(self, x):
B, T, C = x.size()
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:]
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
k = torch.clamp(k, max=30, min=-60)
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)
kv = (k * v).view(B, T, self.n_head, self.head_size)
wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)
rwkv = torch.sigmoid(r) * wkv / sum_k
rwkv = self.output(rwkv)
return rwkv * self.time_gamma[:T, :]
该类实现了时间混合机制,通过时间权重矩阵 time_w
对输入进行变换。time_w
矩阵根据头数和上下文长度初始化,随后对输入进行时间维度上的变换和混合。key
、value
和 receptance
三个线性层分别生成键、值和接收信号,并通过 sigmoid 函数计算输出。
RWKV_ChannelMix
类实现了通道混合机制。
class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
hidden_sz = 5 * config.n_ffn // 2
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)
self.receptance.scale_init = 0
self.weight.scale_init = 0
def forward(self, x):
B, T, C = x.size()
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
k = self.key(x)
v = self.value(x)
r = self.receptance(x)
wkv = self.weight(F.mish(k) * v)
rwkv = torch.sigmoid(r) * wkv
return rwkv
该类实现了通道混合机制,通过 key
、value
和 receptance
三个线性层对输入进行变换。key
和 value
层生成的张量经过 mish
激活函数变换后再通过 weight
层进行加权变换,最后与 receptance
层生成的接收信号相乘得到最终输出。
MHA_rotary
类实现了多头注意力机制,并引入了旋转位置编码。
class MHA_rotary(nn.Module):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id
assert config.n_attn % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_attn // config.n_head
if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.query = nn.Linear(config.n_embd, config.n_attn)
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
self.output = nn.Linear(config.n_attn, config.n_embd)
def forward(self, x):
B, T, C = x.size()
if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2
)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
cos, sin = self.rotary_emb(q, seq_len=T)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = torch.cat((q, query_pass), dim=-1)
k = torch.cat((k, key_pass), dim=-1)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim = -1)
x = att @ v
x = x.transpose(1, 2).contiguous().view(B, T, -1)
x = self.output(x)
return x
该类实现了多头注意力机制,并引入了旋转位置编码以增强模型的位置信息表示。通过 query
、key
和 value
三个线性层生成查询、键和值,再通过旋转编码和自注意力机制计算最终输出。
RWKV v1 模型实现了基础的时间混合、通道混合和多头注意力机制。通过对输入进行多维度的变换和混合,实现了对复杂特征的提取和表示。上述代码段展示了模型的核心组件和主要计算过程,为后续版本的优化和改进提供了坚实的基础。