{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "12e4e650-e036-4ca1-83a7-806911fdf0c1", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import math, json, time, types, copy, sys, os\n", "import torch\n", "from torch.nn import functional as F\n", "import torch.nn as nn\n", "\n", "from transformers import PreTrainedTokenizerFast\n", "\n", "np.set_printoptions(precision=4, suppress=True, linewidth=200)" ] }, { "cell_type": "raw", "id": "d189546c-4d6e-4643-91fb-aab36cfd1935", "metadata": {}, "source": [ "模型下载地址(本脚本请用如下169m参数的,不要用这个链接):https://hf-mirror.com/BlinkDL/rwkv-3-pile-430m/resolve/main/RWKV-3-Pile-430M-20220817-10602.pth?download=true\n", "模型下载地址(用这个):https://hf-mirror.com/BlinkDL/rwkv-3-pile-169m/resolve/main/RWKV-3-Pile-20220720-10704.pth?download=true" ] }, { "cell_type": "code", "execution_count": 2, "id": "9c02901b-63f9-41fd-bf0b-0c28a2ee57cb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "* running on cpu\n" ] } ], "source": [ "RUN_DEVICE = 'cpu' # cpu cuda\n", "ctx_len = 768\n", "n_layer = 12\n", "n_embd = 768\n", "# n_layer = 24\n", "# n_embd = 1024\n", "\n", "# ---> download RWKV-3 169M model from https://huggingface.co/BlinkDL/rwkv-3-pile-169m/tree/main\n", "\n", "# MODEL_NAME = '/data1/ckw/RWKV-3-Pile-430M-20220817-10602'\n", "MODEL_NAME = '/data1/ckw/RWKV-3-Pile-20220720-10704'\n", "K_EPS = 1e-8\n", "\n", "vocab_size = 50277\n", "VOCAB_NAME = '20B_tokenizer.json'\n", "\n", "print(f'\\n* running on {RUN_DEVICE}')" ] }, { "cell_type": "markdown", "id": "65ece3b5-1682-4bf0-a4b0-67192230cd47", "metadata": {}, "source": [ "在v3版本中,对RWKV模型的TimeMix和ChannelMix部分进行了一些修改,主要变化如下:\n", "\n", "1. **ChannelMix**:\n", " - 在v2版本中,ChannelMix模块的计算中,采用了一个时间混合的技巧,即通过时间滑窗对输入进行一定程度的平滑处理。而在v3版本中,这一技巧被移除,不再使用时间混合。\n", " - 另外,ChannelMix模块中的`time_mix_k`和`time_mix_r`参数在v3版本中仍然存在,但在v2版本中不存在。\n", "\n", "2. **TimeMix**:\n", " - 在v2版本中,TimeMix模块的计算中,同样使用了时间混合技巧,并且采用了`time_mix`参数来控制时间混合的程度。而在v3版本中,这一技巧被移除,不再使用时间混合。\n", " - 同时,v3版本中的TimeMix模块取消了`time_mix`参数,而是直接在计算中采用了时间滑窗对输入进行处理。\n", " - 另外,v3版本中取消了`time_mix_v`参数,在计算中不再对值进行时间混合。\n", "\n", "总体来说,v3版本对时间混合和通道混合的操作进行了简化和调整,取消了之前版本中的一些复杂性,使模型更加简洁和高效。" ] }, { "cell_type": "code", "execution_count": 3, "id": "4599dc46-75c5-4e47-af1f-012bb72954e6", "metadata": {}, "outputs": [], "source": [ "class RWKV_ChannelMix(nn.Module):\n", " def __init__(self, layer_id):\n", " super().__init__()\n", " self.layer_id = layer_id\n", "\n", " self.time_shift = nn.ZeroPad2d((0,0,1,-1))\n", " self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))\n", " self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))\n", " \n", " hidden_sz = 4 * n_embd\n", " self.key = nn.Linear(n_embd, hidden_sz, bias=False)\n", " self.receptance = nn.Linear(n_embd, n_embd, bias=False)\n", " self.value = nn.Linear(hidden_sz, n_embd, bias=False)\n", "\n", " def forward(self, x):\n", " xx = self.time_shift(x)\n", " xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)\n", " xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)\n", "\n", " k = self.key(xk)\n", " k = torch.square(torch.relu(k))\n", " kv = self.value(k)\n", " \n", " rkv = torch.sigmoid(self.receptance(xr)) * kv\n", " return rkv" ] }, { "cell_type": "code", "execution_count": 4, "id": "a991b609-53bc-4643-82a3-86d6df15fe09", "metadata": {}, "outputs": [], "source": [ "class RWKV_TimeMix(nn.Module):\n", " def __init__(self, layer_id):\n", " super().__init__()\n", " self.layer_id = layer_id\n", " self.time_decay = nn.Parameter(torch.ones(n_embd, 1))\n", " self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)\n", " self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))\n", " \n", " self.time_shift = nn.ZeroPad2d((0,0,1,-1))\n", " self.time_mix_k = nn.Parameter(torch.ones(1,1,n_embd))\n", " self.time_mix_v = nn.Parameter(torch.ones(1,1,n_embd))\n", " self.time_mix_r = nn.Parameter(torch.ones(1,1,n_embd))\n", "\n", " self.key = nn.Linear(n_embd, n_embd, bias=False)\n", " self.value = nn.Linear(n_embd, n_embd, bias=False)\n", " self.receptance = nn.Linear(n_embd, n_embd, bias=False)\n", "\n", " self.output = nn.Linear(n_embd, n_embd, bias=False)\n", "\n", " def forward(self, x):\n", " B, T, C = x.size()\n", "\n", " xx = self.time_shift(x)\n", " xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)\n", " xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)\n", " xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)\n", "\n", " k = self.key(xk).transpose(-1, -2)\n", " v = self.value(xv).transpose(-1, -2)\n", " r = self.receptance(xr)\n", "\n", " k = torch.clamp(k, max=60)\n", " k = torch.exp(k)\n", "\n", " kv = k * v\n", "\n", " self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)\n", " w = torch.exp(self.time_w)\n", " \n", " w = w[:,-T:].unsqueeze(1)\n", " wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)\n", " wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS\n", "\n", " rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)\n", " \n", " rwkv = self.output(rwkv)\n", " return rwkv" ] }, { "cell_type": "markdown", "id": "fcf76ee6-9264-44e9-be3a-85a48db31fd4", "metadata": {}, "source": [ "在这段代码中,`xk`、`xv` 和 `xr` 分别表示经过时间混合后的输入,用于后续计算的键、值和接收向量。在v3版本中,这三个向量是分别计算的,而在v2版本中,这三个向量是在同一个时间混合过程中计算的。\n", "\n", "具体来说,`xk` 是通过时间混合后的输入用于计算键向量,`xv` 是用于计算值向量,`xr` 是用于计算接收向量。这种分别计算的方法可以使得模型更加灵活,能够更好地适应不同的数据特征。\n", "\n", "分别计算键、值和接收向量的方法,类似的操作在一些相关的论文或研究中也有提及,例如在自注意力机制中,通常会分别计算键、值和查询向量。这种分别计算的方法可以提高模型的灵活性和表现能力,因此在实践中被广泛应用。" ] }, { "cell_type": "markdown", "id": "c5b3e83c-1fda-498b-92f9-0e98744fd999", "metadata": {}, "source": [ "### RWKV的Block\n", "\n", "RWKV的Block是一个基本的模块,它结合了时间混合(TimeMix)和通道混合(ChannelMix)操作。Block中的每个模块(时间混合和通道混合)都通过归一化和残差连接来处理输入数据,从而增强模型的稳定性和性能。\n", "\n", "### 主要组件和操作\n", "\n", "1. **LayerNorm**:用于归一化输入,增强训练的稳定性。\n", " - `self.ln1` 和 `self.ln2` 分别在时间混合和通道混合之前对输入进行归一化。\n", " \n", "2. **时间混合(TimeMix)**:结合当前时间步和前一个时间步的信息,捕获时间依赖性。\n", " - `self.att = RWKV_TimeMix(layer_id)` 初始化时间混合模块。\n", " \n", "3. **通道混合(ChannelMix)**:在不同通道间进行混合,增强模型的表达能力。\n", " - `self.ffn = RWKV_ChannelMix(layer_id)` 初始化通道混合模块。\n", " \n", "4. **残差连接**:通过将混合操作的输出加回到原始输入上,保持信息流动并增强模型的梯度传播能力。\n", "\n", "通过这种设计,RWKV的Block能够高效地处理序列数据,结合时间和通道信息,提高模型的表现。" ] }, { "cell_type": "code", "execution_count": 5, "id": "2630a3ff-6a8c-49d9-a4d9-05098b424d02", "metadata": {}, "outputs": [], "source": [ "class Block(nn.Module):\n", " def __init__(self, layer_id):\n", " super().__init__()\n", " self.layer_id = layer_id\n", "\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", " if self.layer_id == 0: #增加了初始的归一化\n", " self.ln0 = nn.LayerNorm(n_embd)\n", " \n", " self.att = RWKV_TimeMix(layer_id)\n", " self.ffn = RWKV_ChannelMix(layer_id)\n", "\n", " def forward(self, x):\n", " if self.layer_id == 0:\n", " x = self.ln0(x)\n", " x = x + self.att(self.ln1(x))\n", " x = x + self.ffn(self.ln2(x))\n", " return x" ] }, { "cell_type": "markdown", "id": "bed3d877-9847-450b-8c92-5fd5c19a1811", "metadata": {}, "source": [ "接下来,实现了RWKV模型的主要部分:\n", "\n", "1. **模型加载和预处理**:代码中加载模型权重并进行时间相关权重的预处理。\n", "2. **LayerNorm**:在`LN`方法中实现了层归一化,关于LayerNorm的使用。\n", "3. **前馈网络(FF)和自注意力(SA)**:`FF`方法实现了前馈网络的计算,`SA`方法实现了自注意力机制的计算。这两部分对应TimeMix和ChannelMix的详细计算。\n", "4. **运行模型**:`run`方法实现了模型的整体运行逻辑,依次通过每一层,并最终输出结果。即模型的运行和推理过程。" ] }, { "cell_type": "code", "execution_count": 11, "id": "e1c04132-b0de-4a8e-9769-2b4619e2e70a", "metadata": {}, "outputs": [], "source": [ "time_buf = {}\n", "\n", "class RWKV_RNN():\n", " def __init__(self, MODEL_NAME=MODEL_NAME):\n", " print('\\nloading RWKV-RNN', MODEL_NAME)\n", " self.ctx_len = ctx_len\n", " self.n_layer = n_layer\n", " self.n_embd = n_embd\n", " self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)\n", "\n", " self.w = types.SimpleNamespace()\n", " \n", " w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))\n", "\n", " for x in w.keys():\n", " if '.time_' in x:\n", " w[x] = w[x].squeeze()\n", " if '.time_decay' in x:\n", " w[x] = torch.exp(-torch.exp(w[x]))\n", " if '.time_first' in x:\n", " w[x] = torch.exp(w[x])\n", " \n", " xx = x.split('.')\n", " here = self.w\n", " for i in range(len(xx)):\n", " if xx[i].isdigit():\n", " ii = int(xx[i])\n", " if ii not in here:\n", " here[ii] = types.SimpleNamespace()\n", " here = here[ii]\n", " else:\n", " if i == len(xx) - 1:\n", " setattr(here, xx[i], w[x])\n", " elif not hasattr(here, xx[i]):\n", " if xx[i+1].isdigit():\n", " setattr(here, xx[i], {})\n", " else:\n", " setattr(here, xx[i], types.SimpleNamespace())\n", " here = getattr(here, xx[i])\n", "\n", " self.clear()\n", " \n", " def clear(self):\n", " self.xx = {}\n", " self.aa = {}\n", " self.bb = {}\n", " def save(self, target):\n", " target.xx = copy.deepcopy(self.xx)\n", " target.aa = copy.deepcopy(self.aa)\n", " target.bb = copy.deepcopy(self.bb)\n", " def load(self, target):\n", " self.xx = copy.deepcopy(target.xx)\n", " self.aa = copy.deepcopy(target.aa)\n", " self.bb = copy.deepcopy(target.bb)\n", "\n", " def LN(self, xx, w):\n", " return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)\n", "\n", " def FF(self, xx, w, name):\n", " if name not in self.xx:\n", " self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)\n", " xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)\n", " xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)\n", "\n", " self.xx[name] = xx\n", "\n", " r = torch.sigmoid(w.receptance.weight @ xr)\n", " k = torch.square(torch.relu(w.key.weight @ xk))\n", " kv = w.value.weight @ k\n", "\n", " return r * kv\n", "\n", " def SA(self, xx, w, name):\n", " if name not in self.xx:\n", " self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)\n", " self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)\n", " self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)\n", "\n", " xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)\n", " xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)\n", " xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)\n", "\n", " self.xx[name] = xx\n", "\n", " r = torch.sigmoid(w.receptance.weight @ xr)\n", "\n", " k = torch.exp(torch.clamp(w.key.weight @ xk, max=60))\n", " v = w.value.weight @ xv\n", " kv = k * v\n", "\n", " a = self.aa[name] + w.time_first * kv\n", " b = self.bb[name] + w.time_first * k\n", " self.aa[name] = w.time_decay * self.aa[name] + kv\n", " self.bb[name] = w.time_decay * self.bb[name] + k\n", "\n", " rwkv = r * a / (b + K_EPS)\n", "\n", " return w.output.weight @ rwkv\n", "\n", " def run(self, ctx):\n", " w = self.w\n", " x = w.emb.weight[ctx[-1]]\n", "\n", " x = self.LN(x, w.blocks[0].ln0) #相比v2版本,增加了一个初始的归一化\n", " for i in range(n_layer):\n", " x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')\n", " x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')\n", "\n", " x = self.LN(x, w.ln_out)\n", "\n", " x = w.head.weight @ x\n", " x = x.tolist()\n", "\n", " return x" ] }, { "cell_type": "code", "execution_count": 12, "id": "99e34d05-8296-4d2f-9f45-29912ddde8f1", "metadata": {}, "outputs": [], "source": [ "# Edit model.py to set CPU / CUDA mode. Runs on CPU by default.\n", "\n", "TEMPERATURE = 1.0\n", "TOP_P = 0.7\n", "\n", "DEBUG_DEBUG = False\n", "LENGTH_OF_EACH = 333\n", "NUM_TRIALS = 3\n", "\n", "context = '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'\n", "\n", "##############################################################################################################" ] }, { "cell_type": "code", "execution_count": 13, "id": "4a620d9b-e93e-4c6e-b015-99246c1e036c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "loading RWKV-RNN /data1/ckw/RWKV-3-Pile-20220720-10704\n" ] } ], "source": [ "model = RWKV_RNN()" ] }, { "cell_type": "markdown", "id": "0546efbe-600b-47de-8873-1174686beaa4", "metadata": {}, "source": [ "下面我们从给定的输出logits中进行采样,以生成一个新的token。它实现了**温度调节采样**和**核采样(Top-p采样)**,具体步骤如下:\n", "\n", "1. **Softmax转换**:将模型输出的logits通过softmax函数转换为概率分布。\n", "2. **排序和累积概率计算**:对概率从高到低进行排序,并计算累积概率分布。\n", "3. **核采样**:\n", " - 计算累积概率超过`top_p`的最小值,确定截断值`cutoff`。\n", " - 将所有低于截断值的概率置为0,从而保留最重要的`top_p`部分概率。\n", "4. **温度调节**:如果`temperature`不为1,则调整概率分布,使得概率分布更平滑或更尖锐。\n", "5. **采样**:从调整后的概率分布中采样一个值,返回对应的索引。\n", "\n", "这种方法在文本生成任务中尤为常用,通过调节`temperature`和`top_p`参数,可以控制生成文本的多样性和质量。\n", "\n", "v2和v3的采样方法是没有变化的。" ] }, { "cell_type": "code", "execution_count": 14, "id": "46c528e4-ef2f-4136-9022-5c1fc08b185b", "metadata": {}, "outputs": [], "source": [ "def sample_logits(out, temperature=1.0, top_p=None):\n", " # 将输出转化为概率分布(通过softmax函数)\n", " probs = F.softmax(torch.tensor(out), dim=-1)\n", " \n", " # 按概率从高到低排序\n", " sorted_probs, _ = torch.sort(probs, descending=True)\n", "\n", " # 计算累积概率分布\n", " cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()\n", " \n", " # 根据累积概率和top_p计算截断值(cutoff)\n", " cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n", " \n", " # 将低于截断值的概率置为0\n", " probs[probs < cutoff] = 0\n", "\n", " # 如果temperature不等于1,则对概率进行温度调节\n", " if temperature != 1.0:\n", " probs = probs.pow(1.0 / temperature)\n", "\n", " # 从调整后的概率分布中采样一个值并返回\n", " return torch.multinomial(probs, num_samples=1)[0]\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "14581d92-4590-4303-9405-d56f65a3c675", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The technology focuses on learning from human behavior and not information.\n", "\n", "We are an independent Data Whalechina team. This team of trained Data Whalechina team is available to answer any question and provide guidance.\n", "\n", "The information provided on this site is not legal advice, and should not be construed as legal advice. You should consult a lawyer for advice regarding your specific situation.\n", "\n", "We take no responsibility for the content, accuracy, or completeness of any information on this site or any information provided from third parties.\n", "\n", "Cookies are used to store information about you so that we can remember and provide you with products and services you may have used. You can opt-out of our use of cookies at any time. To learn more, please read our cookie policy.\n", "\n", "Cookies are tiny files stored on your computer that allow you to access information such as your email address, your phone number, and other information about you. By clicking on the \"cookies\" button, you agree to our use of cookies. To learn more about our use of cookies, please see our privacy policy.\n", "\n", "We do not store any of your personal information on our servers. Instead, we store the information you give us in a cookie file on your computer. We do not store or share your information with third parties. You can choose to not use cookies at all.\n", "\n", "How to Opt Out of Your Personal Information\n", "\n", "We do not sell or share your personal information with third parties. Your information is only used to improve our services and to provide our products and services. By continuing to use our website, you accept that you will\n", "----------------------------------------------------------------------\n", "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The company focuses on creating educational tools and apps that are accessible for all learners, as well as providing students with tools and tools to improve their performance in school and work.\n", "\n", "The purpose of DataWhalechina is to create a safe and open environment for learning in a professional manner. We believe that learning is important for the mental health of children and that learning is always an important part of life.\n", "\n", "DataWhalechina is located in Shanghai, China. It offers more than 40 schools, kindergartens, elementary schools, and institutions for elementary school children to gain knowledge and skills for using the technologies of artificial intelligence and data science.\n", "\n", "We believe that AI is an important way to improve the lives of the people around us. The use of AI in the classroom has led to a dramatic increase in the number of students in higher education. DataWhalechina provides tools that can help people become more efficient, flexible, and competent in their learning.\n", "\n", "In this article, we will discuss the importance of the quality of education in the creation of life in the world. DataWhalechina will also give an insight into the various technologies that can be used to increase the quality of life in a sustainable way.\n", "\n", "By doing this article, we want to give you a simple, practical, and useful reference on how to create data-driven education and create an environment where data is being used in the most effective and efficient way.\n", "\n", "Here is a guide for how to create a data-driven education.\n", "\n", "In this article, we will discuss how to create a learning environment for students to understand the needs\n", "----------------------------------------------------------------------\n", "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. In the field of artificial intelligence, they believe that AI should be able to act on its own to improve its performance.\n", "\n", "External links\n", "\n", "Notes\n", "\n", "Category:Chinese artificial intelligence\n", "Category:Education in China\n", "Category:Education in Shanghai\n", "Category:Learning in China\n", "Category:Science education in China\n", "Category:Information science\n", "Category:Science education in China\n", "Category:Educational technology in China\n", "Category:Taiwanese education\n", "Category:Chinese technology\n", "Category:Science education in China\n", "Category:Science education in Taiwan\n", "Category:Technology in society\n", "Category:Education in Taiwan\n", "Category:Science education in Asia\n", "Category:Education in the United States\n", "Category:Education in Taiwan\n", "Category:Educational technology in China\n", "Category:Education in the United States\n", "Category:Educational technology in Asia\n", "Category:Education in Taiwan\n", "Category:Educational technology in the United States\n", "Category:Education in China\n", "Category:Educational technology in Taiwan\n", "Category:Education in China\n", "Category:Educational technology in China\n", "Category:Education in the United States\n", "Category:Education in the United States\n", "Category:Education in the United Kingdom\n", "Category:Education in the United States\n", "Category:Education in Canada\n", "Category:Educational technology in Canada\n", "Category:Educational technology in the United States\n", "Category:Educational technology in Canada\n", "Category:Education in China\n", "Category:Education in China\n", "Category:Education in Taiwan\n", "Category:Educational technology in China\n", "Category:Education in China\n", "Category:Education in Taiwan\n", "Category:Education in Canada\n", "Category:Educational technology in China\n", "Category:Education in\n", "----------------------------------------------------------------------" ] } ], "source": [ "for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):\n", " ctx = [model.tokenizer.encode(context)][0]\n", " src_len = len(ctx)\n", " print(context, end='')\n", "\n", " model.clear()\n", " if TRIAL == 0: # build the RNN hidden state?\n", " init_state = types.SimpleNamespace()\n", " for i in range(src_len if DEBUG_DEBUG else src_len):\n", " x = ctx[:i+1]\n", " if i == src_len - 1:\n", " init_state.out = model.run(x)\n", " else:\n", " model.run(x)\n", " model.save(init_state)\n", " else:\n", " model.load(init_state)\n", "\n", " if DEBUG_DEBUG:\n", " out = init_state.out\n", " print('\\n', np.array(x), '==>', np.array(\n", " out), np.max(out), np.min(out))\n", "\n", " for i in range(src_len, src_len + (0 if DEBUG_DEBUG else LENGTH_OF_EACH)):\n", " x = ctx[:i+1]\n", " x = x[-model.ctx_len:]\n", "\n", " if i == src_len:\n", " out = copy.deepcopy(init_state.out) # load the RNN hidden state\n", " else:\n", " out = model.run(x) # run the RNN\n", "\n", " out[0] = -999999999 # disable <|endoftext|>\n", "\n", " char = sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)\n", " char = char.item()\n", " print(model.tokenizer.decode(char), end='', flush=True)\n", "\n", " ctx += [char]\n", " print('\\n' + '-' * 70, end='')" ] }, { "cell_type": "code", "execution_count": null, "id": "bc37dfc5-71c7-44a5-a4d4-b71534574872", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "kewei-ai", "language": "python", "name": "kewei-ai" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }