kewei 2bd03665dd add rwkv 4 月之前
..
model.py 2bd03665dd add rwkv 4 月之前
readme.md 2bd03665dd add rwkv 4 月之前

readme.md

RWKV v1 源码详细解读

由于时间久远,没有找到对应的预训练脚本,对于v1主要做源码分析,v2-v6均可以找到载入模型实现的脚本。本文档详细解析 RWKV v1 版本的核心代码,具体包括初始化、时间混合模块、通道混合模块和多头注意力机制。代码使用 PyTorch 实现,具有良好的可读性和扩展性。


模型初始化

首先,我们定义了一个用于初始化线性层和嵌入层的函数 RWKV_Init

def RWKV_Init(module, config): 
    for m in module.modules():
        if not isinstance(m, (nn.Linear, nn.Embedding)):
            continue
        with torch.no_grad():
            name = '[unknown weight]'
            for name, parameter in module.named_parameters(): 
                if id(m.weight) == id(parameter):
                    break

            shape = m.weight.data.shape
            gain = 1.0  
            scale = 1.0 

            if isinstance(m, nn.Linear):
                if m.bias is not None:
                    m.bias.data.zero_()
                if shape[0] > shape[1]:
                    gain = math.sqrt(shape[0] / shape[1])
                if shape[0] == config.vocab_size and shape[1] == config.n_embd: 
                    scale = config.rwkv_emb_scale

            if isinstance(m, nn.Embedding):
                gain = math.sqrt(max(shape[0], shape[1]))
                if shape[0] == config.vocab_size and shape[1] == config.n_embd: 
                    scale = config.rwkv_emb_scale

            if hasattr(m, 'scale_init'):
                scale = m.scale_init

            print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name)

            gain *= scale
            if gain == 0:
                nn.init.zeros_(m.weight)
            elif gain > 0:
                nn.init.orthogonal_(m.weight, gain=gain)
            else:
                nn.init.normal_(m.weight, mean=0, std=-gain)

该函数遍历模块中的所有线性层和嵌入层,并根据特定条件初始化其权重。对于线性层,若偏置存在,则将其初始化为零;对于嵌入层,计算权重的增益和缩放因子。根据不同的条件使用不同的初始化方法,例如正交初始化或正态初始化。


时间混合模块

RWKV_TimeMix 类实现了时间混合机制。

class RWKV_TimeMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        assert config.n_attn % config.n_head == 0
        self.layer_id = layer_id
        self.ctx_len = config.ctx_len
        self.n_head = config.n_head
        self.head_size = config.n_attn // config.n_head

        with torch.no_grad(): 
            ww = torch.ones(config.n_head, config.ctx_len)
            curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) 
            for h in range(config.n_head):
                if h < config.n_head - 1:
                    decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
                else:
                    decay_speed = 0.0
                ww[h] = torch.exp(curve * decay_speed)
        self.time_w = nn.Parameter(ww)

        self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
        self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
        self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
                
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))

        self.key = nn.Linear(config.n_embd, config.n_attn)
        self.value = nn.Linear(config.n_embd, config.n_attn)
        self.receptance = nn.Linear(config.n_embd, config.n_attn)

        self.output = nn.Linear(config.n_attn, config.n_embd)

        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0

    def forward(self, x):
        B, T, C = x.size()
        TT = self.ctx_len
        w = F.pad(self.time_w, (0, TT))
        w = torch.tile(w, [TT])
        w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
        w = w[:, :, TT-1:] 
        w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]

        x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)

        k = self.key(x)
        v = self.value(x)
        r = self.receptance(x)

        k = torch.clamp(k, max=30, min=-60) 
        k = torch.exp(k)
        sum_k = torch.cumsum(k, dim=1)

        kv = (k * v).view(B, T, self.n_head, self.head_size)

        wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1)

        rwkv = torch.sigmoid(r) * wkv / sum_k

        rwkv = self.output(rwkv)

        return rwkv * self.time_gamma[:T, :]

该类实现了时间混合机制,通过时间权重矩阵 time_w 对输入进行变换。time_w 矩阵根据头数和上下文长度初始化,随后对输入进行时间维度上的变换和混合。keyvaluereceptance 三个线性层分别生成键、值和接收信号,并通过 sigmoid 函数计算输出。


通道混合模块

RWKV_ChannelMix 类实现了通道混合机制。

class RWKV_ChannelMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        
        hidden_sz = 5 * config.n_ffn // 2 
        self.key = nn.Linear(config.n_embd, hidden_sz)
        self.value = nn.Linear(config.n_embd, hidden_sz)
        self.weight = nn.Linear(hidden_sz, config.n_embd)
        self.receptance = nn.Linear(config.n_embd, config.n_embd)

        self.receptance.scale_init = 0
        self.weight.scale_init = 0

    def forward(self, x):
        B, T, C = x.size()
        
        x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
        k = self.key(x)
        v = self.value(x)
        r = self.receptance(x)
        
        wkv = self.weight(F.mish(k) * v) 

        rwkv = torch.sigmoid(r) * wkv

        return rwkv

该类实现了通道混合机制,通过 keyvaluereceptance 三个线性层对输入进行变换。keyvalue 层生成的张量经过 mish 激活函数变换后再通过 weight 层进行加权变换,最后与 receptance 层生成的接收信号相乘得到最终输出。


多头注意力机制

MHA_rotary 类实现了多头注意力机制,并引入了旋转位置编码。

class MHA_rotary(nn.Module):
    def __init__(self, config, layer_id, time_shift = False):
        super().__init__()
        self.layer_id = layer_id
        assert config.n_attn % config.n_head == 0
        self.n_head = config.n_head
        self.ctx_len = config.ctx_len
        self.head_size = config.n_attn // config.n_head

        if time_shift:
            self.time_shift = nn.ZeroPad2d((0,0,1,-1))

        self.query = nn.Linear(config.n_embd, config.n_attn)
        self.key = nn.Linear(config.n_embd, config.n_attn)
        self.value = nn.Linear(config.n_embd, config.n_attn)

        self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
        
        self.rotary_ndims = int(self.head_size * 0.5)
        self.rotary_emb = RotaryEmbedding(self.rotary_ndims)

        self.output = nn.Linear(config.n_attn, config.n_embd)

    def forward(self, x):
        B, T, C = x.size()

        if hasattr(self, 'time_shift'):
            x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)

        q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)
        k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2

)
        v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)

        q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
        k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
        cos, sin = self.rotary_emb(q, seq_len=T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        q = torch.cat((q, query_pass), dim=-1)
        k = torch.cat((k, key_pass), dim=-1)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim = -1)

        x = att @ v
        x = x.transpose(1, 2).contiguous().view(B, T, -1)

        x = self.output(x)
        return x

该类实现了多头注意力机制,并引入了旋转位置编码以增强模型的位置信息表示。通过 querykeyvalue 三个线性层生成查询、键和值,再通过旋转编码和自注意力机制计算最终输出。


总结

RWKV v1 模型实现了基础的时间混合、通道混合和多头注意力机制。通过对输入进行多维度的变换和混合,实现了对复杂特征的提取和表示。上述代码段展示了模型的核心组件和主要计算过程,为后续版本的优化和改进提供了坚实的基础。