forked from pymc-devs/pymc-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
161 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
Automatically add tags to notebook based on which PyMC3 classes are used. | ||
E.g. if a notebook contains a section like | ||
:::{post} 30 Aug, 2021 | ||
:tags: glm, mcmc, exploratory analysis | ||
:category: beginner | ||
::: | ||
in a markdown cell, and uses the class pymc3.Categorical, then this script | ||
will change that part of the markdown cell to: | ||
:::{post} 30 Aug, 2021 | ||
:tags: glm, mcmc, exploratory analysis, pymc3.Categorical | ||
:category: beginner | ||
::: | ||
Example of how to run this: | ||
python scripts/add_tags.py examples/getting_started.ipynb | ||
""" | ||
import sys | ||
from myst_parser.main import to_tokens, MdParserConfig | ||
import subprocess | ||
import os | ||
import json | ||
import argparse | ||
|
||
|
||
def main(argv=None): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("files", nargs="*") | ||
args = parser.parse_args(argv) | ||
|
||
for file in args.files: | ||
# Find which PyMC3 classes are used in the code. | ||
output = subprocess.run( | ||
[ | ||
"nbqa", | ||
"scripts.find_pm_classes", | ||
file, | ||
], | ||
stdout=subprocess.PIPE, | ||
text=True, | ||
) | ||
classes = {f"pymc3.{obj}" for obj in output.stdout.splitlines()} | ||
|
||
# Tokenize the notebook's markdown cells. | ||
with open(file, encoding="utf-8") as fd: | ||
content = fd.read() | ||
notebook = json.loads(content) | ||
markdown_cells = "\n".join( | ||
[ | ||
"\n".join(cell["source"]) | ||
for cell in notebook["cells"] | ||
if cell["cell_type"] == "markdown" | ||
] | ||
) | ||
config = MdParserConfig(enable_extensions=["dollarmath", "colon_fence"]) | ||
tokens = to_tokens(markdown_cells, config=config) | ||
|
||
# Find a ```{post} or :::{post} code block, and look for a line | ||
# starting with tags: or :tags:. | ||
tags = None | ||
for token in tokens: | ||
if token.tag == "code" and token.info.startswith("{post}"): | ||
for line in token.content.splitlines(): | ||
if line.startswith("tags: "): | ||
line_start = "tags: " | ||
original_line = line | ||
tags = {tag.strip() for tag in line[len(line_start) :].split(",")} | ||
break | ||
elif line.startswith(":tags: "): | ||
line_start = ":tags: " | ||
original_line = line | ||
tags = {tag.strip() for tag in line[len(line_start) :].split(",")} | ||
break | ||
|
||
# If tags were found, then append any PyMC3 classes which might have | ||
# been missed. | ||
if tags is not None: | ||
new_tags = ", ".join(sorted(tags.union(classes))) | ||
new_line = f"{line_start}{new_tags}" | ||
content = content.replace(original_line, new_line) | ||
with open(file, "w", encoding="utf-8") as fd: | ||
fd.write(content) | ||
|
||
|
||
if __name__ == "__main__": | ||
exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
Find all PyMC3 classes used in script. | ||
This'll find both call of | ||
pymc3.Categorical(... | ||
and | ||
from pymc3 import Categorical | ||
Categorical | ||
""" | ||
import ast | ||
import sys | ||
|
||
|
||
class ImportVisitor(ast.NodeVisitor): | ||
def __init__(self, file): | ||
self.imports = set() | ||
|
||
def visit_ImportFrom(self, node: ast.ImportFrom): | ||
if node.module.split(".")[0] == "pymc3": | ||
for name in node.names: | ||
if name.name[0].isupper(): | ||
self.imports.add(name.name) | ||
|
||
|
||
class CallVisitor(ast.NodeVisitor): | ||
def __init__(self, file, imports): | ||
self.file = file | ||
self.imports = imports | ||
self.classes_used = set() | ||
|
||
def visit_Call(self, node: ast.Call): | ||
if isinstance(node.func, ast.Attribute): | ||
if isinstance(node.func.value, ast.Name): | ||
if node.func.value.id in {"pm", "pymc3"}: | ||
if node.func.attr[0].isupper(): | ||
self.classes_used.add(node.func.attr) | ||
elif isinstance(node.func, ast.Name): | ||
if node.func.id in self.imports: | ||
self.classes_used.add(node.func.id) | ||
|
||
|
||
if __name__ == "__main__": | ||
for file in sys.argv[1:]: | ||
with open(file) as fd: | ||
content = fd.read() | ||
tree = ast.parse(content) | ||
|
||
import_visitor = ImportVisitor(file) | ||
import_visitor.visit(tree) | ||
|
||
visitor = CallVisitor(file, import_visitor.imports) | ||
visitor.visit(tree) | ||
for i in visitor.classes_used: | ||
print(i) |