setup_credentials.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. """
  2. This script is used to set up credentials for some services in the
  3. CI environment. For instance, it can fetch WandB API tokens and write
  4. the WandB configuration file so test scripts can use the service.
  5. """
  6. import json
  7. import sys
  8. from pathlib import Path
  9. import boto3
  10. AWS_AIR_SECRETS_ARN = (
  11. "arn:aws:secretsmanager:us-west-2:029272617770:secret:"
  12. "oss-ci/ray-air-test-secrets20221014164754935800000002-UONblX"
  13. )
  14. def get_ray_air_secrets(client):
  15. raw_string = client.get_secret_value(SecretId=AWS_AIR_SECRETS_ARN)["SecretString"]
  16. return json.loads(raw_string)
  17. def write_wandb_api_key(api_key: str):
  18. with open(Path("~/.netrc").expanduser(), "w") as fp:
  19. fp.write(f"machine api.wandb.ai\n" f" login user\n" f" password {api_key}\n")
  20. def write_comet_ml_api_key(api_key: str):
  21. with open(Path("~/.comet.config").expanduser(), "w") as fp:
  22. fp.write(f"[comet]\napi_key={api_key}\n")
  23. def write_sigopt_api_key(api_key: str):
  24. sigopt_config_file = Path("~/.sigopt/client/config.json").expanduser()
  25. sigopt_config_file.parent.mkdir(parents=True, exist_ok=True)
  26. with open(sigopt_config_file, "wt") as f:
  27. json.dump(
  28. {
  29. "api_token": api_key,
  30. "code_tracking_enabled": False,
  31. "log_collection_enabled": False,
  32. },
  33. f,
  34. )
  35. SERVICES = {
  36. "wandb": ("wandb_key", write_wandb_api_key),
  37. "comet_ml": ("comet_ml_token", write_comet_ml_api_key),
  38. "sigopt": ("sigopt_key", write_sigopt_api_key),
  39. }
  40. def main():
  41. if len(sys.argv) < 2:
  42. print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
  43. sys.exit(0)
  44. services = sys.argv[1:]
  45. if any(service not in SERVICES for service in services):
  46. raise RuntimeError(
  47. f"All services must be included in {list(SERVICES.keys())}. "
  48. f"Got: {services}"
  49. )
  50. try:
  51. client = boto3.client("secretsmanager", region_name="us-west-2")
  52. ray_air_secrets = get_ray_air_secrets(client)
  53. except Exception as e:
  54. print(f"Could not get Ray AIR secrets: {e}")
  55. return
  56. for service in services:
  57. try:
  58. secret_key, setup_fn = SERVICES[service]
  59. setup_fn(ray_air_secrets[secret_key])
  60. except Exception as e:
  61. print(f"Could not setup service credentials for {service}: {e}")
  62. if __name__ == "__main__":
  63. main()