global_config.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import os
  2. import yaml
  3. from typing import List
  4. from typing_extensions import TypedDict
  5. class GlobalConfig(TypedDict):
  6. byod_ray_ecr: str
  7. byod_ray_cr_repo: str
  8. byod_ray_ml_cr_repo: str
  9. byod_ecr: str
  10. byod_aws_cr: str
  11. byod_gcp_cr: str
  12. state_machine_pr_aws_bucket: str
  13. state_machine_branch_aws_bucket: str
  14. state_machine_disabled: bool
  15. aws2gce_credentials: str
  16. ci_pipeline_premerge: List[str]
  17. ci_pipeline_postmerge: List[str]
  18. ci_pipeline_buildkite_secret: str
  19. config = None
  20. def init_global_config(config_file: str):
  21. """
  22. Initiate the global configuration singleton.
  23. """
  24. global config
  25. if not config:
  26. _init_global_config(config_file)
  27. def get_global_config():
  28. """
  29. Get the global configuration singleton. Need to be invoked after
  30. init_global_config().
  31. """
  32. global config
  33. return config
  34. def _init_global_config(config_file: str):
  35. global config
  36. config_content = yaml.safe_load(open(config_file, "rt"))
  37. config = GlobalConfig(
  38. byod_ray_ecr=(
  39. config_content.get("byod", {}).get("ray_ecr")
  40. or config_content.get("release_byod", {}).get("ray_ecr")
  41. ),
  42. byod_ray_cr_repo=(
  43. config_content.get("byod", {}).get("ray_cr_repo")
  44. or config_content.get("release_byod", {}).get("ray_cr_repo")
  45. ),
  46. byod_ray_ml_cr_repo=(
  47. config_content.get("byod", {}).get("ray_ml_cr_repo")
  48. or config_content.get("release_byod", {}).get("ray_ml_cr_repo")
  49. ),
  50. byod_ecr=(
  51. config_content.get("byod", {}).get("byod_ecr")
  52. or config_content.get("release_byod", {}).get("byod_ecr")
  53. ),
  54. byod_aws_cr=(
  55. config_content.get("byod", {}).get("aws_cr")
  56. or config_content.get("release_byod", {}).get("aws_cr")
  57. ),
  58. byod_gcp_cr=(
  59. config_content.get("byod", {}).get("gcp_cr")
  60. or config_content.get("release_byod", {}).get("gcp_cr")
  61. ),
  62. aws2gce_credentials=(
  63. config_content.get("credentials", {}).get("aws2gce")
  64. or config_content.get("release_byod", {}).get("aws2gce_credentials")
  65. ),
  66. state_machine_pr_aws_bucket=config_content.get("state_machine", {})
  67. .get("pr", {})
  68. .get(
  69. "aws_bucket",
  70. ),
  71. state_machine_branch_aws_bucket=config_content.get("state_machine", {})
  72. .get("branch", {})
  73. .get(
  74. "aws_bucket",
  75. ),
  76. state_machine_disabled=config_content.get("state_machine", {}).get(
  77. "disabled", 0
  78. )
  79. == 1,
  80. ci_pipeline_premerge=config_content.get("ci_pipeline", {}).get("premerge", []),
  81. ci_pipeline_postmerge=config_content.get("ci_pipeline", {}).get(
  82. "postmerge", []
  83. ),
  84. ci_pipeline_buildkite_secret=config_content.get("ci_pipeline", {}).get(
  85. "buildkite_secret"
  86. ),
  87. )
  88. # setup GCP workload identity federation
  89. os.environ[
  90. "GOOGLE_APPLICATION_CREDENTIALS"
  91. ] = f"/workdir/{config['aws2gce_credentials']}"