kewei 2bd03665dd add rwkv 4 months ago
..
model_v1.py 2bd03665dd add rwkv 4 months ago
model_v2.py 2bd03665dd add rwkv 4 months ago
model_v3.py 2bd03665dd add rwkv 4 months ago
model_v4.py 2bd03665dd add rwkv 4 months ago
model_v5.py 2bd03665dd add rwkv 4 months ago
model_v6.py 2bd03665dd add rwkv 4 months ago
readme.md 2bd03665dd add rwkv 4 months ago

readme.md

RWKV 模型版本比较报告

本文档旨在比较 RWKV 模型的六个不同版本(v1 至 v6),并详细介绍每个版本的特性、改进和性能。以下是对这六个模型版本的详细分析和比较。


版本概述

RWKV v1

  • 初始版本,基础实现 RWKV 时间混合和通道混合模块。
  • 主要特性:
    • 使用时间混合(Time-mix)和通道混合(Channel-mix)模块。
    • 采用标准的线性层和嵌入层初始化。
    • 使用掩码来处理因果关系。

RWKV v2

  • 增强版本,改进了时间混合和通道混合的实现。
  • 主要改进:
    • 优化了模型加载和状态管理。
    • 增加了新的归一化方法。
    • 提升了训练和推理效率。

RWKV v3

  • 进一步优化的版本,主要集中在性能提升。
  • 主要改进:
    • 调整了层数和嵌入维度,提供更灵活的配置选项。
    • 增加了预处理步骤,提高了推理效率。

RWKV v4

  • 增加了对更大规模模型的支持,提升了模型复杂度。
  • 主要改进:
    • 支持24层和1024维嵌入。
    • 增加了更多的参数调优选项。

RWKV v5

  • 继续提升模型规模和复杂度,并优化了模型架构。
  • 主要改进:
    • 支持更高的嵌入维度(2048)。
    • 引入了新的时间混合和通道混合方法,提升了模型性能。

RWKV v6

  • 最新版本,综合了前几个版本的改进,并引入了一些新特性。
  • 主要改进:
    • 增加了对更大词汇表(65536)的支持。
    • 采用了改进的混合方法,提升了推理速度和准确性。

详细比较

1. 架构与实现

  • 时间混合(Time-Mix)和通道混合(Channel-Mix)

    • v1:基本实现,功能完备。

      class RWKV_TimeMix(nn.Module):
      def __init__(self, config, layer_id):
          super().__init__()
          assert config.n_attn % config.n_head == 0
          self.layer_id = layer_id
          self.ctx_len = config.ctx_len
          self.n_head = config.n_head
          self.head_size = config.n_attn // config.n_head
      
          with torch.no_grad(): # initial time_w curves for better convergence
              ww = torch.ones(config.n_head, config.ctx_len)
              curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
              for h in range(config.n_head):
                  if h < config.n_head - 1:
                      decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
                  else:
                      decay_speed = 0.0
                  ww[h] = torch.exp(curve * decay_speed)
          self.time_w = nn.Parameter(ww)
          self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
          self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
          self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
                          
          self.time_shift = nn.ZeroPad2d((0,0,1,-1))
      
          self.key = nn.Linear(config.n_embd, config.n_attn)
          self.value = nn.Linear(config.n_embd, config.n_attn)
          self.receptance = nn.Linear(config.n_embd, config.n_attn)
      
          self.output = nn.Linear(config.n_attn, config.n_embd)
      
      • v2:优化了时间混合和通道混合,提升了计算效率。 ```python class RWKV_ChannelMix(nn.Module): def init(self, layer_id): super().init() self.layer_id = layer_id

        self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))

        hidden_sz = 4 * n_embd self.key = nn.Linear(n_embd, hidden_sz, bias=False) self.receptance = nn.Linear(n_embd, n_embd, bias=False) self.value = nn.Linear(hidden_sz, n_embd, bias=False)

      def forward(self, x):

      x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
      k = self.key(x)
      k = torch.square(torch.relu(k))
      kv = self.value(k)
      rkv = torch.sigmoid(self.receptance(x)) * kv
      return rkv
      

      ```

    • v3:进一步优化,并增加了灵活的配置选项。

      class RWKV_ChannelMix(nn.Module):
      def __init__(self, layer_id):
          super().__init__()
          self.layer_id = layer_id
      
          self.time_shift = nn.ZeroPad2d((0,0,1,-1))
          self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
          self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
                  
          hidden_sz = 4 * n_embd
          self.key = nn.Linear(n_embd, hidden_sz, bias=False)
          self.receptance = nn.Linear(n_embd, n_embd, bias=False)
          self.value = nn.Linear(hidden_sz, n_embd, bias=False)
      
      def forward(self, x):
          xx = self.time_shift(x)
          xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
          xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
          k = self.key(xk)
          k = torch.square(torch.relu(k))
          kv = self.value(k)
          rkv = torch.sigmoid(self.receptance(xr)) * kv
          return rkv
      
      • v4:支持更大规模模型,提升了时间混合和通道混合的处理能力。

        class RWKV_RNN(torch.jit.ScriptModule):
        def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # set torch to inference mode
                    
        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
        for k in w.keys():
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
            else: w[k] = w[k].float() # convert to f32 type
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys():
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])
        
    • v5:引入了新的混合方法,进一步提升了性能。

      class RWKV_RNN(MyModule):
      def __init__(self, args):
          super().__init__()
          self.args = args
          self.eval() # set torch to inference mode
                  
          w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
          for k in w.keys():
              w[k] = w[k].float() # convert to f32 type
              if      '.time_' in k: w[k] = w[k].squeeze()
              if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
              if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
      
          self.n_head = w['blocks.0.att.time_decay'].shape[0]
          self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
                  
          self.w = types.SimpleNamespace() # set self.w from w
          self.w.blocks = {}
          for k in w.keys():
              parts = k.split('.')
              last = parts.pop()
              here = self.w
              for p in parts:
                  if p.isdigit():
                      p = int(p)
                      if p not in here: here[p] = types.SimpleNamespace()
                      here = here[p]
                  else:
                      if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                      here = getattr(here, p)
              setattr(here, last, w[k])
      
      • v6:改进了混合方法,提升了整体性能和效率。

        class RWKV_RNN(MyModule):
        def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # set torch to inference mode
                    
        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
        for k in w.keys():
            w[k] = w[k].float() # convert to f32 type
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
        self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
        self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys():
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])
        

2. 模型规模

  • 层数和嵌入维度
    • v1:标准配置,适用于基础任务。
    • v2:支持12层和768维嵌入。
    • v3:提供12层和24层选项,嵌入维度为768和1024。
    • v4:支持24层和1024维嵌入。
    • v5:嵌入维度增加至2048。
    • v6:进一步增加模型复杂度,支持更大词汇表。

3. 性能与效率

  • 推理速度和资源消耗
    • v1:基础实现,资源消耗适中。
    • v2:优化后,推理速度提升。
    • v3:预处理步骤的增加,提高了推理效率。
    • v4:更大规模模型下的性能优化。
    • v5:新的混合方法提升了推理速度和准确性。
    • v6:综合改进,推理速度和资源利用进一步优化。

4. 词汇表和上下文长度

  • 词汇表大小和上下文长度支持
    • v1-v4:词汇表大小和上下文长度逐步增加。
    • v5:支持更大上下文长度,适应复杂任务。
    • v6:支持最大65536的词汇表和更长的上下文长度。

总结

RWKV 模型在每个版本中不断优化和提升,从基础的 v1 到复杂且高效的 v6,模型的性能和功能都有了显著的进步。以下是每个版本的推荐使用场景:

  • v1:适用于基础任务和初步研究。
  • v2:适用于需要更高效率和优化的任务。
  • v3:适用于需要灵活配置和更高性能的应用。
  • v4:适用于大规模模型的训练和推理任务。
  • v5:适用于需要高精度和高效推理的复杂任务。
  • v6:适用于最前沿的研究和应用,提供最高的性能和效率。

每个版本在其特定的改进点上都为用户提供了更好的选择,根据具体需求选择合适的版本将能充分发挥 RWKV 模型的优势。