remote-watch.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. #!/usr/bin/env python
  2. """
  3. This command must be run in a git repository.
  4. It watches the remote branch for changes, killing the given PID when it detects
  5. that the remote branch no longer points to the local commit.
  6. If the commit message contains a line saying "CI_KEEP_ALIVE", then killing
  7. will not occur until the branch is deleted from the remote.
  8. If no PID is given, then the entire process group of this process is killed.
  9. """
  10. # Prefer to keep this file Python 2-compatible so that it can easily run early
  11. # in the CI process on any system.
  12. import argparse
  13. import errno
  14. import logging
  15. import os
  16. import signal
  17. import subprocess
  18. import sys
  19. import time
  20. logger = logging.getLogger(__name__)
  21. GITHUB = "GitHub"
  22. TRAVIS = "Travis"
  23. def git(*args):
  24. cmdline = ["git"] + list(args)
  25. return subprocess.check_output(cmdline).decode("utf-8").rstrip()
  26. def get_current_ci():
  27. if "GITHUB_WORKFLOW" in os.environ:
  28. return GITHUB
  29. elif "TRAVIS" in os.environ:
  30. return TRAVIS
  31. return None
  32. def get_ci_event_name():
  33. ci = get_current_ci()
  34. if ci == GITHUB:
  35. return os.environ["GITHUB_EVENT_NAME"]
  36. elif ci == TRAVIS:
  37. return os.environ["TRAVIS_EVENT_TYPE"]
  38. return None
  39. def get_repo_slug():
  40. ci = get_current_ci()
  41. if ci == GITHUB:
  42. return os.environ["GITHUB_REPOSITORY"]
  43. elif ci == TRAVIS:
  44. return os.environ["TRAVIS_REPO_SLUG"]
  45. return None
  46. def get_remote_url(remote):
  47. return git("ls-remote", "--get-url", remote)
  48. def replace_suffix(base, old_suffix, new_suffix=""):
  49. if base.endswith(old_suffix):
  50. base = base[: len(base) - len(old_suffix)] + new_suffix
  51. return base
  52. def git_branch_info_to_track():
  53. """Obtains the remote branch name, remote name, and commit hash that
  54. should be tracked for changes.
  55. Returns:
  56. ("refs/heads/mybranch", "origin", "1A2B3C4...")
  57. """
  58. expected_sha = None
  59. ref = None
  60. remote = git("remote", "show", "-n").splitlines()[0]
  61. ci = get_current_ci()
  62. if ci == GITHUB:
  63. expected_sha = os.getenv("GITHUB_HEAD_SHA") or os.environ["GITHUB_SHA"]
  64. ref = replace_suffix(os.environ["GITHUB_REF"], "/merge", "/head")
  65. elif ci == TRAVIS:
  66. pr = os.getenv("TRAVIS_PULL_REQUEST", "false")
  67. if pr != "false":
  68. expected_sha = os.environ["TRAVIS_PULL_REQUEST_SHA"]
  69. ref = "refs/pull/{}/head".format(pr)
  70. else:
  71. expected_sha = os.environ["TRAVIS_COMMIT"]
  72. ref = "refs/heads/{}".format(os.environ["TRAVIS_BRANCH"])
  73. result = (ref, remote, expected_sha)
  74. if not all(result):
  75. msg = "Invalid remote {!r}, ref {!r}, or hash {!r} for CI {!r}"
  76. raise ValueError(msg.format(remote, ref, expected_sha, ci))
  77. return result
  78. def get_commit_metadata(hash):
  79. """Get the commit info (content hash, parents, message, etc.) as a list of
  80. key-value pairs.
  81. """
  82. info = git("cat-file", "-p", hash)
  83. parts = info.split("\n\n", 1) # Split off the commit message
  84. records = parts[0]
  85. message = parts[1] if len(parts) > 1 else None
  86. result = []
  87. records = records.replace("\n ", "\0 ") # Join multiple lines into one
  88. for record in records.splitlines(True):
  89. (key, value) = record.split(" ", 1)
  90. value = value.replace("\0 ", "\n ") # Re-split lines
  91. result.append((key, value))
  92. result.append(("message", message))
  93. return result
  94. def terminate_my_process_group():
  95. result = 0
  96. timeout = 15
  97. try:
  98. logger.warning("Attempting kill...")
  99. if sys.platform == "win32":
  100. os.kill(0, signal.CTRL_BREAK_EVENT) # This might get ignored.
  101. time.sleep(timeout)
  102. os.kill(os.getppid(), signal.SIGTERM)
  103. else:
  104. # This SIGTERM seems to be needed to prevent jobs from lingering.
  105. os.kill(os.getppid(), signal.SIGTERM)
  106. time.sleep(timeout)
  107. os.kill(0, signal.SIGKILL)
  108. except OSError as ex:
  109. if ex.errno not in (errno.EBADF, errno.ESRCH):
  110. raise
  111. logger.error("Kill error %s: %s", ex.errno, ex.strerror)
  112. result = ex.errno
  113. return result
  114. def yield_poll_schedule():
  115. schedule = [0, 5, 5, 10, 20, 40, 40] + [60] * 5 + [120] * 10 + [300]
  116. for item in schedule:
  117. yield item
  118. while True:
  119. yield schedule[-1]
  120. def detect_spurious_commit(actual, expected, remote):
  121. """GitHub sometimes spuriously generates commits multiple times with
  122. different dates but identical contents. See here:
  123. https://github.com/travis-ci/travis-ci/issues/7459#issuecomment-601346831
  124. We need to detect whether this might be the case, and we do so by
  125. comparing the commits' contents ("tree" objects) and their parents.
  126. Args:
  127. actual: The commit line on the remote from git ls-remote, e.g.:
  128. da39a3ee5e6b4b0d3255bfef95601890afd80709 refs/heads/master
  129. expected: The commit line initially expected.
  130. Returns:
  131. The new (actual) commit line, if it is suspected to be spurious.
  132. Otherwise, the previously expected commit line.
  133. """
  134. actual_hash = actual.split(None, 1)[0]
  135. expected_hash = expected.split(None, 1)[0]
  136. relevant = ["tree", "parent"] # relevant parts of a commit for comparison
  137. if actual != expected:
  138. git("fetch", "-q", remote, actual_hash)
  139. actual_info = get_commit_metadata(actual_hash)
  140. expected_info = get_commit_metadata(expected_hash)
  141. a = [pair for pair in actual_info if pair[0] in relevant]
  142. b = [pair for pair in expected_info if pair[0] in relevant]
  143. if a == b:
  144. expected = actual
  145. return expected
  146. def should_keep_alive(commit_msg):
  147. result = False
  148. ci = get_current_ci() or ""
  149. for line in commit_msg.splitlines():
  150. parts = line.strip("# ").split(":", 1)
  151. (key, val) = parts if len(parts) > 1 else (parts[0], "")
  152. if key == "CI_KEEP_ALIVE":
  153. ci_names = val.replace(",", " ").lower().split() if val else []
  154. if len(ci_names) == 0 or ci.lower() in ci_names:
  155. result = True
  156. return result
  157. def monitor():
  158. (ref, remote, expected_sha) = git_branch_info_to_track()
  159. expected_line = "{}\t{}".format(expected_sha, ref)
  160. if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
  161. logger.info(
  162. "Not monitoring %s on %s due to keep-alive on: %s",
  163. ref,
  164. remote,
  165. expected_line,
  166. )
  167. return
  168. logger.info(
  169. "Monitoring %s (%s) for changes in %s: %s",
  170. remote,
  171. get_remote_url(remote),
  172. ref,
  173. expected_line,
  174. )
  175. for to_wait in yield_poll_schedule():
  176. time.sleep(to_wait)
  177. status = 0
  178. line = None
  179. try:
  180. # Query the commit on the remote ref (without fetching the commit).
  181. line = git("ls-remote", "--exit-code", remote, ref)
  182. except subprocess.CalledProcessError as ex:
  183. status = ex.returncode
  184. if status == 2:
  185. logger.info(
  186. "Terminating job as %s has been deleted on %s: %s",
  187. ref,
  188. remote,
  189. expected_line,
  190. )
  191. break
  192. elif status != 0:
  193. logger.error(
  194. "Error %d: unable to check %s on %s: %s",
  195. status,
  196. ref,
  197. remote,
  198. expected_line,
  199. )
  200. else:
  201. prev = expected_line
  202. expected_line = detect_spurious_commit(line, expected_line, remote)
  203. if expected_line != line:
  204. logger.info(
  205. "Terminating job as %s has been updated on %s\n"
  206. " from:\t%s\n"
  207. " to: \t%s",
  208. ref,
  209. remote,
  210. expected_line,
  211. line,
  212. )
  213. time.sleep(1) # wait for CI to flush output
  214. break
  215. if expected_line != prev:
  216. logger.info(
  217. "%s appeared to spuriously change on %s\n"
  218. " from:\t%s\n"
  219. " to: \t%s",
  220. ref,
  221. remote,
  222. prev,
  223. expected_line,
  224. )
  225. return terminate_my_process_group()
  226. def main(program, *args):
  227. p = argparse.ArgumentParser()
  228. p.add_argument("--skip_repo", action="append", help="Repo to exclude.")
  229. parsed_args = p.parse_args(args)
  230. skipped_repos = parsed_args.skip_repo or []
  231. repo_slug = get_repo_slug()
  232. event_name = get_ci_event_name()
  233. if repo_slug not in skipped_repos or event_name == "pull_request":
  234. result = monitor()
  235. else:
  236. logger.info("Skipping monitoring %s %s build", repo_slug, event_name)
  237. result = 0
  238. return result
  239. if __name__ == "__main__":
  240. logging.basicConfig(
  241. format="%(levelname)s: %(message)s", stream=sys.stderr, level=logging.DEBUG
  242. )
  243. try:
  244. raise SystemExit(main(*sys.argv) or 0)
  245. except KeyboardInterrupt:
  246. pass