forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
remote-watch.py
executable file
·268 lines (219 loc) · 8.58 KB
/
remote-watch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
#!/usr/bin/env python
"""
This command must be run in a git repository.
It watches the remote branch for changes, killing the given PID when it detects
that the remote branch no longer points to the local commit.
If the commit message contains a line saying "CI_KEEP_ALIVE", then killing
will not occur until the branch is deleted from the remote.
If no PID is given, then the entire process group of this process is killed.
"""
# Prefer to keep this file Python 2-compatible so that it can easily run early
# in the CI process on any system.
import argparse
import errno
import logging
import os
import signal
import subprocess
import sys
import time
logger = logging.getLogger(__name__)
GITHUB = "GitHub"
TRAVIS = "Travis"
def git(*args):
cmdline = ["git"] + list(args)
return subprocess.check_output(cmdline).decode("utf-8").rstrip()
def get_current_ci():
if "GITHUB_WORKFLOW" in os.environ:
return GITHUB
elif "TRAVIS" in os.environ:
return TRAVIS
return None
def get_ci_event_name():
ci = get_current_ci()
if ci == GITHUB:
return os.environ["GITHUB_EVENT_NAME"]
elif ci == TRAVIS:
return os.environ["TRAVIS_EVENT_TYPE"]
return None
def get_repo_slug():
ci = get_current_ci()
if ci == GITHUB:
return os.environ["GITHUB_REPOSITORY"]
elif ci == TRAVIS:
return os.environ["TRAVIS_REPO_SLUG"]
return None
def get_remote_url(remote):
return git("ls-remote", "--get-url", remote)
def replace_suffix(base, old_suffix, new_suffix=""):
if base.endswith(old_suffix):
base = base[:len(base) - len(old_suffix)] + new_suffix
return base
def git_branch_info_to_track():
"""Obtains the remote branch name, remote name, and commit hash that
should be tracked for changes.
Returns:
("refs/heads/mybranch", "origin", "1A2B3C4...")
"""
expected_sha = None
ref = None
remote = git("remote", "show", "-n").splitlines()[0]
ci = get_current_ci()
if ci == GITHUB:
expected_sha = os.getenv("GITHUB_HEAD_SHA") or os.environ["GITHUB_SHA"]
ref = replace_suffix(os.environ["GITHUB_REF"], "/merge", "/head")
elif ci == TRAVIS:
pr = os.getenv("TRAVIS_PULL_REQUEST", "false")
if pr != "false":
expected_sha = os.environ["TRAVIS_PULL_REQUEST_SHA"]
ref = "refs/pull/{}/head".format(pr)
else:
expected_sha = os.environ["TRAVIS_COMMIT"]
ref = "refs/heads/{}".format(os.environ["TRAVIS_BRANCH"])
result = (ref, remote, expected_sha)
if not all(result):
msg = "Invalid remote {!r}, ref {!r}, or hash {!r} for CI {!r}"
raise ValueError(msg.format(remote, ref, expected_sha, ci))
return result
def get_commit_metadata(hash):
"""Get the commit info (content hash, parents, message, etc.) as a list of
key-value pairs.
"""
info = git("cat-file", "-p", hash)
parts = info.split("\n\n", 1) # Split off the commit message
records = parts[0]
message = parts[1] if len(parts) > 1 else None
result = []
records = records.replace("\n ", "\0 ") # Join multiple lines into one
for record in records.splitlines(True):
(key, value) = record.split(" ", 1)
value = value.replace("\0 ", "\n ") # Re-split lines
result.append((key, value))
result.append(("message", message))
return result
def terminate_my_process_group():
result = 0
timeout = 15
try:
logger.warning("Attempting kill...")
if sys.platform == "win32":
os.kill(0, signal.CTRL_BREAK_EVENT) # This might get ignored.
time.sleep(timeout)
os.kill(os.getppid(), signal.SIGTERM)
else:
# This SIGTERM seems to be needed to prevent jobs from lingering.
os.kill(os.getppid(), signal.SIGTERM)
time.sleep(timeout)
os.kill(0, signal.SIGKILL)
except OSError as ex:
if ex.errno not in (errno.EBADF, errno.ESRCH):
raise
logger.error("Kill error %s: %s", ex.errno, ex.strerror)
result = ex.errno
return result
def yield_poll_schedule():
schedule = [0, 5, 5, 10, 20, 40, 40] + [60] * 5 + [120] * 10 + [300]
for item in schedule:
yield item
while True:
yield schedule[-1]
def detect_spurious_commit(actual, expected, remote):
"""GitHub sometimes spuriously generates commits multiple times with
different dates but identical contents. See here:
https://github.com/travis-ci/travis-ci/issues/7459#issuecomment-601346831
We need to detect whether this might be the case, and we do so by
comparing the commits' contents ("tree" objects) and their parents.
Args:
actual: The commit line on the remote from git ls-remote, e.g.:
da39a3ee5e6b4b0d3255bfef95601890afd80709 refs/heads/master
expected: The commit line initially expected.
Returns:
The new (actual) commit line, if it is suspected to be spurious.
Otherwise, the previously expected commit line.
"""
actual_hash = actual.split(None, 1)[0]
expected_hash = expected.split(None, 1)[0]
relevant = ["tree", "parent"] # relevant parts of a commit for comparison
if actual != expected:
git("fetch", "-q", remote, actual_hash)
actual_info = get_commit_metadata(actual_hash)
expected_info = get_commit_metadata(expected_hash)
a = [pair for pair in actual_info if pair[0] in relevant]
b = [pair for pair in expected_info if pair[0] in relevant]
if a == b:
expected = actual
return expected
def should_keep_alive(commit_msg):
result = False
ci = get_current_ci() or ""
for line in commit_msg.splitlines():
parts = line.strip("# ").split(":", 1)
(key, val) = parts if len(parts) > 1 else (parts[0], "")
if key == "CI_KEEP_ALIVE":
ci_names = val.replace(",", " ").lower().split() if val else []
if len(ci_names) == 0 or ci.lower() in ci_names:
result = True
return result
def monitor():
(ref, remote, expected_sha) = git_branch_info_to_track()
expected_line = "{}\t{}".format(expected_sha, ref)
if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
logger.info("Not monitoring %s on %s due to keep-alive on: %s", ref,
remote, expected_line)
return
logger.info("Monitoring %s (%s) for changes in %s: %s", remote,
get_remote_url(remote), ref, expected_line)
for to_wait in yield_poll_schedule():
time.sleep(to_wait)
status = 0
line = None
try:
# Query the commit on the remote ref (without fetching the commit).
line = git("ls-remote", "--exit-code", remote, ref)
except subprocess.CalledProcessError as ex:
status = ex.returncode
if status == 2:
logger.info("Terminating job as %s has been deleted on %s: %s",
ref, remote, expected_line)
break
elif status != 0:
logger.error("Error %d: unable to check %s on %s: %s", status, ref,
remote, expected_line)
else:
prev = expected_line
expected_line = detect_spurious_commit(line, expected_line, remote)
if expected_line != line:
logger.info(
"Terminating job as %s has been updated on %s\n"
" from:\t%s\n"
" to: \t%s", ref, remote, expected_line, line)
time.sleep(1) # wait for CI to flush output
break
if expected_line != prev:
logger.info(
"%s appeared to spuriously change on %s\n"
" from:\t%s\n"
" to: \t%s", ref, remote, prev, expected_line)
return terminate_my_process_group()
def main(program, *args):
p = argparse.ArgumentParser()
p.add_argument("--skip_repo", action="append", help="Repo to exclude.")
parsed_args = p.parse_args(args)
skipped_repos = parsed_args.skip_repo or []
repo_slug = get_repo_slug()
event_name = get_ci_event_name()
if repo_slug not in skipped_repos or event_name == "pull_request":
result = monitor()
else:
logger.info("Skipping monitoring %s %s build", repo_slug, event_name)
result = 0
return result
if __name__ == "__main__":
logging.basicConfig(
format="%(levelname)s: %(message)s",
stream=sys.stderr,
level=logging.DEBUG)
try:
raise SystemExit(main(*sys.argv) or 0)
except KeyboardInterrupt:
pass