app_real3dportrait.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import os, sys
  2. sys.path.append('./')
  3. import argparse
  4. import gradio as gr
  5. from inference.real3d_infer import GeneFace2Infer
  6. from utils.commons.hparams import hparams
  7. class Inferer(GeneFace2Infer):
  8. def infer_once_args(self, *args, **kargs):
  9. assert len(kargs) == 0
  10. keys = [
  11. 'src_image_name',
  12. 'drv_audio_name',
  13. 'drv_pose_name',
  14. 'bg_image_name',
  15. 'blink_mode',
  16. 'temperature',
  17. 'mouth_amp',
  18. 'out_mode',
  19. 'map_to_init_pose',
  20. 'low_memory_usage',
  21. 'hold_eye_opened',
  22. 'a2m_ckpt',
  23. 'head_ckpt',
  24. 'torso_ckpt',
  25. 'min_face_area_percent',
  26. ]
  27. inp = {}
  28. out_name = None
  29. info = ""
  30. try: # try to catch errors and jump to return
  31. for key_index in range(len(keys)):
  32. key = keys[key_index]
  33. inp[key] = args[key_index]
  34. if '_name' in key:
  35. inp[key] = inp[key] if inp[key] is not None else ''
  36. if inp['src_image_name'] == '':
  37. info = "Input Error: Source image is REQUIRED!"
  38. raise ValueError
  39. if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
  40. info = "Input Error: At least one of driving audio or video is REQUIRED!"
  41. raise ValueError
  42. if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
  43. inp['drv_audio_name'] = inp['drv_pose_name']
  44. print("No audio input, we use driving pose video for video driving")
  45. if inp['drv_pose_name'] == '':
  46. inp['drv_pose_name'] = 'static'
  47. reload_flag = False
  48. if inp['a2m_ckpt'] != self.audio2secc_dir:
  49. print("Changes of a2m_ckpt detected, reloading model")
  50. reload_flag = True
  51. if inp['head_ckpt'] != self.head_model_dir:
  52. print("Changes of head_ckpt detected, reloading model")
  53. reload_flag = True
  54. if inp['torso_ckpt'] != self.torso_model_dir:
  55. print("Changes of torso_ckpt detected, reloading model")
  56. reload_flag = True
  57. inp['out_name'] = ''
  58. inp['seed'] = 42
  59. print(f"infer inputs : {inp}")
  60. try:
  61. if reload_flag:
  62. self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
  63. except Exception as e:
  64. content = f"{e}"
  65. info = f"Reload ERROR: {content}"
  66. raise ValueError
  67. try:
  68. out_name = self.infer_once(inp)
  69. except Exception as e:
  70. content = f"{e}"
  71. info = f"Inference ERROR: {content}"
  72. raise ValueError
  73. except Exception as e:
  74. if info == "": # unexpected errors
  75. content = f"{e}"
  76. info = f"WebUI ERROR: {content}"
  77. # output part
  78. if len(info) > 0 : # there is errors
  79. print(info)
  80. info_gr = gr.update(visible=True, value=info)
  81. else: # no errors
  82. info_gr = gr.update(visible=False, value=info)
  83. if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
  84. print(f"Succefully generated in {out_name}")
  85. video_gr = gr.update(visible=True, value=out_name)
  86. else:
  87. print(f"Failed to generate")
  88. video_gr = gr.update(visible=True, value=out_name)
  89. return video_gr, info_gr
  90. def toggle_audio_file(choice):
  91. if choice == False:
  92. return gr.update(visible=True), gr.update(visible=False)
  93. else:
  94. return gr.update(visible=False), gr.update(visible=True)
  95. def ref_video_fn(path_of_ref_video):
  96. if path_of_ref_video is not None:
  97. return gr.update(value=True)
  98. else:
  99. return gr.update(value=False)
  100. def real3dportrait_demo(
  101. audio2secc_dir,
  102. head_model_dir,
  103. torso_model_dir,
  104. device = 'cuda',
  105. warpfn = None,
  106. ):
  107. sep_line = "-" * 40
  108. infer_obj = Inferer(
  109. audio2secc_dir=audio2secc_dir,
  110. head_model_dir=head_model_dir,
  111. torso_model_dir=torso_model_dir,
  112. device=device,
  113. )
  114. print(sep_line)
  115. print("Model loading is finished.")
  116. print(sep_line)
  117. with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
  118. gr.Markdown("\
  119. <div align='center'> <h2> Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis (ICLR 2024 Spotlight) </span> </h2> \
  120. <a style='font-size:18px;color: #a0a0a0' href='https://arxiv.org/pdf/2401.08503.pdf'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
  121. <a style='font-size:18px;color: #a0a0a0' href='https://real3dportrait.github.io/'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
  122. <a style='font-size:18px;color: #a0a0a0' href='https://github.com/yerfor/Real3DPortrait/'> Github </div>")
  123. sources = None
  124. with gr.Row():
  125. with gr.Column(variant='panel'):
  126. with gr.Tabs(elem_id="source_image"):
  127. with gr.TabItem('Upload image'):
  128. with gr.Row():
  129. src_image_name = gr.Image(label="Source image (required)", sources=sources, type="filepath", value="data/raw/examples/Macron.png")
  130. with gr.Tabs(elem_id="driven_audio"):
  131. with gr.TabItem('Upload audio'):
  132. with gr.Column(variant='panel'):
  133. drv_audio_name = gr.Audio(label="Input audio (required for audio-driven)", sources=sources, type="filepath", value="data/raw/examples/Obama_5s.wav")
  134. with gr.Tabs(elem_id="driven_pose"):
  135. with gr.TabItem('Upload video'):
  136. with gr.Column(variant='panel'):
  137. drv_pose_name = gr.Video(label="Driven Pose (required for video-driven, optional for audio-driven)", sources=sources, value="data/raw/examples/May_5s.mp4")
  138. with gr.Tabs(elem_id="bg_image"):
  139. with gr.TabItem('Upload image'):
  140. with gr.Row():
  141. bg_image_name = gr.Image(label="Background image (optional)", sources=sources, type="filepath", value="data/raw/examples/bg.png")
  142. with gr.Column(variant='panel'):
  143. with gr.Tabs(elem_id="checkbox"):
  144. with gr.TabItem('General Settings'):
  145. with gr.Column(variant='panel'):
  146. blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
  147. min_face_area_percent = gr.Slider(minimum=0.15, maximum=0.5, step=0.01, label="min_face_area_percent", value=0.2, info='The minimum face area percent in the output frame, to prevent bad cases caused by a too small face.',)
  148. temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
  149. mouth_amp = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="mouth amplitude", value=0.45, info='higher -> mouth will open wider, default to be 0.4',)
  150. out_mode = gr.Radio(['final', 'concat_debug'], value='concat_debug', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
  151. low_memory_usage = gr.Checkbox(label="Low Memory Usage Mode: save memory at the expense of lower inference speed. Useful when running a low audio (minutes-long).", value=False)
  152. map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose", value=True)
  153. hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
  154. submit = gr.Button('Generate', elem_id="generate", variant='primary')
  155. with gr.Tabs(elem_id="genearted_video"):
  156. info_box = gr.Textbox(label="Error", interactive=False, visible=False)
  157. gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
  158. with gr.Column(variant='panel'):
  159. with gr.Tabs(elem_id="checkbox"):
  160. with gr.TabItem('Checkpoints'):
  161. with gr.Column(variant='panel'):
  162. ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
  163. audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
  164. head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
  165. torso_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=torso_model_dir, file_count='single', label='torso model ckpt path or directory')
  166. # audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
  167. # head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
  168. # torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
  169. fn = infer_obj.infer_once_args
  170. if warpfn:
  171. fn = warpfn(fn)
  172. submit.click(
  173. fn=fn,
  174. inputs=[
  175. src_image_name,
  176. drv_audio_name,
  177. drv_pose_name,
  178. bg_image_name,
  179. blink_mode,
  180. temperature,
  181. mouth_amp,
  182. out_mode,
  183. map_to_init_pose,
  184. low_memory_usage,
  185. hold_eye_opened,
  186. audio2secc_dir,
  187. head_model_dir,
  188. torso_model_dir,
  189. min_face_area_percent,
  190. ],
  191. outputs=[
  192. gen_video,
  193. info_box,
  194. ],
  195. )
  196. print(sep_line)
  197. print("Gradio page is constructed.")
  198. print(sep_line)
  199. return real3dportrait_interface
  200. if __name__ == "__main__":
  201. parser = argparse.ArgumentParser()
  202. parser.add_argument("--a2m_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/audio2secc_vae/model_ckpt_steps_400000.ckpt')
  203. parser.add_argument("--head_ckpt", type=str, default='')
  204. parser.add_argument("--torso_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig/model_ckpt_steps_100000.ckpt')
  205. parser.add_argument("--port", type=int, default=None)
  206. parser.add_argument("--server", type=str, default='127.0.0.1')
  207. parser.add_argument("--share", action='store_true', dest='share', help='share srever to Internet')
  208. args = parser.parse_args()
  209. demo = real3dportrait_demo(
  210. audio2secc_dir=args.a2m_ckpt,
  211. head_model_dir=args.head_ckpt,
  212. torso_model_dir=args.torso_ckpt,
  213. device='cuda:0',
  214. warpfn=None,
  215. )
  216. demo.queue()
  217. demo.launch(share=args.share, server_name=args.server, server_port=args.port)