{ "cells": [ { "cell_type": "markdown", "id": "3eba4e9b", "metadata": {}, "source": [ "# streamlit深入浅出 " ] }, { "cell_type": "markdown", "id": "f5203f34", "metadata": {}, "source": [ "Streamlit是一个快速构建数据分析和机器学习Web页面的开源Python库。\n", "\n", "英文说明:A faster way to build and share data apps\n", "\n", "\n", "先看一个极简的效果,将一个文本情感分类的模型部署在了HuggingFace的Space托管页面了。\n", "\n", "效果如下: https://www.bilibili.com/video/bv1bN4y1G7ws\n", "\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h567xz046gj20l80dkt8y.jpg) \n", "\n", "\n", "\n", "大家猜猜做出这个效果需要多少个行代码?100行? 300行?No,全部代码仅需10行,如下所示。\n", "\n", "\n", "```python\n", "import streamlit as st\n", "from transformers import pipeline\n", "\n", "st.title('Text Classification')\n", "pipe = pipeline(\"text-classification\")\n", "text = st.text_area(\"Enter some text:\")\n", "\n", "if text:\n", " out = pipe(text)\n", " st.json(out)\n", "```\n", "\n", "公众号算法美食屋后台回复关键词:**streamlit**,获取本文源代码 和 HuggingFace部署的TextClassification和FasterRCNN演示项目地址。\n", "\n", "\n", "参考资料:\n", "\n", "* 官方文档:https://docs.streamlit.io/\n", "\n", "* 范例长廊:https://streamlit.io/gallery (可以右边查看代码)\n", "\n", "* 《streamlit 教程》:https://zhuanlan.zhihu.com/p/448912854\n", "\n", "* 《算法不会前端,也可以做出好看的界面-Streamlit》: https://zhuanlan.zhihu.com/p/469582149\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a6fc5907", "metadata": {}, "outputs": [], "source": [ "# 安装\n", "#!pip install streamlit -i https://pypi.tuna.tsinghua.edu.cn/simple\n", "\n", "#备注,需要python3.7及以上版本。\n", "\n", "# 环境测试\n", "#streamlit hello " ] }, { "cell_type": "markdown", "id": "defd7766", "metadata": {}, "source": [ "## 一,HelloWorld范例" ] }, { "cell_type": "code", "execution_count": null, "id": "ac94ca5b", "metadata": {}, "outputs": [], "source": [ "%%writefile demo.py\n", "import streamlit as st \n", "st.write(\"hello world\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "12f9fb10", "metadata": {}, "outputs": [], "source": [ "!streamlit run demo.py --server.port=8085 " ] }, { "cell_type": "markdown", "id": "3358bc02", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h565piz8msj20oq0ckq36.jpg)\n" ] }, { "cell_type": "markdown", "id": "9d25b56b", "metadata": {}, "source": [ "## 二,MarkDown范例" ] }, { "cell_type": "markdown", "id": "49fe56b9", "metadata": {}, "source": [ "支持常用的markdown展示\n", "\n", "* st.markdown: 按照markdown语法呈现内容\n", "\n", "* st.header\n", "\n", "* st.subheader\n", "\n", "* st.code\n", "\n", "* st.caption: 注释说明\n", "\n", "* st.text\n", "\n", "* st.latex\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a2ef5253", "metadata": {}, "outputs": [], "source": [ "%%writefile demo.py\n", "import streamlit as st\n", "\n", "# markdown\n", "\n", "st.title('streamlit极简教程')\n", "\n", "st.markdown('### 一. 安装')\n", "\n", "st.text('和安装其他包一样,安装 streamlit 非常简单,一条命令即可')\n", "code1 = '''pip install streamlit'''\n", "st.code(code1, language='bash')\n", "st.caption(\"需要python3.7以及以上环境\")\n", "\n", "\n", "st.markdown('### 二. 使用')\n", "\n", "\n", "st.markdown('#### 1 生成 Markdown 文档')\n", "\n", "code2 = '''import streamlit as st\n", "st.markdown('Streamlit Demo')\n", "st.header('标题')\n", "st.text('普通文本')\n", "'''\n", "st.code(code2, language='python')\n", "\n", "\n", "st.markdown('#### 2 生成 图表')\n", "\n", "code3 = '''import streamlit as st\n", "import pandas as pd \n", "chart_data = pd.DataFrame(\n", " np.random.randn(20, 3),\n", " columns=['a', 'b', 'c'])\n", "st.line_chart(chart_data)'''\n", "st.code(code2, language='python')\n", "\n", "\n", "st.markdown('### 三. 运行')\n", "\n", "code4 = '''streamlit run demo.py'''\n", "st.code(code4, language='bash')\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b947e404", "metadata": {}, "outputs": [], "source": [ "!streamlit run demo.py --server.port=8085 " ] }, { "cell_type": "markdown", "id": "7d820617", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h565v5e6atj20kr0kbmy0.jpg)" ] }, { "cell_type": "code", "execution_count": null, "id": "8863692e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "2c9be9b9", "metadata": {}, "source": [ "## 三,图表范例" ] }, { "cell_type": "markdown", "id": "6a118088", "metadata": {}, "source": [ "支持以下图表展示:\n", "\n", "* st.table\n", "\n", "* st.dataframe\n", "\n", "* st.metric\n", "\n", "* st.json\n", "\n", "* st.line_chart\n", "\n", "* st.bar_chart\n", "\n", "* st.area_chart\n", "\n", "* st.map_chart\n", "\n", "* st.pyplot : matplotlib 的 figure\n", "\n", "* st.plotly_chart: plotly 的 figure\n", "\n", "and more \n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e04170a3", "metadata": {}, "outputs": [], "source": [ "%%writefile demo.py\n", "import streamlit as st\n", "import numpy as np \n", "import pandas as pd \n", "import plotly.express as px \n", "\n", "\n", "st.title('streamlit图表范例')\n", "\n", "st.header(\"一,Table/DataFrame示范\")\n", "\n", "df = pd.DataFrame(\n", " np.random.randn(10, 5),\n", " columns=('第%d列' % (i+1) for i in range(5))\n", ")\n", "\n", "#st.table(df)\n", "st.dataframe(df.style.highlight_max(axis=0))\n", "\n", "\n", "st.header(\"二,metric监控指标\")\n", "col1, col2, col3 = st.columns(3)\n", "col1.metric(\"Temperature\", \"70 °F\", \"1.2 °F\")\n", "col2.metric(\"Wind\", \"9 mph\", \"-8%\")\n", "col3.metric(\"Humidity\", \"86%\", \"4%\")\n", "\n", "\n", "\n", "st.header(\"三,内置图表\")\n", "\n", "st.subheader(\"1,折线图\")\n", "\n", "chart_data = pd.DataFrame(\n", " np.random.randn(20, 3),\n", " columns=['a', 'b', 'c'])\n", "\n", "st.line_chart(chart_data)\n", "\n", "\n", "st.subheader(\"2,面积图\")\n", "\n", "chart_data = pd.DataFrame(\n", " np.random.randn(20, 3),\n", " columns = ['a', 'b', 'c'])\n", "\n", "st.area_chart(chart_data)\n", "\n", "st.subheader(\"3,柱形图\")\n", "\n", "chart_data = pd.DataFrame(\n", " np.random.randn(50, 3),\n", " columns = [\"a\", \"b\", \"c\"])\n", "st.bar_chart(chart_data)\n", "\n", "st.subheader(\"4,地图\")\n", "\n", "chart_data = pd.DataFrame(\n", " np.random.randn(1000, 2) / [50, 50] + [37.76, -122.4],\n", " columns=['lat', 'lon']\n", ")\n", "st.map(chart_data)\n", "\n", "\n", "st.header(\"四,plotly图表\")\n", "\n", "fig = px.line(data_frame=px.data.stocks(),x=\"date\",y=[\"GOOG\",\"AAPL\",\"AMZN\",\"FB\"]) \n", "\n", "st.plotly_chart(fig)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "66782db0", "metadata": {}, "outputs": [], "source": [ "!streamlit run demo.py --server.port=8085 " ] }, { "cell_type": "markdown", "id": "4d9579a2", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5664m9xe3j20ka0gjjsm.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5665pmum4j20lk0iygmr.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h56678j9yhj20m20c774s.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5667v5cryj20m20bvgma.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5669oiqolj20m20gj0tg.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566917ohaj20m20eb3z7.jpg)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f42e5267", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "2f7ad852", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "3f2920d1", "metadata": {}, "source": [ "## 四,控件范例" ] }, { "cell_type": "markdown", "id": "e7839c14", "metadata": {}, "source": [ "streamlit支持非常丰富的交互式输入控件。\n", "\n", "值得注意的是,当改变任何一个输入时,整个网页会重新计算和渲染。\n", "\n", "* button:按钮\n", "* download_button:文件下载\n", "* file_uploader:文件上传\n", "* checkbox:复选框\n", "* radio:单选框\n", "* selectbox:下拉单选框\n", "* multiselect:下拉多选框\n", "* slider:滑动条\n", "* select_slider:选择条\n", "* text_input:文本输入框\n", "* text_area:文本输入区域\n", "* number_input:数字输入框,支持加减按钮\n", "* date_input:日期选择框\n", "* time_input:时间选择框\n", "* color_picker:颜色选择器\n", "\n", "下面分别演示一些较高频的控件\n", "\n", "* 1, button \n", "\n", "* 2, selectbox\n", "\n", "* 3, number_input\n", "\n", "* 4, slider\n", "\n", "* 5, text_input\n", "\n", "* 6, text_area \n", "\n", "* 7, download_button \n", "\n", "* 8, file_uploader \n" ] }, { "cell_type": "code", "execution_count": null, "id": "63ba3bc8", "metadata": {}, "outputs": [], "source": [ "%%writefile demo.py\n", "import streamlit as st \n", "import plotly.express as px \n", "import time\n", "import pandas as pd \n", "\n", "st.title('streamlit控件范例')\n", "\n", "st.header(\"1,button\")\n", "\n", "#button常用于启动一段费时代码的执行\n", "if st.button(\"Start count sheep\"):\n", " msg = st.empty() #st.empty可以作为占位符\n", " for i in range(1,11):\n", " msg.write(\"{} sheep...\".format(i))\n", " time.sleep(0.3)\n", "else:\n", " pass #st.stop\n", "\n", "st.header(\"2,selectbox\") \n", "\n", "stock = st.selectbox(label = \"Choose a stock\",options=[\"GOOG\",\"AAPL\",\"AMZN\",\"FB\"])\n", "\n", "st.write('You selected:', stock)\n", "\n", "fig = px.line(data_frame=px.data.stocks(),x=\"date\",y=[stock]) \n", "\n", "st.plotly_chart(fig)\n", "\n", "st.header(\"3,number_input\") \n", "\n", "st.write(\"input x and y to eval x+y:\")\n", "x = st.number_input(\"x\",min_value=-10000,max_value=10000)\n", "y = st.number_input(\"y\",min_value=0,max_value=8)\n", "st.write('x+y=', x+y)\n", "\n", "\n", "st.header(\"4,slider\") \n", "\n", "st.write(\"slide to choose your age:\")\n", "age = st.slider(label=\"age\",min_value=0,max_value=120)\n", "\n", "st.write('your age is ', age)\n", "\n", "\n", "st.header(\"5,text_input\") \n", "\n", "st.write(\"what's your name\")\n", "name = st.text_input(label=\"name\",max_chars=100)\n", "st.write(\"your name is \",name)\n", "\n", "st.header(\"6,text_area\") \n", "\n", "st.write(\"give an introduction of yourself\")\n", "name = st.text_area(label=\"introduction\",max_chars=1024)\n", "\n", "\n", "st.header(\"7,download_button\") \n", "\n", "@st.cache\n", "def save_csv():\n", " # IMPORTANT: Cache the conversion to prevent computation on every rerun\n", " df = px.data.stocks()\n", " return df.to_csv().encode('utf-8')\n", "\n", "csv = save_csv()\n", "\n", "st.download_button(\n", " label=\"Download stock data\",\n", " data=csv,\n", " file_name='stocks.csv',\n", " mime='text/csv',\n", " )\n", "\n", "st.header(\"8,file_uploader\") \n", "\n", "csv_file = st.file_uploader(\"Choose a csv file\")\n", "\n", "if csv_file is not None:\n", " try:\n", " dfstocks = pd.read_csv(csv_file)\n", " st.table(dfstocks)\n", " except Exception as err:\n", " st.write(err)\n", " \n", "image_file = st.file_uploader(\"choose a image file(jpg/png)\")\n", "if image_file is not None:\n", " try:\n", " st.image(image_file)\n", " except Exception as err:\n", " st.write(err)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1f836a59", "metadata": {}, "outputs": [], "source": [ "!streamlit run demo.py --server.port=8085 " ] }, { "cell_type": "markdown", "id": "1e7e33de", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566dgrh7nj20m20cjglv.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566f6mzl7j20m20bgjrp.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566fyzy7xj20m20g8glv.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566hfzsezj20m20e1dg3.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566iopiy3j20m20d074p.jpg)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2c9356d6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "cbf908c1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "e43400e7", "metadata": {}, "source": [ "### 五,布局范例" ] }, { "cell_type": "markdown", "id": "1b2a7faa", "metadata": {}, "source": [ "Streamlit 是自上而下渲染的,组件在页面上的排列顺序与代码的执行顺序一致。\n", "\n", "可以应用如下布局组件实现非自上而下的布局。\n", "\n", "* st.sidebar:侧边栏\n", "\n", "* st.columns:列布局\n", "\n", "* st.expander:隐藏\n", "\n", "* st.empty:占位符,可以后续更新其中内容。\n", "\n", "* st.container: 容器占位符,可以后续往其中添加内容。\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "48a02971", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "2be021bd", "metadata": {}, "outputs": [], "source": [ "%%writefile demo.py\n", "import streamlit as st \n", "import time\n", "import pandas as pd \n", "\n", "st.title('streamlit布局范例')\n", "\n", "st.header(\"1,sidebar\")\n", "st.text(\"see the left side\")\n", "with st.sidebar:\n", " st.subheader(\"配置参数\")\n", " optim = st.multiselect(label = \"optimizer:\",options = [\"SGD\",\"Adam\",\"AdamW\"])\n", " lr = st.slider(label=\"lr:\",min_value=1e-5,max_value=0.1)\n", " early_stopping = st.checkbox(label = \"early_stopping\",value=True)\n", " batch_size = st.number_input(label = \"batch_size\",min_value=1,max_value=64)\n", "\n", "\n", "st.header(\"2,columns\")\n", "col1, col2, col3 = st.columns(3)\n", "col1.metric(\"accuracy\", \"0.82\", \"+32%\")\n", "col2.metric(\"AUC\", \"0.89\", \"-8%\")\n", "col3.metric(\"recall\", \"0.92\", \"+4%\")\n", "\n", "st.header(\"3,expander\")\n", "st.line_chart(data = [1,1,2,3,5,8,13,21,33,54])\n", "with st.expander(label=\"see explanation\"):\n", " st.text(\"This is the Fibonacci sequence\")\n", " st.text(\"You can see more about it in below link\")\n", " st.markdown(\"[](https://baike.baidu.com/item/%E6%96%90%E6%B3%A2%E9%82%A3%E5%A5%91%E6%95%B0%E5%88%97/99145?fr=aladdin)\")\n", "\n", "st.header(\"4,empty\")\n", "#st.empty可以作为占位符\n", "if st.button(\"Start count sheep\"):\n", " msg = st.empty() #st.empty可以作为占位符\n", " for i in range(1,11):\n", " msg.write(\"{} sheep...\".format(i))\n", " time.sleep(0.3)\n", "else:\n", " pass #st.stop\n", " \n", " \n", "st.header(\"5,container\")\n", "\n", "container = st.container()\n", "container.write(\"1:This should in container\")\n", "st.write(\"2:This should out container\")\n", "container.write(\"3:This should in container too\")\n", "container.bar_chart(data = [1,1,2,3,5,8,13,21,33,54])\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "51ae6a7a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "c052d775", "metadata": {}, "outputs": [], "source": [ "!streamlit run demo.py --server.port=8085 " ] }, { "cell_type": "code", "execution_count": null, "id": "f3189575", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "40dfafab", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566l8xub3j20z60c174t.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566mr402jj20mi0gzgm5.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566p47ekrj20mi0kqjrv.jpg)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c7702a0d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "64178da7", "metadata": {}, "source": [ "### 六,状态范例" ] }, { "cell_type": "markdown", "id": "68351554", "metadata": {}, "source": [ "Streamlit支持如下状态范例。\n", "\n", "* st.progress:进度条,如游戏加载进度\n", "* st.spinner:等待提示\n", "\n", "\n", "* st.info:显示常规信息\n", "* st.warning:显示报警信息\n", "* st.success:显示成功信息\n", "* st.error:显示错误信息\n", "* st.exception:显示异常信息\n", "\n", "\n", "* st.balloons:页面底部飘气球,表示庆祝\n", "* st.snow: 页面飘雪,表示庆祝\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b4d81161", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d3861a15", "metadata": {}, "outputs": [], "source": [ "%%writefile demo.py\n", "import streamlit as st \n", "import time\n", "\n", "\n", "if st.button(\"Start count sheep\"):\n", " with st.spinner('Wait for it...'):\n", " bar = st.progress(0)\n", " msg = st.empty() #st.empty可以作为占位符\n", " max_num = 20\n", " for i in range(1,max_num+1):\n", " msg.write(\"{} sheep...\".format(i))\n", " time.sleep(0.3)\n", " bar.progress((i*100)//max_num)\n", " time.sleep(1)\n", " st.success(\"You count 20 sheep! congratulations\")\n", " st.balloons()\n", " time.sleep(1)\n", " st.snow()\n", "else:\n", " pass \n" ] }, { "cell_type": "code", "execution_count": null, "id": "506d4b76", "metadata": {}, "outputs": [], "source": [ "!streamlit run demo.py --server.port=8085 " ] }, { "cell_type": "markdown", "id": "29cec8bd", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h566s547rfj20nc0kqwf2.jpg)" ] }, { "cell_type": "code", "execution_count": null, "id": "60fa7eec", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "13895334", "metadata": {}, "source": [ "### 七,综合范例" ] }, { "cell_type": "markdown", "id": "7f756845", "metadata": {}, "source": [ "下面示范一个用streamlit实现一个FasterRCNN的网页交互APP范例。" ] }, { "cell_type": "code", "execution_count": null, "id": "d2130e58", "metadata": {}, "outputs": [], "source": [ "%%writefile demo.py\n", "import numpy as np\n", "from PIL import Image,ImageColor,ImageDraw,ImageFont \n", "import torch\n", "from torch import nn\n", "\n", "import torchvision\n", "from torchvision import datasets, models, transforms\n", "\n", "import streamlit as st \n", "\n", "# 可视化函数\n", "def plot_detection(image,prediction,idx2names,min_score = 0.8):\n", " image_result = image.copy()\n", " boxes,labels,scores = prediction['boxes'],prediction['labels'],prediction['scores']\n", " draw = ImageDraw.Draw(image_result) \n", " for idx in range(boxes.shape[0]):\n", " if scores[idx] >= min_score:\n", " x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]\n", " name = idx2names.get(str(labels[idx].item()))\n", " score = scores[idx]\n", " draw.rectangle((x1,y1,x2,y2), fill=None, outline ='lawngreen',width = 2)\n", " draw.text((x1,y1),name+\":\\n\"+str(round(score.item(),2)),fill=\"red\")\n", " return image_result \n", "\n", "\n", "# 加载模型\n", "@st.cache()\n", "def load_model():\n", " num_classes = 91\n", " model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,num_classes = num_classes)\n", " if torch.cuda.is_available():\n", " model.to(\"cuda:0\")\n", " model.eval()\n", " model.idx2names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car', \n", " '4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat', \n", " '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign', \n", " '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat',\n", " '18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant', \n", " '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack', \n", " '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase',\n", " '34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball',\n", " '38': 'kite','39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard',\n", " '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass', \n", " '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl',\n", " '52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange', \n", " '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza',\n", " '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch', \n", " '64': 'potted plant', '65': 'bed', '67': 'dining table',\n", " '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse', \n", " '75': 'remote', '76': 'keyboard', '77': 'cell phone', \n", " '78': 'microwave', '79': 'oven', '80': 'toaster', \n", " '81': 'sink', '82': 'refrigerator', '84': 'book',\n", " '85': 'clock', '86': 'vase', '87': 'scissors',\n", " '88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'}\n", " return model \n", "\n", "def predict_detection(model,image_path,min_score=0.8):\n", " # 准备数据\n", " inputs = []\n", " img = Image.open(image_path).convert(\"RGB\")\n", " img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()\n", " if torch.cuda.is_available():\n", " img_tensor = img_tensor.cuda()\n", " inputs.append(img_tensor) \n", "\n", " # 预测结果\n", " with torch.no_grad():\n", " predictions = model(inputs)\n", "\n", " # 结果可视化\n", " img_result = plot_detection(img,predictions[0],\n", " model.idx2names,min_score = min_score)\n", " return img_result\n", " \n", "st.title(\"FasterRCNN功能演示\")\n", "\n", "st.header(\"FasterRCNN Input:\")\n", "image_file = st.file_uploader(\"upload a image file(jpg/png) to predict:\")\n", "if image_file is not None:\n", " try:\n", " st.image(image_file)\n", " except Exception as err:\n", " st.write(err)\n", "else:\n", " image_file = \"horseman.png\"\n", " st.image(image_file)\n", " \n", "min_score = st.slider(label=\"choose the min_score parameter:\",min_value=0.1,max_value=0.98,value=0.8)\n", "\n", "st.header(\"FasterRCNN Prediction:\")\n", "with st.spinner('waitting for prediction...'):\n", " model = load_model()\n", " img_result = predict_detection(model,image_file,min_score=min_score)\n", " st.image(img_result)\n" ] }, { "cell_type": "markdown", "id": "6470d333", "metadata": {}, "source": [ "效果演示页面如下:\n", "\n", "https://huggingface.co/spaces/lyhue1991/FasterRCNNDemo" ] }, { "cell_type": "markdown", "id": "dc43f867", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h567w534hnj20l80lv768.jpg)\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h567w6oiqwj20l80k2tag.jpg)" ] }, { "cell_type": "code", "execution_count": null, "id": "8ff2bb14", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "285ca16e", "metadata": {}, "source": [ "### 八,部署到HuggingFace" ] }, { "cell_type": "markdown", "id": "2ffc4ca2", "metadata": {}, "source": [ "为了便于向合作伙伴展示我们的模型App,可以将stremlit的模型部署到 HuggingFace的 Space托管空间中,完全免费的哦。\n", "\n", "方法如下:\n", "\n", "1,注册huggingface账号:https://huggingface.co/join\n", "\n", "2,在space空间中创建项目:https://huggingface.co/spaces\n", "\n", "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5671duaw4j20ze0le0v4.jpg) " ] }, { "cell_type": "markdown", "id": "6a5a9650", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h567416p9tj20ku0lvq44.jpg)\n" ] }, { "cell_type": "markdown", "id": "593446bd", "metadata": {}, "source": [ "3,创建好的项目有一个Readme文档,根据说明操作即可。" ] }, { "cell_type": "markdown", "id": "b502c153", "metadata": {}, "source": [ "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h56776qqloj20m20lv0u6.jpg)" ] }, { "cell_type": "code", "execution_count": null, "id": "cc757a7e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "97da5839", "metadata": {}, "source": [ "公众号算法美食屋后台回复关键词:**streamlit**,获取本文源代码 和 HuggingFace部署的TextClassification和FasterRCNN演示项目地址。" ] }, { "cell_type": "markdown", "id": "348cf3a2", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "id": "875bd781", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }