run_inference.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import json
  4. from facechain.inference_fact import GenPortrait
  5. import cv2
  6. from facechain.utils import snapshot_download
  7. from facechain.constants import neg_prompt, pos_prompt_with_cloth, pos_prompt_with_style, base_models
  8. def generate_pos_prompt(style_model, prompt_cloth):
  9. if style_model is not None:
  10. matched = list(filter(lambda style: style_model == style['name'], styles))
  11. if len(matched) == 0:
  12. raise ValueError(f'styles not found: {style_model}')
  13. matched = matched[0]
  14. if matched['model_id'] is None:
  15. pos_prompt = pos_prompt_with_cloth.format(prompt_cloth)
  16. else:
  17. pos_prompt = pos_prompt_with_style.format(matched['add_prompt_style'])
  18. else:
  19. pos_prompt = pos_prompt_with_cloth.format(prompt_cloth)
  20. return pos_prompt
  21. styles = []
  22. for base_model in base_models:
  23. style_in_base = []
  24. folder_path = f"styles/{base_model['name']}"
  25. files = os.listdir(folder_path)
  26. files.sort()
  27. for file in files:
  28. file_path = os.path.join(folder_path, file)
  29. with open(file_path, "r") as f:
  30. data = json.load(f)
  31. style_in_base.append(data['name'])
  32. styles.append(data)
  33. base_model['style_list'] = style_in_base
  34. use_pose_model = False
  35. input_img_path = 'poses/man/pose2.png'
  36. pose_image = 'poses/man/pose1.png'
  37. num_generate = 5
  38. multiplier_style = 0.25
  39. output_dir = './generated'
  40. base_model_idx = 0
  41. style_idx = 0
  42. base_model = base_models[base_model_idx]
  43. style = styles[style_idx]
  44. model_id = style['model_id']
  45. if model_id == None:
  46. style_model_path = None
  47. pos_prompt = generate_pos_prompt(style['name'], style['add_prompt_style'])
  48. else:
  49. if os.path.exists(model_id):
  50. model_dir = model_id
  51. else:
  52. model_dir = snapshot_download(model_id, revision=style['revision'])
  53. style_model_path = os.path.join(model_dir, style['bin_file'])
  54. pos_prompt = generate_pos_prompt(style['name'], style['add_prompt_style']) # style has its own prompt
  55. if not use_pose_model:
  56. pose_image = None
  57. gen_portrait = GenPortrait()
  58. outputs = gen_portrait(num_generate, base_model_idx, style_model_path, pos_prompt, neg_prompt, input_img_path, pose_image, multiplier_style)
  59. os.makedirs(output_dir, exist_ok=True)
  60. for i, out_tmp in enumerate(outputs):
  61. cv2.imwrite(os.path.join(output_dir, f'{i}.png'), out_tmp)