install.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import launch
  2. import os
  3. # TODO: add pip dependency if need extra module only on extension
  4. if not launch.is_installed("slugify"):
  5. print("--installing slugify...")
  6. launch.run_pip("install slugify", "requirements for slugify")
  7. launch.run_pip("install python-slugify==8.0.1", "requirements for python-slugify")
  8. if not launch.is_installed("diffusers"):
  9. print("--installing diffusers...")
  10. launch.run_pip("install diffusers==0.29.0", "requirements for diffusers")
  11. if not launch.is_installed("peft"):
  12. print("--installing peft...")
  13. launch.run_pip("install peft==0.11.1", "requirements for peft")
  14. if not launch.is_installed("onnxruntime") and not launch.is_installed("onnxruntime-gpu"):
  15. import torch.cuda as cuda
  16. print("Installing onnxruntime")
  17. launch.run_pip("install onnxruntime-gpu" if cuda.is_available() else "install onnxruntime")
  18. if not launch.is_installed("modelscope"):
  19. print("--installing modelscope...")
  20. launch.run_pip("install modelscope", "requirements for modelscope")
  21. if not launch.is_installed("controlnet_aux"):
  22. print("--installing controlnet_aux...")
  23. launch.run_pip("install controlnet_aux==0.0.6", "requirements for controlnet_aux")
  24. def get_pytorch_version():
  25. import torch
  26. version = torch.__version__
  27. def get_python_version():
  28. import sys
  29. return sys.version_info[0] + "." + sys.version_info[1]
  30. if not launch.is_installed("mmcv-full"):
  31. print("--installing mmcv...")
  32. # Todo 这里有坑
  33. try:
  34. launch.run_pip("install mmcv-full==1.7.2")
  35. except Exception as e:
  36. print(e)
  37. if os.name == 'nt': # Windows
  38. print('ERROR facechain: failed to install mmcv, make sure to have "CUDA Toolkit" and "Build Tools for Visual Studio" installed')
  39. else:
  40. print('ERROR facechain: failed to install mmcv, make sure to have "CUDA Toolkit" installed')
  41. if not launch.is_installed("mmdet"):
  42. print("--installing mmdet...")
  43. launch.run_pip("install mmdet==2.26.0", "requirements for mmdet")
  44. if not launch.is_installed("mediapipe"):
  45. print("--installing mmdet...")
  46. launch.run_pip("install mediapipe==0.10.3", "requirements for mediapipe")
  47. if not launch.is_installed("edge_tts"):
  48. print("--installing edge_tts...")
  49. launch.run_pip("install edge_tts", "requirements for mediapipe")
  50. if not launch.is_installed("cv2"):
  51. launch.run_pip("install opencv-python", "requirements for opencv")
  52. if not launch.is_installed("diffusers"):
  53. launch.run_pip("install diffusers", "requirements for diffusers")
  54. if not launch.is_installed("protobuf==3.20.1"):
  55. launch.run_pip("install protobuf==3.20.1", "requirements for diffusers")
  56. # there seems to be a bug in fsspec 2023.10.0 that triggers an Error during training
  57. # NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported.
  58. # currently webui by default will install 2023.10.0
  59. # Todo remove fsspec version change after issue is resolved, please monitor situation, it's possible in the future that webui might specify a specific version of fsspec
  60. import pkg_resources
  61. required_fsspec_version = '2023.9.2'
  62. try:
  63. fsspec_version = pkg_resources.get_distribution('fsspec').version
  64. if fsspec_version != required_fsspec_version:
  65. print("--installing fsspec...")
  66. launch.run_pip(f"install -U fsspec=={required_fsspec_version}", f"facechain changing fsspec version from {fsspec_version} to {required_fsspec_version}")
  67. except Exception:
  68. # pkg_resources.get_distribution will throw if fsspec installed, since webui install by default fsspec this section shouldn't be necessary
  69. print("--installing fsspec...")
  70. launch.run_pip(f"install -U fsspec=={required_fsspec_version}", f"requirements for facechain")