-
Notifications
You must be signed in to change notification settings - Fork 463
/
Copy pathverify_pr_labels.py
87 lines (66 loc) · 2.54 KB
/
verify_pr_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
"""Borrowed & adapted from https://github.com/pytorch/vision/blob/main/.github/process_commit.py
This script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in
'.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled,
this script is a no-op.
Note: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision
with no labeling responsibility, so we don't want to bother them.
"""
from typing import Any
import requests
# For a PR to be properly labeled it should have one primary label and one secondary label
# Should specify the type of change
PRIMARY_LABELS = {
"type: new feature",
"type: bug",
"type: enhancement",
"type: misc",
}
# Should specify what has been modified
SECONDARY_LABELS = {
"topic: documentation",
"module: datasets",
"module: io",
"module: models",
"module: transforms",
"module: utils",
"ext: api",
"ext: demo",
"ext: docs",
"ext: notebooks",
"ext: references",
"ext: scripts",
"ext: tests",
"topic: build",
"topic: ci",
"topic: docker",
}
GH_ORG = "mindee"
GH_REPO = "doctr"
def query_repo(cmd: str, *, accept) -> Any:
response = requests.get(f"https://api.github.com/repos/{GH_ORG}/{GH_REPO}/{cmd}", headers=dict(Accept=accept))
return response.json()
def get_pr_merger_and_labels(pr_number: int) -> tuple[str, set[str]]:
# See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request
data = query_repo(f"pulls/{pr_number}", accept="application/vnd.github.v3+json")
merger = data.get("merged_by", {}).get("login")
labels = {label["name"] for label in data["labels"]}
return merger, labels
def main(args):
merger, labels = get_pr_merger_and_labels(args.pr)
is_properly_labeled = bool(PRIMARY_LABELS.intersection(labels) and SECONDARY_LABELS.intersection(labels))
if isinstance(merger, str) and not is_properly_labeled:
print(f"@{merger}")
def parse_args():
import argparse
parser = argparse.ArgumentParser(
description="PR label checker", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("pr", type=int, help="PR number")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)