wework_channel.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import io
  2. import os
  3. import random
  4. import tempfile
  5. import threading
  6. os.environ['ntwork_LOG'] = "ERROR"
  7. import ntwork
  8. import requests
  9. import uuid
  10. from bridge.context import *
  11. from bridge.reply import *
  12. from channel.chat_channel import ChatChannel
  13. from channel.wework.wework_message import *
  14. from channel.wework.wework_message import WeworkMessage
  15. from common.singleton import singleton
  16. from common.log import logger
  17. from common.time_check import time_checker
  18. from common.utils import compress_imgfile, fsize
  19. from config import conf
  20. from channel.wework.run import wework
  21. from channel.wework import run
  22. from PIL import Image
  23. def get_wxid_by_name(room_members, group_wxid, name):
  24. if group_wxid in room_members:
  25. for member in room_members[group_wxid]['member_list']:
  26. if member['room_nickname'] == name or member['username'] == name:
  27. return member['user_id']
  28. return None # 如果没有找到对应的group_wxid或name,则返回None
  29. def download_and_compress_image(url, filename, quality=30):
  30. # 确定保存图片的目录
  31. directory = os.path.join(os.getcwd(), "tmp")
  32. # 如果目录不存在,则创建目录
  33. if not os.path.exists(directory):
  34. os.makedirs(directory)
  35. # 下载图片
  36. pic_res = requests.get(url, stream=True)
  37. image_storage = io.BytesIO()
  38. for block in pic_res.iter_content(1024):
  39. image_storage.write(block)
  40. # 检查图片大小并可能进行压缩
  41. sz = fsize(image_storage)
  42. if sz >= 10 * 1024 * 1024: # 如果图片大于 10 MB
  43. logger.info("[wework] image too large, ready to compress, sz={}".format(sz))
  44. image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
  45. logger.info("[wework] image compressed, sz={}".format(fsize(image_storage)))
  46. # 将内存缓冲区的指针重置到起始位置
  47. image_storage.seek(0)
  48. # 读取并保存图片
  49. image = Image.open(image_storage)
  50. image_path = os.path.join(directory, f"{filename}.png")
  51. image.save(image_path, "png")
  52. return image_path
  53. def download_video(url, filename):
  54. # 确定保存视频的目录
  55. directory = os.path.join(os.getcwd(), "tmp")
  56. # 如果目录不存在,则创建目录
  57. if not os.path.exists(directory):
  58. os.makedirs(directory)
  59. # 下载视频
  60. response = requests.get(url, stream=True)
  61. total_size = 0
  62. video_path = os.path.join(directory, f"{filename}.mp4")
  63. with open(video_path, 'wb') as f:
  64. for block in response.iter_content(1024):
  65. total_size += len(block)
  66. # 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
  67. if total_size > 30 * 1024 * 1024:
  68. logger.info("[WX] Video is larger than 30MB, skipping...")
  69. return None
  70. f.write(block)
  71. return video_path
  72. def create_message(wework_instance, message, is_group):
  73. logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
  74. cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
  75. logger.debug(f"cmsg:{cmsg}")
  76. return cmsg
  77. def handle_message(cmsg, is_group):
  78. logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
  79. if is_group:
  80. WeworkChannel().handle_group(cmsg)
  81. else:
  82. WeworkChannel().handle_single(cmsg)
  83. logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
  84. def _check(func):
  85. def wrapper(self, cmsg: ChatMessage):
  86. msgId = cmsg.msg_id
  87. create_time = cmsg.create_time # 消息时间戳
  88. if create_time is None:
  89. return func(self, cmsg)
  90. if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
  91. logger.debug("[WX]history message {} skipped".format(msgId))
  92. return
  93. return func(self, cmsg)
  94. return wrapper
  95. @wework.msg_register(
  96. [ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_LINK_CARD_MSG,ntwork.MT_RECV_FILE_MSG, ntwork.MT_RECV_VOICE_MSG])
  97. def all_msg_handler(wework_instance: ntwork.WeWork, message):
  98. logger.debug(f"收到消息: {message}")
  99. if 'data' in message:
  100. # 首先查找conversation_id,如果没有找到,则查找room_conversation_id
  101. conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
  102. if conversation_id is not None:
  103. is_group = "R:" in conversation_id
  104. try:
  105. cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
  106. except NotImplementedError as e:
  107. logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
  108. return None
  109. delay = random.randint(1, 2)
  110. timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
  111. timer.start()
  112. else:
  113. logger.debug("消息数据中无 conversation_id")
  114. return None
  115. return None
  116. def accept_friend_with_retries(wework_instance, user_id, corp_id):
  117. result = wework_instance.accept_friend(user_id, corp_id)
  118. logger.debug(f'result:{result}')
  119. # @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
  120. # def friend(wework_instance: ntwork.WeWork, message):
  121. # data = message["data"]
  122. # user_id = data["user_id"]
  123. # corp_id = data["corp_id"]
  124. # logger.info(f"接收到好友请求,消息内容:{data}")
  125. # delay = random.randint(1, 180)
  126. # threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
  127. #
  128. # return None
  129. def get_with_retry(get_func, max_retries=5, delay=5):
  130. retries = 0
  131. result = None
  132. while retries < max_retries:
  133. result = get_func()
  134. if result:
  135. break
  136. logger.warning(f"获取数据失败,重试第{retries + 1}次······")
  137. retries += 1
  138. time.sleep(delay) # 等待一段时间后重试
  139. return result
  140. @singleton
  141. class WeworkChannel(ChatChannel):
  142. NOT_SUPPORT_REPLYTYPE = []
  143. def __init__(self):
  144. super().__init__()
  145. def startup(self):
  146. smart = conf().get("wework_smart", True)
  147. wework.open(smart)
  148. logger.info("等待登录······")
  149. wework.wait_login()
  150. login_info = wework.get_login_info()
  151. self.user_id = login_info['user_id']
  152. self.name = login_info['nickname']
  153. logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
  154. logger.info("静默延迟60s,等待客户端刷新数据,请勿进行任何操作······")
  155. time.sleep(60)
  156. contacts = get_with_retry(wework.get_external_contacts)
  157. rooms = get_with_retry(wework.get_rooms)
  158. directory = os.path.join(os.getcwd(), "tmp")
  159. if not contacts or not rooms:
  160. logger.error("获取contacts或rooms失败,程序退出")
  161. ntwork.exit_()
  162. os.exit(0)
  163. if not os.path.exists(directory):
  164. os.makedirs(directory)
  165. # 将contacts保存到json文件中
  166. with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
  167. json.dump(contacts, f, ensure_ascii=False, indent=4)
  168. with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
  169. json.dump(rooms, f, ensure_ascii=False, indent=4)
  170. # 创建一个空字典来保存结果
  171. result = {}
  172. # 遍历列表中的每个字典
  173. for room in rooms['room_list']:
  174. # 获取聊天室ID
  175. room_wxid = room['conversation_id']
  176. # 获取聊天室成员
  177. room_members = wework.get_room_members(room_wxid)
  178. # 将聊天室成员保存到结果字典中
  179. result[room_wxid] = room_members
  180. # 将结果保存到json文件中
  181. with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
  182. json.dump(result, f, ensure_ascii=False, indent=4)
  183. logger.info("wework程序初始化完成········")
  184. run.forever()
  185. @time_checker
  186. @_check
  187. def handle_single(self, cmsg: ChatMessage):
  188. if cmsg.from_user_id == cmsg.to_user_id:
  189. # ignore self reply
  190. return
  191. if cmsg.ctype == ContextType.VOICE:
  192. if not conf().get("speech_recognition"):
  193. return
  194. logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
  195. elif cmsg.ctype == ContextType.IMAGE:
  196. logger.debug("[WX]receive image msg: {}".format(cmsg.content))
  197. elif cmsg.ctype == ContextType.PATPAT:
  198. logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
  199. elif cmsg.ctype == ContextType.TEXT:
  200. logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
  201. else:
  202. logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
  203. context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
  204. if context:
  205. self.produce(context)
  206. @time_checker
  207. @_check
  208. def handle_group(self, cmsg: ChatMessage):
  209. if cmsg.ctype == ContextType.VOICE:
  210. if not conf().get("speech_recognition"):
  211. return
  212. logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
  213. elif cmsg.ctype == ContextType.IMAGE:
  214. logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
  215. elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
  216. logger.debug("[WX]receive note msg: {}".format(cmsg.content))
  217. elif cmsg.ctype == ContextType.TEXT:
  218. pass
  219. else:
  220. logger.debug("[WX]receive group msg: {}".format(cmsg.content))
  221. context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
  222. if context:
  223. self.produce(context)
  224. # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
  225. def send(self, reply: Reply, context: Context):
  226. logger.debug(f"context: {context}")
  227. receiver = context["receiver"]
  228. actual_user_id = context["msg"].actual_user_id
  229. if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
  230. match = re.search(r"^@(.*?)\n", reply.content)
  231. logger.debug(f"match: {match}")
  232. if match:
  233. new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
  234. at_list = [actual_user_id]
  235. logger.debug(f"new_content: {new_content}")
  236. wework.send_room_at_msg(receiver, new_content, at_list)
  237. else:
  238. wework.send_text(receiver, reply.content)
  239. logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
  240. elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
  241. wework.send_text(receiver, reply.content)
  242. logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
  243. elif reply.type == ReplyType.IMAGE: # 从文件读取图片
  244. image_storage = reply.content
  245. image_storage.seek(0)
  246. # Read data from image_storage
  247. data = image_storage.read()
  248. # Create a temporary file
  249. with tempfile.NamedTemporaryFile(delete=False) as temp:
  250. temp_path = temp.name
  251. temp.write(data)
  252. # Send the image
  253. wework.send_image(receiver, temp_path)
  254. logger.info("[WX] sendImage, receiver={}".format(receiver))
  255. # Remove the temporary file
  256. os.remove(temp_path)
  257. elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
  258. img_url = reply.content
  259. filename = str(uuid.uuid4())
  260. # 调用你的函数,下载图片并保存为本地文件
  261. image_path = download_and_compress_image(img_url, filename)
  262. wework.send_image(receiver, file_path=image_path)
  263. logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
  264. elif reply.type == ReplyType.VIDEO_URL:
  265. video_url = reply.content
  266. filename = str(uuid.uuid4())
  267. video_path = download_video(video_url, filename)
  268. if video_path is None:
  269. # 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
  270. wework.send_text(receiver, "抱歉,视频太大了!!!")
  271. else:
  272. wework.send_video(receiver, video_path)
  273. logger.info("[WX] sendVideo, receiver={}".format(receiver))
  274. elif reply.type == ReplyType.VOICE:
  275. current_dir = os.getcwd()
  276. voice_file = reply.content.split("/")[-1]
  277. reply.content = os.path.join(current_dir, "tmp", voice_file)
  278. wework.send_file(receiver, reply.content)
  279. logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))