Skip to content

Commit

Permalink
Automated tagging (pymc-devs#225)
Browse files Browse the repository at this point in the history
* automated tagging

* doc
  • Loading branch information
MarcoGorelli authored Sep 13, 2021
1 parent f48e7a6 commit b9820cc
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 3 deletions.
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,12 @@ repos:
language: python
name: Check all notebooks appear in table of contents
types: [jupyter]
- id: add-tags
entry: python scripts/add_tags.py
language: python
name: Add PyMC3 classes used to tags
types: [jupyter]
additional_dependencies:
- nbqa==1.1.1
- beautifulsoup4==4.9.3
- myst_parser==0.13.7
2 changes: 1 addition & 1 deletion examples/case_studies/multilevel_modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"# A Primer on Bayesian Methods for Multilevel Modeling\n",
"\n",
":::{post} 30 Aug, 2021\n",
":tags: hierarchical\n",
":tags: hierarchical, pymc3.Data, pymc3.Deterministic, pymc3.Exponential, pymc3.LKJCholeskyCov, pymc3.Model, pymc3.MvNormal, pymc3.Normal\n",
":category: intermediate\n",
":::"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/case_studies/rugby_analytics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"# A Hierarchical model for Rugby prediction\n",
"\n",
":::{post} 30 Aug, 2021\n",
":tags: hierarchical, sports\n",
":tags: hierarchical, pymc3.Data, pymc3.Deterministic, pymc3.HalfNormal, pymc3.Model, pymc3.Normal, pymc3.Poisson, sports\n",
":category: intermediate\n",
":::"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"Note: This text is based on the [PeerJ CS publication on PyMC3](https://peerj.com/articles/cs-55/).\n",
"\n",
":::{post} 30 Aug, 2021\n",
":tags: glm, mcmc, exploratory analysis\n",
":tags: exploratory analysis, glm, mcmc, pymc3.Data, pymc3.Deterministic, pymc3.DiscreteUniform, pymc3.Exponential, pymc3.GaussianRandomWalk, pymc3.HalfNormal, pymc3.Model, pymc3.Normal, pymc3.Poisson, pymc3.Slice, pymc3.StudentT\n",
":category: beginner\n",
":::"
]
Expand Down
Empty file added scripts/__init__.py
Empty file.
92 changes: 92 additions & 0 deletions scripts/add_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Automatically add tags to notebook based on which PyMC3 classes are used.
E.g. if a notebook contains a section like
:::{post} 30 Aug, 2021
:tags: glm, mcmc, exploratory analysis
:category: beginner
:::
in a markdown cell, and uses the class pymc3.Categorical, then this script
will change that part of the markdown cell to:
:::{post} 30 Aug, 2021
:tags: glm, mcmc, exploratory analysis, pymc3.Categorical
:category: beginner
:::
Example of how to run this:
python scripts/add_tags.py examples/getting_started.ipynb
"""
import sys
from myst_parser.main import to_tokens, MdParserConfig
import subprocess
import os
import json
import argparse


def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument("files", nargs="*")
args = parser.parse_args(argv)

for file in args.files:
# Find which PyMC3 classes are used in the code.
output = subprocess.run(
[
"nbqa",
"scripts.find_pm_classes",
file,
],
stdout=subprocess.PIPE,
text=True,
)
classes = {f"pymc3.{obj}" for obj in output.stdout.splitlines()}

# Tokenize the notebook's markdown cells.
with open(file, encoding="utf-8") as fd:
content = fd.read()
notebook = json.loads(content)
markdown_cells = "\n".join(
[
"\n".join(cell["source"])
for cell in notebook["cells"]
if cell["cell_type"] == "markdown"
]
)
config = MdParserConfig(enable_extensions=["dollarmath", "colon_fence"])
tokens = to_tokens(markdown_cells, config=config)

# Find a ```{post} or :::{post} code block, and look for a line
# starting with tags: or :tags:.
tags = None
for token in tokens:
if token.tag == "code" and token.info.startswith("{post}"):
for line in token.content.splitlines():
if line.startswith("tags: "):
line_start = "tags: "
original_line = line
tags = {tag.strip() for tag in line[len(line_start) :].split(",")}
break
elif line.startswith(":tags: "):
line_start = ":tags: "
original_line = line
tags = {tag.strip() for tag in line[len(line_start) :].split(",")}
break

# If tags were found, then append any PyMC3 classes which might have
# been missed.
if tags is not None:
new_tags = ", ".join(sorted(tags.union(classes)))
new_line = f"{line_start}{new_tags}"
content = content.replace(original_line, new_line)
with open(file, "w", encoding="utf-8") as fd:
fd.write(content)


if __name__ == "__main__":
exit(main())
57 changes: 57 additions & 0 deletions scripts/find_pm_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Find all PyMC3 classes used in script.
This'll find both call of
pymc3.Categorical(...
and
from pymc3 import Categorical
Categorical
"""
import ast
import sys


class ImportVisitor(ast.NodeVisitor):
def __init__(self, file):
self.imports = set()

def visit_ImportFrom(self, node: ast.ImportFrom):
if node.module.split(".")[0] == "pymc3":
for name in node.names:
if name.name[0].isupper():
self.imports.add(name.name)


class CallVisitor(ast.NodeVisitor):
def __init__(self, file, imports):
self.file = file
self.imports = imports
self.classes_used = set()

def visit_Call(self, node: ast.Call):
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
if node.func.value.id in {"pm", "pymc3"}:
if node.func.attr[0].isupper():
self.classes_used.add(node.func.attr)
elif isinstance(node.func, ast.Name):
if node.func.id in self.imports:
self.classes_used.add(node.func.id)


if __name__ == "__main__":
for file in sys.argv[1:]:
with open(file) as fd:
content = fd.read()
tree = ast.parse(content)

import_visitor = ImportVisitor(file)
import_visitor.visit(tree)

visitor = CallVisitor(file, import_visitor.imports)
visitor.visit(tree)
for i in visitor.classes_used:
print(i)

0 comments on commit b9820cc

Please sign in to comment.