kewei 2bd03665dd add rwkv | 4 月之前 | |
---|---|---|
.. | ||
model_v1.py | 4 月之前 | |
model_v2.py | 4 月之前 | |
model_v3.py | 4 月之前 | |
model_v4.py | 4 月之前 | |
model_v5.py | 4 月之前 | |
model_v6.py | 4 月之前 | |
readme.md | 4 月之前 |
本文档旨在比较 RWKV 模型的六个不同版本(v1 至 v6),并详细介绍每个版本的特性、改进和性能。以下是对这六个模型版本的详细分析和比较。
RWKV v1
RWKV v2
RWKV v3
RWKV v4
RWKV v5
RWKV v6
1. 架构与实现
时间混合(Time-Mix)和通道混合(Channel-Mix):
v1:基本实现,功能完备。
class RWKV_TimeMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
assert config.n_attn % config.n_head == 0
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head
with torch.no_grad(): # initial time_w curves for better convergence
ww = torch.ones(config.n_head, config.ctx_len)
curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance
for h in range(config.n_head):
if h < config.n_head - 1:
decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1))
else:
decay_speed = 0.0
ww[h] = torch.exp(curve * decay_speed)
self.time_w = nn.Parameter(ww)
self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.key = nn.Linear(config.n_embd, config.n_attn)
self.value = nn.Linear(config.n_embd, config.n_attn)
self.receptance = nn.Linear(config.n_embd, config.n_attn)
self.output = nn.Linear(config.n_attn, config.n_embd)
v2:优化了时间混合和通道混合,提升了计算效率。 ```python class RWKV_ChannelMix(nn.Module): def init(self, layer_id): super().init() self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1)) self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd self.key = nn.Linear(n_embd, hidden_sz, bias=False) self.receptance = nn.Linear(n_embd, n_embd, bias=False) self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
k = self.key(x)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(x)) * kv
return rkv
```
v3:进一步优化,并增加了灵活的配置选项。
class RWKV_ChannelMix(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
hidden_sz = 4 * n_embd
self.key = nn.Linear(n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(hidden_sz, n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
v4:支持更大规模模型,提升了时间混合和通道混合的处理能力。
class RWKV_RNN(torch.jit.ScriptModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
else: w[k] = w[k].float() # convert to f32 type
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
v5:引入了新的混合方法,进一步提升了性能。
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_decay'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
v6:改进了混合方法,提升了整体性能和效率。
class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
w[k] = w[k].float() # convert to f32 type
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
2. 模型规模
3. 性能与效率
4. 词汇表和上下文长度
RWKV 模型在每个版本中不断优化和提升,从基础的 v1 到复杂且高效的 v6,模型的性能和功能都有了显著的进步。以下是每个版本的推荐使用场景:
每个版本在其特定的改进点上都为用户提供了更好的选择,根据具体需求选择合适的版本将能充分发挥 RWKV 模型的优势。