{ "cells": [ { "cell_type": "markdown", "id": "a06a0543", "metadata": {}, "source": [ "# DIN" ] }, { "cell_type": "markdown", "id": "bb1253ee", "metadata": {}, "source": [ "阿里妈妈在CTR预估领域有3篇比较有名的文章。\n", "\n", "2017年的深度兴趣网络, DIN(DeepInterestNetwork)。 \n", "\n", "2018年的深度兴趣演化网络, DIEN(DeepInterestEvolutionNetWork)。\n", "\n", "2019年的深度会话兴趣网络, DSIN(DeepSessionInterestNetWork)。\n", "\n", "这3篇文章的主要思想和相互关系用一句话分别概括如下:\n", "\n", "第1篇DIN说,用户的行为日志中只有一部分和当前候选广告有关。可以利用Attention机制从用户行为日志中建模出和当前候选广告相关的用户兴趣表示。我们试过涨点了嘻嘻嘻。\n", "\n", "第2篇DIEN说,用户最近的行为可能比较远的行为更加重要。可以用循环神经网络GRU建模用户兴趣随时间的演化。我们试过也涨点了嘿嘿嘿。\n", "\n", "第3篇DSIN说,用户在同一次会话中的行为高度相关,在不同会话间的行为则相对独立。可以把用户行为日志按照时间间隔分割成会话并用SelfAttention机制建模它们之间的相互作用。我们试过又涨点了哈哈哈。\n", "\n", "\n", "参考材料:\n", "\n", "* DIN论文: https://arxiv.org/pdf/1706.06978.pdf\n", "\n", "* 推荐系统中的注意力机制: https://zhuanlan.zhihu.com/p/51623339\n", "\n", "* 阿里经典兴趣网络: https://zhuanlan.zhihu.com/p/429433768\n", "\n", "* 从DIN到DIEN看阿里CTR算法的进化脉络: https://zhuanlan.zhihu.com/p/78365283\n", "\n", "* DIN+DIEN,机器学习唯一指定涨点技Attention: https://zhuanlan.zhihu.com/p/431131396\n", "\n", "* Attention机制简单总结: https://zhuanlan.zhihu.com/p/46313756\n", "\n", "* 代码实现参考: https://github.com/GitHub-HongweiZhang/prediction-flow\n", "\n", "\n", "本篇文章我们主要介绍DIN,下一篇文章我们介绍DIEN。 \n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d4aeef18", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ffc35888", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "4b6510eb", "metadata": {}, "source": [ "## 〇,Attention原理概述" ] }, { "cell_type": "markdown", "id": "5fbf832b", "metadata": {}, "source": [ "众所周知,Attention机制在深度学习领域是非常通用的涨分技巧。其主要作用是提升模型的自适应能力。\n", "\n", "Attention机制的一些常用功能和典型范例总结如下:\n", "\n", "* 1,动态特征选择,根据样本不同动态地赋予特征以不同的权重,典型范例如SENet中的SEAttention,DIN中的Attention.\n", "\n", "* 2,动态特征交互,动态地构建特征之间的交互强弱关系,提取高阶特征。典型范例如Transformer中的的Attention。\n", "\n", "* 3,动态模块集成,类似多模型融合集成,但是不同子模块的权重是动态的。典型范例如MOE中的门控注意力机制。\n", "\n", "\n", "在许多Attention机制的应用场景中,输入分成Query(Q)和Key(K)。Query是当前关注项的Embedding向量,Key是待和当前关注项进行匹配的Embedding向量。\n", "\n", "例如在广告CTR领域,Query就是当前待预估的广告,Key就是用户历史上点击过的广告,通过Attention机制建立当前待预估的广告和用户历史上点击过的广告的相关性强弱。\n", "\n", "又比如在NLP翻译领域,Query就是当前正在解码的译文单词词向量,Key就是原文单词序列的词向量,通过Attention机制可以建立译文单词和原文单词的对应关系。\n", "\n", "\n", "Attention机制的核心实现是计算注意力权重,一些的常用实现形式如下:\n", "\n", "$$attention = f(Q,K)$$\n", "\n", "\n", "* 1,多层感知机方法\n", "\n", "先将Query和Key进行拼接,然后接一个多层感知机。\n", "\n", "这种方法不需要Query和Key的向量长度相等,Query和Key之间的交互方式是通过学习获得的。\n", "\n", "$$f(Q,K) = mlp([Q;K])$$\n", "\n", "\n", "* 2,Bilinear方法\n", "\n", "通过一个权重矩阵直接建立Query和Key的关系映射,计算速度较快,但是需要Query和Key的向量长度相同。\n", "$$f(Q,K) = QWK^T$$\n", "\n", "\n", "* 3,Scaled-Dot Product\n", "\n", "这种方式直接求Query和Key的内积相似度,没有需要学习的参数,计算速度极快,需要Query和Key的向量长度相同。考虑到随着向量维度的增加,最后得到的权重也会增加,对其进行scaling。\n", "\n", "$$f(Q,K)=softmax(\\frac{QK^T}{\\sqrt{d_k}})$$\n", "\n" ] }, { "cell_type": "markdown", "id": "44c9d421", "metadata": {}, "source": [ "## 一,DIN原理解析" ] }, { "cell_type": "markdown", "id": "543a880c", "metadata": {}, "source": [ "阿里的展示广告系统主要用到了如下4类特征\n", "\n", "* (1) 用户画像特征。\n", "* (2) 用户行为特征,即用户点击过的商品。\n", "* (3) 待曝光的广告特征,广告其实也是商品。\n", "* (4) 上下文特征。\n", "\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h3eet3webnj20jj09rdgq.jpg)" ] }, { "cell_type": "markdown", "id": "78386dec", "metadata": {}, "source": [ "DIN、DIEN和DSIN主要聚焦在对用户行为日志的建模。\n", "\n", "用户行为日志反应的是用户的兴趣,如何从行为日志中建模出一个好的用户兴趣的表示?\n", "\n", "\n", "最基础的建模方法是 Embedding+SumPooling. 把用户的过去所有点击行为做Embedding, 然后求和。\n", "\n", "\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h3eezc7meij20d908x0t5.jpg)\n", "\n" ] }, { "cell_type": "markdown", "id": "ee4cc87a", "metadata": {}, "source": [ "这个SumPooling的实现不要太简单。" ] }, { "cell_type": "code", "execution_count": 54, "id": "c9c77017", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class SumPooling(nn.Module):\n", " def __init__(self, dim):\n", " super(SumPooling, self).__init__()\n", " self.dim = dim\n", "\n", " def forward(self, x):\n", " return torch.sum(x, self.dim)\n", " " ] }, { "cell_type": "markdown", "id": "5d397722", "metadata": {}, "source": [ "这种建模方式存在着一个巨大的缺陷,那就是用户的兴趣表示是确定的,和候选广告无关。\n", "\n", "不管来个啥候选广告,用户过去的所有行为日志全部一把梭哈丢进去求和。\n", "\n", "很显然,如果我们如果建模出和候选广告相关的用户兴趣表示,效果应该会好很多。\n", "\n", "那么,如何做到这一点呢?我们可以用候选广告来和用户历史行为日志求相关性,用相关性对历史行为日志做加权。\n", "\n", "这是很自然的,我们主要聚焦(Attention)用户历史行为日志中那些和候选广告相关的部分。\n", "\n", "于是,duang的一下,DIN模型的模型架构就出来了。\n", "\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h3ef9wpx4uj20mb0c0dh0.jpg) \n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "10debe94", "metadata": {}, "source": [ "这里注意力机制比较值得玩味,它是一种mlp形式的注意力结构,但在输入端不是简单地拼接了$Q$和$K$,而是将$Q,K,Q-K,Q*K$都一起打包拼接了,这样模型更加容易学习Q和K之间的相似性关系。\n", "\n", "此外,这里用mask技巧将keys中填充的的部分的注意力赋值为0,以及维度变换等一些实现上的细节,也是很值得揣摩的。\n" ] }, { "cell_type": "code", "execution_count": 55, "id": "2bc1ea8b", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn \n", "\n", "\n", "class MLP(nn.Module):\n", " def __init__(self, input_size, hidden_layers,\n", " dropout=0.0, batchnorm=True):\n", " super(MLP, self).__init__()\n", " modules = OrderedDict()\n", " previous_size = input_size\n", " for index, hidden_layer in enumerate(hidden_layers):\n", " modules[f\"dense{index}\"] = nn.Linear(previous_size, hidden_layer)\n", " if batchnorm:\n", " modules[f\"batchnorm{index}\"] = nn.BatchNorm1d(hidden_layer)\n", " modules[f\"activation{index}\"] = nn.PReLU() \n", " if dropout:\n", " modules[f\"dropout{index}\"] = nn.Dropout(dropout)\n", " previous_size = hidden_layer\n", " self.mlp = nn.Sequential(modules)\n", "\n", " def forward(self, x):\n", " return self.mlp(x)\n", " \n", " \n", "class Attention(nn.Module):\n", " def __init__(\n", " self,\n", " input_size,\n", " hidden_layers,\n", " dropout=0.0,\n", " batchnorm=True,\n", " return_scores=False):\n", " \n", " super().__init__()\n", " self.return_scores = return_scores\n", " \n", " self.mlp = MLP(\n", " input_size=input_size * 4,\n", " hidden_layers=hidden_layers,\n", " dropout=dropout,\n", " batchnorm=batchnorm,\n", " activation=activation)\n", " self.fc = nn.Linear(hidden_layers[-1], 1)\n", "\n", " def forward(self, query, keys, keys_length):\n", " \"\"\"\n", " Parameters\n", " ----------\n", " query: 2D tensor, [Batch, Hidden]\n", " keys: 3D tensor, [Batch, Time, Hidden]\n", " keys_length: 1D tensor, [Batch]\n", "\n", " Returns\n", " -------\n", " outputs: 2D tensor, [Batch, Hidden]\n", " \"\"\"\n", " \n", " batch_size, max_length, dim = keys.size()\n", "\n", " query = query.unsqueeze(1).expand(-1, max_length, -1)\n", "\n", " din_all = torch.cat(\n", " [query, keys, query - keys, query * keys], dim=-1)\n", "\n", " din_all = din_all.view(batch_size * max_length, -1)\n", "\n", " outputs = self.mlp(din_all)\n", "\n", " outputs = self.fc(outputs).view(batch_size, max_length) # [B, T]\n", "\n", " # Scale\n", " outputs = outputs / (dim ** 0.5)\n", "\n", " # Mask\n", " mask = (torch.arange(max_length, device=keys_length.device).repeat(\n", " batch_size, 1) < keys_length.view(-1, 1))\n", " outputs[~mask] = -np.inf\n", "\n", " # Activation\n", " outputs = torch.sigmoid(outputs) # [B, T]\n", "\n", " if not self.return_scores:\n", " # Weighted sum\n", " outputs = torch.matmul(\n", " outputs.unsqueeze(1), keys).squeeze() # [B, H]\n", " \n", " return outputs \n", " " ] }, { "cell_type": "markdown", "id": "8a47d11a", "metadata": {}, "source": [ "我们期待的效果是这样的,和候选广告(query)越相关的用户历史浏览记录(keys),其注意力权重值越高😋。\n", "\n" ] }, { "cell_type": "markdown", "id": "6688ee66", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h3efhfac81j20i706vaaj.jpg)" ] }, { "cell_type": "markdown", "id": "3083d63f", "metadata": {}, "source": [ "除了用Attention机制从用户行为日志中建模出和当前候选广告相关的用户兴趣表示这个主要创新外,DIN这篇文章还有一些其他的微创新。\n", "\n", "* 引入转折点可以学习的Dice激活函数代替PReLU激活函数\n", "* 介绍一种Mini-batch Aware 的L2正则化方法\n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "3decf586", "metadata": {}, "source": [ "## 二,DIN的pytorch实现" ] }, { "cell_type": "markdown", "id": "5b25ae8f", "metadata": {}, "source": [ "下面是一个DIN模型的完整pytorch实现。\n", "\n", "这里的AttentionGroup类比较特别,是为了建立候选广告属性和历史广告属性的pair关系。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "ab957679", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F \n", "from collections import OrderedDict\n", "\n", "class MaxPooling(nn.Module):\n", " def __init__(self, dim):\n", " super(MaxPooling, self).__init__()\n", " self.dim = dim\n", "\n", " def forward(self, x):\n", " return torch.max(x, self.dim)[0]\n", "\n", "\n", "class SumPooling(nn.Module):\n", " def __init__(self, dim):\n", " super(SumPooling, self).__init__()\n", " self.dim = dim\n", "\n", " def forward(self, x):\n", " return torch.sum(x, self.dim)\n", "\n", "class Dice(nn.Module):\n", " \"\"\"\n", " The Data Adaptive Activation Function in DIN, a generalization of PReLu.\n", " \"\"\"\n", " def __init__(self, emb_size, dim=2, epsilon=1e-8):\n", " super(Dice, self).__init__()\n", " assert dim == 2 or dim == 3\n", "\n", " self.bn = nn.BatchNorm1d(emb_size, eps=epsilon)\n", " self.sigmoid = nn.Sigmoid()\n", " self.dim = dim\n", " \n", " # wrap alpha in nn.Parameter to make it trainable\n", " self.alpha = nn.Parameter(torch.zeros((emb_size,))) if self.dim == 2 else nn.Parameter(\n", " torch.zeros((emb_size, 1)))\n", "\n", "\n", " def forward(self, x):\n", " assert x.dim() == self.dim\n", " if self.dim == 2:\n", " x_p = self.sigmoid(self.bn(x))\n", " out = self.alpha * (1 - x_p) * x + x_p * x\n", " else:\n", " x = torch.transpose(x, 1, 2)\n", " x_p = self.sigmoid(self.bn(x))\n", " out = self.alpha * (1 - x_p) * x + x_p * x\n", " out = torch.transpose(out, 1, 2)\n", " return out\n", "\n", "\n", "class Identity(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " def forward(self, x):\n", " return x\n", " \n", "def get_activation_layer(name, hidden_size=None, dice_dim=2):\n", " name = name.lower()\n", " name_dict = {x.lower():x for x in dir(nn) if '__' not in x and 'Z'>=x[0]>='A'}\n", " if name==\"linear\":\n", " return Identity()\n", " elif name==\"dice\":\n", " assert dice_dim\n", " return Dice(hidden_size, dice_dim)\n", " else:\n", " assert name in name_dict, f'activation type {name} not supported!'\n", " return getattr(nn,name_dict[name])()\n", " \n", "def init_weights(model):\n", " if isinstance(model, nn.Linear):\n", " if model.weight is not None:\n", " nn.init.kaiming_uniform_(model.weight.data)\n", " if model.bias is not None:\n", " nn.init.normal_(model.bias.data)\n", " elif isinstance(model, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d)):\n", " if model.weight is not None:\n", " nn.init.normal_(model.weight.data, mean=1, std=0.02)\n", " if model.bias is not None:\n", " nn.init.constant_(model.bias.data, 0)\n", " else:\n", " pass\n", "\n", "\n", "class MLP(nn.Module):\n", " def __init__(self, input_size, hidden_layers,\n", " dropout=0.0, batchnorm=True, activation='relu'):\n", " super(MLP, self).__init__()\n", " modules = OrderedDict()\n", " previous_size = input_size\n", " for index, hidden_layer in enumerate(hidden_layers):\n", " modules[f\"dense{index}\"] = nn.Linear(previous_size, hidden_layer)\n", " if batchnorm:\n", " modules[f\"batchnorm{index}\"] = nn.BatchNorm1d(hidden_layer)\n", " if activation:\n", " modules[f\"activation{index}\"] = get_activation_layer(activation,hidden_layer,2)\n", " if dropout:\n", " modules[f\"dropout{index}\"] = nn.Dropout(dropout)\n", " previous_size = hidden_layer\n", " self.mlp = nn.Sequential(modules)\n", "\n", " def forward(self, x):\n", " return self.mlp(x)\n", " " ] }, { "cell_type": "code", "execution_count": 2, "id": "edf8a025", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "class Attention(nn.Module):\n", " def __init__(\n", " self,\n", " input_size,\n", " hidden_layers,\n", " dropout=0.0,\n", " batchnorm=True,\n", " activation='prelu',\n", " return_scores=False):\n", " \n", " super().__init__()\n", " self.return_scores = return_scores\n", " \n", " self.mlp = MLP(\n", " input_size=input_size * 4,\n", " hidden_layers=hidden_layers,\n", " dropout=dropout,\n", " batchnorm=batchnorm,\n", " activation=activation)\n", " self.fc = nn.Linear(hidden_layers[-1], 1)\n", "\n", " def forward(self, query, keys, keys_length):\n", " \"\"\"\n", " Parameters\n", " ----------\n", " query: 2D tensor, [Batch, Hidden]\n", " keys: 3D tensor, [Batch, Time, Hidden]\n", " keys_length: 1D tensor, [Batch]\n", "\n", " Returns\n", " -------\n", " outputs: 2D tensor, [Batch, Hidden]\n", " \"\"\"\n", " \n", " batch_size, max_length, dim = keys.size()\n", "\n", " query = query.unsqueeze(1).expand(-1, max_length, -1)\n", "\n", " din_all = torch.cat(\n", " [query, keys, query - keys, query * keys], dim=-1)\n", "\n", " din_all = din_all.view(batch_size * max_length, -1)\n", "\n", " outputs = self.mlp(din_all)\n", "\n", " outputs = self.fc(outputs).view(batch_size, max_length) # [B, T]\n", "\n", " # Scale\n", " outputs = outputs / (dim ** 0.5)\n", "\n", " # Mask\n", " mask = (torch.arange(max_length, device=keys_length.device).repeat(\n", " batch_size, 1) < keys_length.view(-1, 1))\n", " outputs[~mask] = -np.inf\n", "\n", " # Activation\n", " outputs = torch.sigmoid(outputs) # [B, T]\n", "\n", " if not self.return_scores:\n", " # Weighted sum\n", " outputs = torch.matmul(\n", " outputs.unsqueeze(1), keys).squeeze() # [B, H]\n", " \n", " return outputs \n", " " ] }, { "cell_type": "code", "execution_count": 3, "id": "e3517c96", "metadata": {}, "outputs": [], "source": [ "class AttentionGroup(object):\n", " def __init__(self, name, pairs,\n", " hidden_layers, activation='dice', att_dropout=0.0):\n", " self.name = name\n", " self.pairs = pairs\n", " self.hidden_layers = hidden_layers\n", " self.activation = activation\n", " self.att_dropout = att_dropout\n", "\n", " self.related_feature_names = set()\n", " for pair in pairs:\n", " self.related_feature_names.add(pair['ad'])\n", " self.related_feature_names.add(pair['pos_hist'])\n", "\n", " def is_attention_feature(self, feature_name):\n", " if feature_name in self.related_feature_names:\n", " return True\n", " return False\n", "\n", " @property\n", " def pairs_count(self):\n", " return len(self.pairs)\n", " " ] }, { "cell_type": "code", "execution_count": 4, "id": "3178c1f6", "metadata": {}, "outputs": [], "source": [ "class DIN(nn.Module):\n", " def __init__(self, num_features,cat_features,seq_features, \n", " cat_nums,embedding_size, attention_groups,\n", " mlp_hidden_layers, mlp_activation='prelu', mlp_dropout=0.0,\n", " d_out = 1\n", " ):\n", " super().__init__()\n", " self.num_features = num_features\n", " self.cat_features = cat_features\n", " self.seq_features = seq_features\n", " self.cat_nums = cat_nums \n", " self.embedding_size = embedding_size\n", " \n", " self.attention_groups = attention_groups\n", " \n", " self.mlp_hidden_layers = mlp_hidden_layers\n", " self.mlp_activation = mlp_activation\n", " self.mlp_dropout = mlp_dropout\n", " \n", " self.d_out = d_out\n", " \n", " #embedding\n", " self.embeddings = OrderedDict()\n", " for feature in self.cat_features+self.seq_features:\n", " self.embeddings[feature] = nn.Embedding(\n", " self.cat_nums[feature], self.embedding_size, padding_idx=0)\n", " self.add_module(f\"embedding:{feature}\",self.embeddings[feature])\n", "\n", " self.sequence_poolings = OrderedDict()\n", " self.attention_poolings = OrderedDict()\n", " total_embedding_sizes = 0\n", " for feature in self.cat_features:\n", " total_embedding_sizes += self.embedding_size\n", " for feature in self.seq_features:\n", " total_embedding_sizes += self.embedding_size\n", " \n", " #sequence_pooling\n", " for feature in self.seq_features:\n", " if not self.is_attention_feature(feature):\n", " self.sequence_poolings[feature] = MaxPooling(1)\n", " self.add_module(f\"pooling:{feature}\",self.sequence_poolings[feature])\n", "\n", " #attention_pooling\n", " for attention_group in self.attention_groups:\n", " self.attention_poolings[attention_group.name] = (\n", " self.create_attention_fn(attention_group))\n", " self.add_module(f\"attention_pooling:{attention_group.name}\",\n", " self.attention_poolings[attention_group.name])\n", "\n", " total_input_size = total_embedding_sizes+len(self.num_features)\n", " \n", " self.mlp = MLP(\n", " total_input_size,\n", " mlp_hidden_layers,\n", " dropout=mlp_dropout, batchnorm=True, activation=mlp_activation)\n", " \n", " self.final_layer = nn.Linear(mlp_hidden_layers[-1], self.d_out)\n", " self.apply(init_weights)\n", "\n", " def forward(self, x):\n", " \n", " final_layer_inputs = list()\n", "\n", " number_inputs = list()\n", " for feature in self.num_features:\n", " number_inputs.append(x[feature].view(-1, 1))\n", "\n", " embeddings = OrderedDict()\n", " for feature in self.cat_features:\n", " embeddings[feature] = self.embeddings[feature](x[feature])\n", "\n", " for feature in self.seq_features:\n", " if not self.is_attention_feature(feature):\n", " embeddings[feature] = self.sequence_poolings[feature](\n", " self.embeddings[feature](x[feature]))\n", " \n", " for attention_group in self.attention_groups:\n", " query = torch.cat(\n", " [embeddings[pair['ad']]\n", " for pair in attention_group.pairs],\n", " dim=-1)\n", " keys = torch.cat(\n", " [self.embeddings[pair['pos_hist']](\n", " x[pair['pos_hist']]) for pair in attention_group.pairs],\n", " dim=-1)\n", " #hist_length = torch.sum(hist>0,axis=1)\n", " keys_length = torch.min(torch.cat(\n", " [torch.sum(x[pair['pos_hist']]>0,axis=1).view(-1, 1)\n", " for pair in attention_group.pairs],\n", " dim=-1), dim=-1)[0]\n", " \n", " embeddings[attention_group.name] = self.attention_poolings[\n", " attention_group.name](query, keys, keys_length)\n", "\n", " emb_concat = torch.cat(number_inputs + [\n", " emb for emb in embeddings.values()], dim=-1)\n", "\n", " final_layer_inputs = self.mlp(emb_concat)\n", " output = self.final_layer(final_layer_inputs)\n", " if self.d_out==1:\n", " output = output.squeeze() \n", "\n", " return output\n", "\n", " def create_attention_fn(self, attention_group):\n", " return Attention(\n", " attention_group.pairs_count * self.embedding_size,\n", " hidden_layers=attention_group.hidden_layers,\n", " dropout=attention_group.att_dropout,\n", " activation=attention_group.activation)\n", " \n", " def is_attention_feature(self, feature):\n", " for group in self.attention_groups:\n", " if group.is_attention_feature(feature):\n", " return True\n", " return False\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "d4f0b673", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "cef0b4f2", "metadata": {}, "source": [ "## 三,Movielens数据集完整范例" ] }, { "cell_type": "markdown", "id": "99bc3417", "metadata": {}, "source": [ "下面是一个基于Movielens评价数据集的完整范例,根据用户过去对一些电影的评价结果,来预测用户对候选电影是否会给好评。\n", "\n", "这个数据集不大,用CPU就能跑。😁\n" ] }, { "cell_type": "markdown", "id": "ada2a5f6", "metadata": {}, "source": [ "### 1,准备数据" ] }, { "cell_type": "code", "execution_count": 7, "id": "2f13dd71", "metadata": {}, "outputs": [], "source": [ "import numpy as np \n", "import pandas as pd \n", "from sklearn.base import BaseEstimator, TransformerMixin\n", "from sklearn.preprocessing import QuantileTransformer\n", "from sklearn.pipeline import Pipeline, FeatureUnion \n", "from sklearn.impute import SimpleImputer \n", "from collections import Counter\n", "\n", "#类别特征预处理\n", "class CategoryEncoder(BaseEstimator, TransformerMixin):\n", " \n", " def __init__(self, min_cnt=5, word2idx=None, idx2word=None):\n", " super().__init__() \n", " self.min_cnt = min_cnt\n", " self.word2idx = word2idx if word2idx else dict()\n", " self.idx2word = idx2word if idx2word else dict()\n", "\n", " def fit(self, x, y=None):\n", " if not self.word2idx:\n", " counter = Counter(np.asarray(x).ravel())\n", "\n", " selected_terms = sorted(\n", " list(filter(lambda x: counter[x] >= self.min_cnt, counter)))\n", "\n", " self.word2idx = dict(\n", " zip(selected_terms, range(1, len(selected_terms) + 1)))\n", " self.word2idx['__PAD__'] = 0\n", " if '__UNKNOWN__' not in self.word2idx:\n", " self.word2idx['__UNKNOWN__'] = len(self.word2idx)\n", "\n", " if not self.idx2word:\n", " self.idx2word = {\n", " index: word for word, index in self.word2idx.items()}\n", "\n", " return self\n", "\n", " def transform(self, x):\n", " transformed_x = list()\n", " for term in np.asarray(x).ravel():\n", " try:\n", " transformed_x.append(self.word2idx[term])\n", " except KeyError:\n", " transformed_x.append(self.word2idx['__UNKNOWN__'])\n", "\n", " return np.asarray(transformed_x, dtype=np.int64)\n", "\n", " def dimension(self):\n", " return len(self.word2idx)\n", "\n", "#序列特征预处理(类别序列) \n", "class SequenceEncoder(BaseEstimator, TransformerMixin):\n", " def __init__(self, sep=' ', min_cnt=5, max_len=None,\n", " word2idx=None, idx2word=None):\n", " super().__init__() \n", " self.sep = sep\n", " self.min_cnt = min_cnt\n", " self.max_len = max_len\n", "\n", " self.word2idx = word2idx if word2idx else dict()\n", " self.idx2word = idx2word if idx2word else dict()\n", "\n", " def fit(self, x, y=None):\n", " if not self.word2idx:\n", " counter = Counter()\n", "\n", " max_len = 0\n", " for sequence in np.array(x).ravel():\n", " words = sequence.split(self.sep)\n", " counter.update(words)\n", " max_len = max(max_len, len(words))\n", "\n", " if self.max_len is None:\n", " self.max_len = max_len\n", "\n", " # drop rare words\n", " words = sorted(\n", " list(filter(lambda x: counter[x] >= self.min_cnt, counter)))\n", "\n", " self.word2idx = dict(zip(words, range(1, len(words) + 1)))\n", " self.word2idx['__PAD__'] = 0\n", " if '__UNKNOWN__' not in self.word2idx:\n", " self.word2idx['__UNKNOWN__'] = len(self.word2idx)\n", "\n", " if not self.idx2word:\n", " self.idx2word = {\n", " index: word for word, index in self.word2idx.items()}\n", "\n", " if not self.max_len:\n", " max_len = 0\n", " for sequence in np.array(x).ravel():\n", " words = sequence.split(self.sep)\n", " max_len = max(max_len, len(words))\n", " self.max_len = max_len\n", "\n", " return self\n", "\n", " def transform(self, x):\n", " transformed_x = list()\n", "\n", " for sequence in np.asarray(x).ravel():\n", " words = list()\n", " for word in sequence.split(self.sep):\n", " try:\n", " words.append(self.word2idx[word])\n", " except KeyError:\n", " words.append(self.word2idx['__UNKNOWN__'])\n", "\n", " transformed_x.append(\n", " np.asarray(words[0:self.max_len], dtype=np.int64))\n", "\n", " return np.asarray(transformed_x, dtype=object)\n", " \n", " def dimension(self):\n", " return len(self.word2idx)\n", "\n", " def max_length(self):\n", " return self.max_len\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "b413951e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 8, "id": "30d774bf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preprocess number features...\n", "preprocess category features...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:00<00:00, 142.91it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "preprocess sequence features...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3/3 [00:00<00:00, 16.94it/s]\n" ] } ], "source": [ "from sklearn.preprocessing import QuantileTransformer\n", "from sklearn.pipeline import Pipeline \n", "from sklearn.impute import SimpleImputer \n", "from tqdm import tqdm \n", "\n", "dftrain = pd.read_csv(\"../data/ml_1m/train.csv\")\n", "dfval = pd.read_csv(\"../data/ml_1m/test.csv\")\n", "\n", "for col in [\"movieId\",\"histHighRatedMovieIds\",\"negHistMovieIds\",\"genres\"]:\n", " dftrain[col] = dftrain[col].astype(str)\n", " dfval[col] = dfval[col].astype(str)\n", "\n", "num_features = ['age']\n", "cat_features = ['gender', 'movieId', 'occupation', 'zipCode']\n", "seq_features = ['genres', 'histHighRatedMovieIds', 'negHistMovieIds']\n", "\n", "num_pipe = Pipeline(steps = [('impute',SimpleImputer()),('quantile',QuantileTransformer())])\n", "\n", "encoders = {}\n", "\n", "print(\"preprocess number features...\")\n", "dftrain[num_features] = num_pipe.fit_transform(dftrain[num_features]).astype(np.float32)\n", "dfval[num_features] = num_pipe.transform(dfval[num_features]).astype(np.float32)\n", "\n", "print(\"preprocess category features...\")\n", "for col in tqdm(cat_features):\n", " encoders[col] = CategoryEncoder(min_cnt=5)\n", " dftrain[col] = encoders[col].fit_transform(dftrain[col])\n", " dfval[col] = encoders[col].transform(dfval[col])\n", " \n", "print(\"preprocess sequence features...\")\n", "for col in tqdm(seq_features):\n", " encoders[col] = SequenceEncoder(sep=\"|\",min_cnt=5)\n", " dftrain[col] = encoders[col].fit_transform(dftrain[col])\n", " dfval[col] = encoders[col].transform(dfval[col])\n", " \n", "from collections import OrderedDict\n", "from itertools import chain\n", "from torch.utils.data import Dataset,DataLoader \n", "\n", "class Df2Dataset(Dataset):\n", " def __init__(self, dfdata, num_features, cat_features,\n", " seq_features, encoders, label_col=\"label\"):\n", " self.dfdata = dfdata\n", " self.num_features = num_features\n", " self.cat_features = cat_features \n", " self.seq_features = seq_features\n", " self.encoders = encoders\n", " self.label_col = label_col\n", " self.size = len(self.dfdata)\n", "\n", " def __len__(self):\n", " return self.size\n", "\n", " @staticmethod\n", " def pad_sequence(sequence,max_length):\n", " #zero is special index for padding\n", " padded_seq = np.zeros(max_length, np.int32)\n", " padded_seq[0: sequence.shape[0]] = sequence\n", " return padded_seq\n", "\n", " def __getitem__(self, idx):\n", " record = OrderedDict()\n", " for col in self.num_features:\n", " record[col] = self.dfdata[col].iloc[idx].astype(np.float32)\n", " \n", " for col in self.cat_features:\n", " record[col] = self.dfdata[col].iloc[idx].astype(np.int64)\n", " \n", " for col in self.seq_features:\n", " seq = self.dfdata[col].iloc[idx]\n", " max_length = self.encoders[col].max_length()\n", " record[col] = Df2Dataset.pad_sequence(seq,max_length)\n", "\n", " if self.label_col is not None:\n", " record['label'] = self.dfdata[self.label_col].iloc[idx].astype(np.float32)\n", " return record\n", "\n", " def get_num_batches(self, batch_size):\n", " return np.ceil(self.size / batch_size)\n", " \n", "ds_train = Df2Dataset(dftrain, num_features, cat_features, seq_features, encoders)\n", "ds_val = Df2Dataset(dfval,num_features, cat_features, seq_features, encoders)\n", "dl_train = DataLoader(ds_train, batch_size=128,shuffle=True)\n", "dl_val = DataLoader(ds_val,batch_size=128,shuffle=False)\n", "\n", "cat_nums = {k:v.dimension() for k,v in encoders.items()} \n" ] }, { "cell_type": "code", "execution_count": 9, "id": "ffe7dbf0", "metadata": {}, "outputs": [], "source": [ "for batch in dl_train:\n", " break " ] }, { "cell_type": "code", "execution_count": null, "id": "f3ea3f05", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "be532d54", "metadata": {}, "source": [ "### 2,定义模型" ] }, { "cell_type": "code", "execution_count": 124, "id": "5ec35633", "metadata": {}, "outputs": [], "source": [ "def create_net():\n", " din_attention_groups = [\n", " AttentionGroup(\n", " name='group1',\n", " pairs=[{'ad': 'movieId', 'pos_hist': 'histHighRatedMovieIds'}],\n", " activation='dice',\n", " hidden_layers=[16, 8], att_dropout=0.1)\n", " ]\n", "\n", " net = DIN(num_features=num_features,\n", " cat_features=cat_features,\n", " seq_features=seq_features,\n", " cat_nums = cat_nums,\n", " embedding_size=16,\n", " attention_groups=din_attention_groups,\n", " mlp_hidden_layers=[32,16],\n", " mlp_activation=\"prelu\",\n", " mlp_dropout=0.25,\n", " d_out=1\n", " )\n", " return net \n", "\n", "net = create_net() \n", "\n", "out = net.forward(batch)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6e82a3ea", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 125, "id": "40d61b6e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------------------------------\n", "Layer (type) Output Shape Param #\n", "==========================================================================\n", "Embedding-1 [-1, 16] 64\n", "Embedding-2 [-1, 16] 4,480\n", "Embedding-3 [-1, 16] 368\n", "Embedding-4 [-1, 16] 1,984\n", "Embedding-5 [-1, 6, 16] 320\n", "MaxPooling-6 [-1, 16] 0\n", "Embedding-7 [-1, 10, 16] 61,888\n", "MaxPooling-8 [-1, 16] 0\n", "Embedding-9 [-1, 10, 16] 28,656\n", "Linear-10 [-1, 16] 1,040\n", "BatchNorm1d-11 [-1, 16] 32\n", "BatchNorm1d-12 [-1, 16] 32\n", "Sigmoid-13 [-1, 16] 0\n", "Dice-14 [-1, 16] 48\n", "Dropout-15 [-1, 16] 0\n", "Linear-16 [-1, 8] 136\n", "BatchNorm1d-17 [-1, 8] 16\n", "BatchNorm1d-18 [-1, 8] 16\n", "Sigmoid-19 [-1, 8] 0\n", "Dice-20 [-1, 8] 24\n", "Dropout-21 [-1, 8] 0\n", "Sequential-22 [-1, 8] 1,296\n", "MLP-23 [-1, 8] 1,296\n", "Linear-24 [-1, 1] 9\n", "Attention-25 [-1, 16] 1,305\n", "Linear-26 [-1, 32] 3,648\n", "BatchNorm1d-27 [-1, 32] 64\n", "PReLU-28 [-1, 32] 1\n", "Dropout-29 [-1, 32] 0\n", "Linear-30 [-1, 16] 528\n", "BatchNorm1d-31 [-1, 16] 32\n", "PReLU-32 [-1, 16] 1\n", "Dropout-33 [-1, 16] 0\n", "Sequential-34 [-1, 16] 4,274\n", "MLP-35 [-1, 16] 4,274\n", "Linear-36 [-1, 1] 17\n", "DIN-37 [-1] 103,356\n", "==========================================================================\n", "Total params: 219,205\n", "Trainable params: 219,205\n", "Non-trainable params: 0\n", "--------------------------------------------------------------------------\n", "Input size (MB): 0.000343\n", "Forward/backward pass size (MB): 0.006981\n", "Params size (MB): 0.836201\n", "Estimated Total Size (MB): 0.843525\n", "--------------------------------------------------------------------------\n" ] } ], "source": [ "from torchkeras.summary import summary \n", "summary(net,input_data=batch);\n" ] }, { "cell_type": "markdown", "id": "bd2ea9a2", "metadata": {}, "source": [ "### 3,训练模型" ] }, { "cell_type": "code", "execution_count": null, "id": "536d5450", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 128, "id": "816c3662", "metadata": {}, "outputs": [], "source": [ "import os,sys,time\n", "import numpy as np\n", "import pandas as pd\n", "import datetime \n", "from tqdm import tqdm \n", "\n", "import torch\n", "from torch import nn \n", "from accelerate import Accelerator\n", "from copy import deepcopy\n", "\n", "\n", "def printlog(info):\n", " nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n", " print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n", " print(str(info)+\"\\n\")\n", " \n", "class StepRunner:\n", " def __init__(self, net, loss_fn,stage = \"train\", metrics_dict = None, \n", " optimizer = None, lr_scheduler = None,\n", " accelerator = None\n", " ):\n", " self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage\n", " self.optimizer,self.lr_scheduler = optimizer,lr_scheduler\n", " self.accelerator = accelerator\n", " \n", " def __call__(self, batch):\n", " #loss\n", " preds = self.net(batch)\n", " loss = self.loss_fn(preds,batch[\"label\"])\n", "\n", " #backward()\n", " if self.optimizer is not None and self.stage==\"train\":\n", " if self.accelerator is None:\n", " loss.backward()\n", " else:\n", " self.accelerator.backward(loss)\n", " self.optimizer.step()\n", " if self.lr_scheduler is not None:\n", " self.lr_scheduler.step()\n", " self.optimizer.zero_grad()\n", " \n", " #metrics\n", " step_metrics = {self.stage+\"_\"+name:metric_fn(preds, batch[\"label\"]).item() \n", " for name,metric_fn in self.metrics_dict.items()}\n", " return loss.item(),step_metrics\n", " \n", " \n", "class EpochRunner:\n", " def __init__(self,steprunner):\n", " self.steprunner = steprunner\n", " self.stage = steprunner.stage\n", " self.steprunner.net.train() if self.stage==\"train\" else self.steprunner.net.eval()\n", " \n", " def __call__(self,dataloader):\n", " total_loss,step = 0,0\n", " loop = tqdm(enumerate(dataloader), total =len(dataloader))\n", " for i, batch in loop:\n", " if self.stage==\"train\":\n", " loss, step_metrics = self.steprunner(batch)\n", " else:\n", " with torch.no_grad():\n", " loss, step_metrics = self.steprunner(batch)\n", "\n", " step_log = dict({self.stage+\"_loss\":loss},**step_metrics)\n", "\n", " total_loss += loss\n", " step+=1\n", " if i!=len(dataloader)-1:\n", " loop.set_postfix(**step_log)\n", " else:\n", " epoch_loss = total_loss/step\n", " epoch_metrics = {self.stage+\"_\"+name:metric_fn.compute().item() \n", " for name,metric_fn in self.steprunner.metrics_dict.items()}\n", " epoch_log = dict({self.stage+\"_loss\":epoch_loss},**epoch_metrics)\n", " loop.set_postfix(**epoch_log)\n", "\n", " for name,metric_fn in self.steprunner.metrics_dict.items():\n", " metric_fn.reset()\n", " return epoch_log\n", "\n", "class KerasModel(torch.nn.Module):\n", " def __init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler = None):\n", " super().__init__()\n", " self.accelerator = Accelerator()\n", " self.history = {}\n", " \n", " self.net = net\n", " self.loss_fn = loss_fn\n", " self.metrics_dict = nn.ModuleDict(metrics_dict) \n", " \n", " self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(\n", " self.parameters(), lr=1e-2)\n", " self.lr_scheduler = lr_scheduler\n", " \n", " self.net,self.loss_fn,self.metrics_dict,self.optimizer = self.accelerator.prepare(\n", " self.net,self.loss_fn,self.metrics_dict,self.optimizer)\n", "\n", " def forward(self, x):\n", " if self.net:\n", " return self.net.forward(x)\n", " else:\n", " raise NotImplementedError\n", "\n", "\n", " def fit(self, train_data, val_data=None, epochs=10, ckpt_path='checkpoint.pt', \n", " patience=5, monitor=\"val_loss\", mode=\"min\"):\n", " \n", " train_data = self.accelerator.prepare(train_data)\n", " val_data = self.accelerator.prepare(val_data) if val_data else []\n", "\n", " for epoch in range(1, epochs+1):\n", " printlog(\"Epoch {0} / {1}\".format(epoch, epochs))\n", " \n", " # 1,train ------------------------------------------------- \n", " train_step_runner = StepRunner(net = self.net,stage=\"train\",\n", " loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),\n", " optimizer = self.optimizer, lr_scheduler = self.lr_scheduler,\n", " accelerator = self.accelerator)\n", " train_epoch_runner = EpochRunner(train_step_runner)\n", " train_metrics = train_epoch_runner(train_data)\n", " \n", " for name, metric in train_metrics.items():\n", " self.history[name] = self.history.get(name, []) + [metric]\n", "\n", " # 2,validate -------------------------------------------------\n", " if val_data:\n", " val_step_runner = StepRunner(net = self.net,stage=\"val\",\n", " loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),\n", " accelerator = self.accelerator)\n", " val_epoch_runner = EpochRunner(val_step_runner)\n", " with torch.no_grad():\n", " val_metrics = val_epoch_runner(val_data)\n", " val_metrics[\"epoch\"] = epoch\n", " for name, metric in val_metrics.items():\n", " self.history[name] = self.history.get(name, []) + [metric]\n", " \n", " # 3,early-stopping -------------------------------------------------\n", " arr_scores = self.history[monitor]\n", " best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n", " if best_score_idx==len(arr_scores)-1:\n", " torch.save(self.net.state_dict(),ckpt_path)\n", " print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n", " arr_scores[best_score_idx]),file=sys.stderr)\n", " if len(arr_scores)-best_score_idx>patience:\n", " print(\"<<<<<< {} without improvement in {} epoch, early stopping >>>>>>\".format(\n", " monitor,patience),file=sys.stderr)\n", " self.net.load_state_dict(torch.load(ckpt_path))\n", " break \n", " \n", " return pd.DataFrame(self.history)\n", "\n", " @torch.no_grad()\n", " def evaluate(self, val_data):\n", " val_data = self.accelerator.prepare(val_data)\n", " val_step_runner = StepRunner(net = self.net,stage=\"val\",\n", " loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),\n", " accelerator = self.accelerator)\n", " val_epoch_runner = EpochRunner(val_step_runner)\n", " val_metrics = val_epoch_runner(val_data)\n", " return val_metrics\n", " \n", " \n", " @torch.no_grad()\n", " def predict(self, dataloader):\n", " dataloader = self.accelerator.prepare(dataloader)\n", " self.net.eval()\n", " result = torch.cat([self.forward(t) for t in dataloader])\n", " return result.data\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "b95cd39b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 129, "id": "c5255faf", "metadata": {}, "outputs": [], "source": [ "from torchkeras.metrics import AUC\n", "\n", "loss_fn = nn.BCEWithLogitsLoss()\n", "\n", "metrics_dict = {\"auc\":AUC()}\n", "\n", "optimizer = torch.optim.Adam(net.parameters(), lr=0.002, weight_decay=0.001) \n", "\n", "\n", "model = KerasModel(net,\n", " loss_fn = loss_fn,\n", " metrics_dict= metrics_dict,\n", " optimizer = optimizer,\n", " ) \n" ] }, { "cell_type": "code", "execution_count": 130, "id": "34ad222d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:26\n", "Epoch 1 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 24.87it/s, train_auc=0.515, train_loss=0.886]\n", "100%|██████████| 10/10 [00:00<00:00, 34.60it/s, val_auc=0.542, val_loss=0.741]\n", "<<<<<< reach best val_auc : 0.5422608256340027 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:28\n", "Epoch 2 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.57it/s, train_auc=0.559, train_loss=0.752]\n", "100%|██████████| 10/10 [00:00<00:00, 33.08it/s, val_auc=0.564, val_loss=0.693]\n", "<<<<<< reach best val_auc : 0.5640223026275635 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:30\n", "Epoch 3 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 25.03it/s, train_auc=0.566, train_loss=0.725]\n", "100%|██████████| 10/10 [00:00<00:00, 34.72it/s, val_auc=0.571, val_loss=0.682]\n", "<<<<<< reach best val_auc : 0.570896327495575 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:32\n", "Epoch 4 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.38it/s, train_auc=0.58, train_loss=0.703] \n", "100%|██████████| 10/10 [00:00<00:00, 29.78it/s, val_auc=0.58, val_loss=0.676]\n", "<<<<<< reach best val_auc : 0.5801306366920471 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:34\n", "Epoch 5 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 20.83it/s, train_auc=0.598, train_loss=0.693]\n", "100%|██████████| 10/10 [00:00<00:00, 32.51it/s, val_auc=0.586, val_loss=0.671]\n", "<<<<<< reach best val_auc : 0.5860493779182434 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:36\n", "Epoch 6 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 22.49it/s, train_auc=0.605, train_loss=0.684]\n", "100%|██████████| 10/10 [00:00<00:00, 32.40it/s, val_auc=0.594, val_loss=0.668]\n", "<<<<<< reach best val_auc : 0.5941869616508484 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:38\n", "Epoch 7 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 22.59it/s, train_auc=0.629, train_loss=0.664]\n", "100%|██████████| 10/10 [00:00<00:00, 32.62it/s, val_auc=0.599, val_loss=0.668]\n", "<<<<<< reach best val_auc : 0.5990937948226929 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:40\n", "Epoch 8 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 24.68it/s, train_auc=0.626, train_loss=0.668]\n", "100%|██████████| 10/10 [00:00<00:00, 36.89it/s, val_auc=0.599, val_loss=0.668]\n", "<<<<<< reach best val_auc : 0.5991701483726501 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:42\n", "Epoch 9 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 24.39it/s, train_auc=0.642, train_loss=0.655]\n", "100%|██████████| 10/10 [00:00<00:00, 27.89it/s, val_auc=0.603, val_loss=0.666]\n", "<<<<<< reach best val_auc : 0.6033138632774353 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:44\n", "Epoch 10 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.41it/s, train_auc=0.657, train_loss=0.644]\n", "100%|██████████| 10/10 [00:00<00:00, 33.93it/s, val_auc=0.607, val_loss=0.665]\n", "<<<<<< reach best val_auc : 0.6066208481788635 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:45\n", "Epoch 11 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 25.75it/s, train_auc=0.641, train_loss=0.656]\n", "100%|██████████| 10/10 [00:00<00:00, 35.30it/s, val_auc=0.607, val_loss=0.666]\n", "<<<<<< reach best val_auc : 0.6072936058044434 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:47\n", "Epoch 12 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.64it/s, train_auc=0.668, train_loss=0.635]\n", "100%|██████████| 10/10 [00:00<00:00, 35.02it/s, val_auc=0.608, val_loss=0.665]\n", "<<<<<< reach best val_auc : 0.6082348227500916 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:49\n", "Epoch 13 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 22.63it/s, train_auc=0.683, train_loss=0.628]\n", "100%|██████████| 10/10 [00:00<00:00, 32.89it/s, val_auc=0.61, val_loss=0.664]\n", "<<<<<< reach best val_auc : 0.6097879409790039 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:51\n", "Epoch 14 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.67it/s, train_auc=0.671, train_loss=0.636]\n", "100%|██████████| 10/10 [00:00<00:00, 33.25it/s, val_auc=0.615, val_loss=0.662]\n", "<<<<<< reach best val_auc : 0.6154622435569763 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:53\n", "Epoch 15 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.30it/s, train_auc=0.701, train_loss=0.613]\n", "100%|██████████| 10/10 [00:00<00:00, 32.71it/s, val_auc=0.615, val_loss=0.664]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:55\n", "Epoch 16 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 21.25it/s, train_auc=0.703, train_loss=0.61] \n", "100%|██████████| 10/10 [00:00<00:00, 33.11it/s, val_auc=0.617, val_loss=0.665]\n", "<<<<<< reach best val_auc : 0.6173899173736572 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:57\n", "Epoch 17 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 22.53it/s, train_auc=0.72, train_loss=0.601] \n", "100%|██████████| 10/10 [00:00<00:00, 33.07it/s, val_auc=0.625, val_loss=0.663]\n", "<<<<<< reach best val_auc : 0.6246682405471802 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:21:59\n", "Epoch 18 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.81it/s, train_auc=0.727, train_loss=0.595]\n", "100%|██████████| 10/10 [00:00<00:00, 35.39it/s, val_auc=0.626, val_loss=0.664]\n", "<<<<<< reach best val_auc : 0.625541627407074 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:01\n", "Epoch 19 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.72it/s, train_auc=0.747, train_loss=0.58] \n", "100%|██████████| 10/10 [00:00<00:00, 30.50it/s, val_auc=0.629, val_loss=0.667]\n", "<<<<<< reach best val_auc : 0.6285688877105713 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:03\n", "Epoch 20 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.21it/s, train_auc=0.755, train_loss=0.572]\n", "100%|██████████| 10/10 [00:00<00:00, 32.14it/s, val_auc=0.631, val_loss=0.669]\n", "<<<<<< reach best val_auc : 0.6309064030647278 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:05\n", "Epoch 21 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.53it/s, train_auc=0.782, train_loss=0.55] \n", "100%|██████████| 10/10 [00:00<00:00, 33.75it/s, val_auc=0.629, val_loss=0.676]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:07\n", "Epoch 22 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 22.40it/s, train_auc=0.801, train_loss=0.532]\n", "100%|██████████| 10/10 [00:00<00:00, 32.27it/s, val_auc=0.632, val_loss=0.688]\n", "<<<<<< reach best val_auc : 0.6316272020339966 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:09\n", "Epoch 23 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.87it/s, train_auc=0.808, train_loss=0.522]\n", "100%|██████████| 10/10 [00:00<00:00, 34.96it/s, val_auc=0.639, val_loss=0.684]\n", "<<<<<< reach best val_auc : 0.6392135620117188 >>>>>>\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:11\n", "Epoch 24 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 24.69it/s, train_auc=0.835, train_loss=0.493]\n", "100%|██████████| 10/10 [00:00<00:00, 32.95it/s, val_auc=0.634, val_loss=0.706]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:13\n", "Epoch 25 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.91it/s, train_auc=0.856, train_loss=0.465]\n", "100%|██████████| 10/10 [00:00<00:00, 33.51it/s, val_auc=0.633, val_loss=0.716]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:15\n", "Epoch 26 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.46it/s, train_auc=0.872, train_loss=0.44] \n", "100%|██████████| 10/10 [00:00<00:00, 29.78it/s, val_auc=0.626, val_loss=0.775]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:17\n", "Epoch 27 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.15it/s, train_auc=0.897, train_loss=0.406]\n", "100%|██████████| 10/10 [00:00<00:00, 32.71it/s, val_auc=0.635, val_loss=0.78]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================2022-07-02 09:22:19\n", "Epoch 28 / 100\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 38/38 [00:01<00:00, 23.70it/s, train_auc=0.907, train_loss=0.383]\n", "100%|██████████| 10/10 [00:00<00:00, 32.35it/s, val_auc=0.633, val_loss=0.857]\n", "<<<<<< val_auc without improvement in 5 epoch, early stopping >>>>>>\n" ] } ], "source": [ "dfhistory = model.fit(train_data=dl_train,val_data=dl_val,epochs=100, patience=5,\n", " monitor = \"val_auc\",mode=\"max\",ckpt_path='checkpoint.pt')\n" ] }, { "cell_type": "markdown", "id": "3425d1a9", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h3sbs7wl0lj20r107nab5.jpg)" ] }, { "cell_type": "markdown", "id": "d4902b2d", "metadata": {}, "source": [ "### 4,评估模型" ] }, { "cell_type": "code", "execution_count": 131, "id": "c8c7b32c", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "def plot_metric(dfhistory, metric):\n", " train_metrics = dfhistory[\"train_\"+metric]\n", " val_metrics = dfhistory['val_'+metric]\n", " epochs = range(1, len(train_metrics) + 1)\n", " plt.plot(epochs, train_metrics, 'bo--')\n", " plt.plot(epochs, val_metrics, 'ro-')\n", " plt.title('Training and validation '+ metric)\n", " plt.xlabel(\"Epochs\")\n", " plt.ylabel(metric)\n", " plt.legend([\"train_\"+metric, 'val_'+metric])\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 132, "id": "a1acb98b", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-07-02T09:22:28.024467\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.4, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_metric(dfhistory,\"loss\")" ] }, { "cell_type": "markdown", "id": "ea4fc554", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h3sbs8ryajj20h20a1gly.jpg)" ] }, { "cell_type": "markdown", "id": "173b66a2", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "id": "44f94173", "metadata": {}, "outputs": [], "source": [ "plot_metric(dfhistory,\"auc\")" ] }, { "cell_type": "markdown", "id": "a8e11ad8", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h3sbsf8b1wj20f30a70t3.jpg)" ] }, { "cell_type": "code", "execution_count": 134, "id": "4ae22dd7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10/10 [00:00<00:00, 32.40it/s, val_auc=0.639, val_loss=0.684]\n" ] }, { "data": { "text/plain": [ "{'val_loss': 0.6842133283615113, 'val_auc': 0.6392135620117188}" ] }, "execution_count": 134, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(dl_val)" ] }, { "cell_type": "markdown", "id": "685f701c", "metadata": {}, "source": [ "{'val_loss': 0.6842133283615113, 'val_auc': 0.6392135620117188}" ] }, { "cell_type": "code", "execution_count": null, "id": "1cafbe6d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "84830511", "metadata": {}, "source": [ "### 4,使用模型" ] }, { "cell_type": "code", "execution_count": 135, "id": "741987f7", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6392135469811272\n" ] } ], "source": [ "labels = torch.tensor([x[\"label\"] for x in ds_val])\n", "preds = model.predict(dl_val)\n", "val_auc = roc_auc_score(labels.cpu().numpy(),preds.cpu().numpy())\n", "print(val_auc)" ] }, { "cell_type": "markdown", "id": "0da487f4", "metadata": {}, "source": [ "0.6392135469811272" ] }, { "cell_type": "code", "execution_count": null, "id": "caa74c34", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "21d58167", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "f0788cbf", "metadata": {}, "source": [ "### 5, 保存模型" ] }, { "cell_type": "code", "execution_count": 136, "id": "46d2aea7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 136, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.save(model.net.state_dict(),\"best_din.pt\")\n", "net_clone = create_net()\n", "net_clone.load_state_dict(torch.load(\"best_din.pt\"))" ] }, { "cell_type": "code", "execution_count": 137, "id": "5dfb7b45", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6392135469811272\n" ] } ], "source": [ "net_clone.eval()\n", "labels = torch.tensor([x[\"label\"] for x in ds_val])\n", "preds = torch.cat([net_clone(x).data for x in dl_val]) \n", "val_auc = roc_auc_score(labels.cpu().numpy(),preds.cpu().numpy())\n", "print(val_auc)" ] }, { "cell_type": "markdown", "id": "ac7b1907", "metadata": {}, "source": [ "0.6392135469811272" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }