app.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. import os
  2. import time
  3. import pdb
  4. import re
  5. import gradio as gr
  6. import spaces
  7. import numpy as np
  8. import sys
  9. import subprocess
  10. from huggingface_hub import snapshot_download
  11. import requests
  12. import argparse
  13. import os
  14. from omegaconf import OmegaConf
  15. import numpy as np
  16. import cv2
  17. import torch
  18. import glob
  19. import pickle
  20. from tqdm import tqdm
  21. import copy
  22. from argparse import Namespace
  23. import shutil
  24. import gdown
  25. import imageio
  26. import ffmpeg
  27. from moviepy.editor import *
  28. ProjectDir = os.path.abspath(os.path.dirname(__file__))
  29. CheckpointsDir = os.path.join(ProjectDir, "models")
  30. def print_directory_contents(path):
  31. for child in os.listdir(path):
  32. child_path = os.path.join(path, child)
  33. if os.path.isdir(child_path):
  34. print(child_path)
  35. def download_model():
  36. if not os.path.exists(CheckpointsDir):
  37. os.makedirs(CheckpointsDir)
  38. print("Checkpoint Not Downloaded, start downloading...")
  39. tic = time.time()
  40. snapshot_download(
  41. repo_id="TMElyralab/MuseTalk",
  42. local_dir=CheckpointsDir,
  43. max_workers=8,
  44. local_dir_use_symlinks=True,
  45. force_download=True, resume_download=False
  46. )
  47. # weight
  48. os.makedirs(f"{CheckpointsDir}/sd-vae-ft-mse/")
  49. snapshot_download(
  50. repo_id="stabilityai/sd-vae-ft-mse",
  51. local_dir=CheckpointsDir+'/sd-vae-ft-mse',
  52. max_workers=8,
  53. local_dir_use_symlinks=True,
  54. force_download=True, resume_download=False
  55. )
  56. #dwpose
  57. os.makedirs(f"{CheckpointsDir}/dwpose/")
  58. snapshot_download(
  59. repo_id="yzd-v/DWPose",
  60. local_dir=CheckpointsDir+'/dwpose',
  61. max_workers=8,
  62. local_dir_use_symlinks=True,
  63. force_download=True, resume_download=False
  64. )
  65. #vae
  66. url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
  67. response = requests.get(url)
  68. # 确保请求成功
  69. if response.status_code == 200:
  70. # 指定文件保存的位置
  71. file_path = f"{CheckpointsDir}/whisper/tiny.pt"
  72. os.makedirs(f"{CheckpointsDir}/whisper/")
  73. # 将文件内容写入指定位置
  74. with open(file_path, "wb") as f:
  75. f.write(response.content)
  76. else:
  77. print(f"请求失败,状态码:{response.status_code}")
  78. #gdown face parse
  79. url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
  80. os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
  81. file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
  82. gdown.download(url, file_path, quiet=False)
  83. #resnet
  84. url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
  85. response = requests.get(url)
  86. # 确保请求成功
  87. if response.status_code == 200:
  88. # 指定文件保存的位置
  89. file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
  90. # 将文件内容写入指定位置
  91. with open(file_path, "wb") as f:
  92. f.write(response.content)
  93. else:
  94. print(f"请求失败,状态码:{response.status_code}")
  95. toc = time.time()
  96. print(f"download cost {toc-tic} seconds")
  97. print_directory_contents(CheckpointsDir)
  98. else:
  99. print("Already download the model.")
  100. download_model() # for huggingface deployment.
  101. from musetalk.utils.utils import get_file_type,get_video_fps,datagen
  102. from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
  103. from musetalk.utils.blending import get_image
  104. from musetalk.utils.utils import load_all_model
  105. @spaces.GPU(duration=600)
  106. @torch.no_grad()
  107. def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)):
  108. args_dict={"result_dir":'./results/output', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script
  109. args = Namespace(**args_dict)
  110. input_basename = os.path.basename(video_path).split('.')[0]
  111. audio_basename = os.path.basename(audio_path).split('.')[0]
  112. output_basename = f"{input_basename}_{audio_basename}"
  113. result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
  114. crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
  115. os.makedirs(result_img_save_path,exist_ok =True)
  116. if args.output_vid_name=="":
  117. output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
  118. else:
  119. output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
  120. ############################################## extract frames from source video ##############################################
  121. if get_file_type(video_path)=="video":
  122. save_dir_full = os.path.join(args.result_dir, input_basename)
  123. os.makedirs(save_dir_full,exist_ok = True)
  124. # cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
  125. # os.system(cmd)
  126. # 读取视频
  127. reader = imageio.get_reader(video_path)
  128. # 保存图片
  129. for i, im in enumerate(reader):
  130. imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
  131. input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
  132. fps = get_video_fps(video_path)
  133. else: # input img folder
  134. input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
  135. input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  136. fps = args.fps
  137. #print(input_img_list)
  138. ############################################## extract audio feature ##############################################
  139. whisper_feature = audio_processor.audio2feat(audio_path)
  140. whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
  141. ############################################## preprocess input image ##############################################
  142. if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
  143. print("using extracted coordinates")
  144. with open(crop_coord_save_path,'rb') as f:
  145. coord_list = pickle.load(f)
  146. frame_list = read_imgs(input_img_list)
  147. else:
  148. print("extracting landmarks...time consuming")
  149. coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
  150. with open(crop_coord_save_path, 'wb') as f:
  151. pickle.dump(coord_list, f)
  152. bbox_shift_text=get_bbox_range(input_img_list, bbox_shift)
  153. i = 0
  154. input_latent_list = []
  155. for bbox, frame in zip(coord_list, frame_list):
  156. if bbox == coord_placeholder:
  157. continue
  158. x1, y1, x2, y2 = bbox
  159. crop_frame = frame[y1:y2, x1:x2]
  160. crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
  161. latents = vae.get_latents_for_unet(crop_frame)
  162. input_latent_list.append(latents)
  163. # to smooth the first and the last frame
  164. frame_list_cycle = frame_list + frame_list[::-1]
  165. coord_list_cycle = coord_list + coord_list[::-1]
  166. input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
  167. ############################################## inference batch by batch ##############################################
  168. print("start inference")
  169. video_num = len(whisper_chunks)
  170. batch_size = args.batch_size
  171. gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
  172. res_frame_list = []
  173. for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
  174. tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
  175. audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
  176. audio_feature_batch = pe(audio_feature_batch)
  177. pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
  178. recon = vae.decode_latents(pred_latents)
  179. for res_frame in recon:
  180. res_frame_list.append(res_frame)
  181. ############################################## pad to full image ##############################################
  182. print("pad talking image to original video")
  183. for i, res_frame in enumerate(tqdm(res_frame_list)):
  184. bbox = coord_list_cycle[i%(len(coord_list_cycle))]
  185. ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
  186. x1, y1, x2, y2 = bbox
  187. try:
  188. res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
  189. except:
  190. # print(bbox)
  191. continue
  192. combine_frame = get_image(ori_frame,res_frame,bbox)
  193. cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
  194. # cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p temp.mp4"
  195. # print(cmd_img2video)
  196. # os.system(cmd_img2video)
  197. # 帧率
  198. fps = 25
  199. # 图片路径
  200. # 输出视频路径
  201. output_video = 'temp.mp4'
  202. # 读取图片
  203. def is_valid_image(file):
  204. pattern = re.compile(r'\d{8}\.png')
  205. return pattern.match(file)
  206. images = []
  207. files = [file for file in os.listdir(result_img_save_path) if is_valid_image(file)]
  208. files.sort(key=lambda x: int(x.split('.')[0]))
  209. for file in files:
  210. filename = os.path.join(result_img_save_path, file)
  211. images.append(imageio.imread(filename))
  212. # 保存视频
  213. imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
  214. # cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
  215. # print(cmd_combine_audio)
  216. # os.system(cmd_combine_audio)
  217. input_video = './temp.mp4'
  218. # Check if the input_video and audio_path exist
  219. if not os.path.exists(input_video):
  220. raise FileNotFoundError(f"Input video file not found: {input_video}")
  221. if not os.path.exists(audio_path):
  222. raise FileNotFoundError(f"Audio file not found: {audio_path}")
  223. # 读取视频
  224. reader = imageio.get_reader(input_video)
  225. fps = reader.get_meta_data()['fps'] # 获取原视频的帧率
  226. # 将帧存储在列表中
  227. frames = images
  228. # 保存视频并添加音频
  229. # imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path])
  230. # input_video = ffmpeg.input(input_video)
  231. # input_audio = ffmpeg.input(audio_path)
  232. print(len(frames))
  233. # imageio.mimwrite(
  234. # output_video,
  235. # frames,
  236. # 'FFMPEG',
  237. # fps=25,
  238. # codec='libx264',
  239. # audio_codec='aac',
  240. # input_params=['-i', audio_path],
  241. # output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists
  242. # )
  243. # writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p')
  244. # for im in frames:
  245. # writer.append_data(im)
  246. # writer.close()
  247. # Load the video
  248. video_clip = VideoFileClip(input_video)
  249. # Load the audio
  250. audio_clip = AudioFileClip(audio_path)
  251. # Set the audio to the video
  252. video_clip = video_clip.set_audio(audio_clip)
  253. # Write the output video
  254. video_clip.write_videofile(output_vid_name, codec='libx264', audio_codec='aac',fps=25)
  255. os.remove("temp.mp4")
  256. #shutil.rmtree(result_img_save_path)
  257. print(f"result is save to {output_vid_name}")
  258. return output_vid_name,bbox_shift_text
  259. # load model weights
  260. audio_processor,vae,unet,pe = load_all_model()
  261. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  262. timesteps = torch.tensor([0], device=device)
  263. def check_video(video):
  264. if not isinstance(video, str):
  265. return video # in case of none type
  266. # Define the output video file name
  267. dir_path, file_name = os.path.split(video)
  268. if file_name.startswith("outputxxx_"):
  269. return video
  270. # Add the output prefix to the file name
  271. output_file_name = "outputxxx_" + file_name
  272. os.makedirs('./results',exist_ok=True)
  273. os.makedirs('./results/output',exist_ok=True)
  274. os.makedirs('./results/input',exist_ok=True)
  275. # Combine the directory path and the new file name
  276. output_video = os.path.join('./results/input', output_file_name)
  277. # # Run the ffmpeg command to change the frame rate to 25fps
  278. # command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y"
  279. # read video
  280. reader = imageio.get_reader(video)
  281. fps = reader.get_meta_data()['fps'] # get fps from original video
  282. # conver fps to 25
  283. frames = [im for im in reader]
  284. target_fps = 25
  285. L = len(frames)
  286. L_target = int(L / fps * target_fps)
  287. original_t = [x / fps for x in range(1, L+1)]
  288. t_idx = 0
  289. target_frames = []
  290. for target_t in range(1, L_target+1):
  291. while target_t / target_fps > original_t[t_idx]:
  292. t_idx += 1 # find the first t_idx so that target_t / target_fps <= original_t[t_idx]
  293. if t_idx >= L:
  294. break
  295. target_frames.append(frames[t_idx])
  296. # save video
  297. imageio.mimwrite(output_video, target_frames, 'FFMPEG', fps=25, codec='libx264', quality=9, pixelformat='yuv420p')
  298. return output_video
  299. css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
  300. with gr.Blocks(css=css) as demo:
  301. gr.Markdown(
  302. "<div align='center'> <h1>MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting </span> </h1> \
  303. <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
  304. </br>\
  305. Yue Zhang <sup>\*</sup>,\
  306. Minhao Liu<sup>\*</sup>,\
  307. Zhaokang Chen,\
  308. Bin Wu<sup>†</sup>,\
  309. Yingjie He,\
  310. Chao Zhan,\
  311. Wenjiang Zhou\
  312. (<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
  313. Lyra Lab, Tencent Music Entertainment\
  314. </h2> \
  315. <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
  316. <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
  317. <a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\
  318. <a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
  319. )
  320. with gr.Row():
  321. with gr.Column():
  322. audio = gr.Audio(label="Driven Audio",type="filepath")
  323. video = gr.Video(label="Reference Video",sources=['upload'])
  324. bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
  325. bbox_shift_scale = gr.Textbox(label="BBox_shift recommend value lower bound,The corresponding bbox range is generated after the initial result is generated. \n If the result is not good, it can be adjusted according to this reference value", value="",interactive=False)
  326. btn = gr.Button("Generate")
  327. out1 = gr.Video()
  328. video.change(
  329. fn=check_video, inputs=[video], outputs=[video]
  330. )
  331. btn.click(
  332. fn=inference,
  333. inputs=[
  334. audio,
  335. video,
  336. bbox_shift,
  337. ],
  338. outputs=[out1,bbox_shift_scale]
  339. )
  340. # Set the IP and port
  341. ip_address = "0.0.0.0" # Replace with your desired IP address
  342. port_number = 7860 # Replace with your desired port number
  343. demo.queue().launch(
  344. share=False , debug=True, server_name=ip_address, server_port=port_number
  345. )