Skip to content

Commit

Permalink
Fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Apr 11, 2024
1 parent 85ebdfc commit f52900c
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions utils/deprecate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import argparse
import os
from collections import defaultdict
from typing import Tuple, Optional
from pathlib import Path
from typing import Optional, Tuple

import requests
from git import Repo
from packaging import version

from transformers import CONFIG_MAPPING, logging
from transformers import __version__ as current_version
from transformers import logging, CONFIG_MAPPING


REPO_PATH = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
repo = Repo(REPO_PATH)
Expand Down Expand Up @@ -84,19 +86,19 @@ def get_model_doc_path(model: str) -> Tuple[Optional[str], Optional[str]]:
model_doc_path = REPO_PATH / f"docs/source/en/model_doc/{model.replace('_', '-')}.md"

if os.path.exists(model_doc_path):
return model_doc_path, model.replace('_', '-')
return model_doc_path, model.replace("_", "-")

# Try replacing _ with "" in the model name
model_doc_path = REPO_PATH / f"docs/source/en/model_doc/{model.replace('_', '')}.md"

if os.path.exists(model_doc_path):
return model_doc_path, model.replace('_', '')
return model_doc_path, model.replace("_", "")

return None, None


def extract_model_info(model):
model_info = dict()
model_info = {}
model_doc_path, model_doc_name = get_model_doc_path(model)
model_path = REPO_PATH / f"src/transformers/models/{model}"

Expand Down Expand Up @@ -173,7 +175,6 @@ def get_line_indent(s):
maybe_else_block = []
in_else_block = False
in_base_imports = False
base_import_block = []
open_indent_level = -1

# We iterate over each line in the init file to create a new init file
Expand Down Expand Up @@ -290,7 +291,7 @@ def update_init_file(filename, models):
f.write(init_file)


def remove_model_references_from_file(filename, models, condition=None):
def remove_model_references_from_file(filename, models, condition):
"""
Remove all references to the given models from the given file
Expand All @@ -299,9 +300,6 @@ def remove_model_references_from_file(filename, models, condition=None):
models (List[str]): The models to remove
condition (Callable): A function that takes the line and model and returns True if the line should be removed
"""
if condition is None:
condition = lambda line, model: model == line.strip()

with open(filename, "r") as f:
init_file = f.read()

Expand Down Expand Up @@ -372,8 +370,8 @@ def deprecate_models(models):
for model, model_info in models_info.items():
if model in CONFIG_MAPPING:
model_config_classes.append(CONFIG_MAPPING[model].__name__)
elif model_info['model_doc_name'] in CONFIG_MAPPING:
model_config_classes.append(CONFIG_MAPPING[model_info['model_doc_name']].__name__)
elif model_info["model_doc_name"] in CONFIG_MAPPING:
model_config_classes.append(CONFIG_MAPPING[model_info["model_doc_name"]].__name__)
else:
skipped_models.append(model)
print(f"Model config class not found for model: {model}")
Expand All @@ -385,7 +383,7 @@ def deprecate_models(models):
print(f"Models to deprecate: {models}")

# Remove model config classes from config check
print(f"Removing model config classes from config checks")
print("Removing model config classes from config checks")
remove_model_config_classes_from_config_check("src/transformers/configuration_utils.py", model_config_classes)

tip_message = build_tip_message(get_last_stable_minor_release())
Expand All @@ -407,13 +405,17 @@ def deprecate_models(models):
# We do the following with all models passed at once to avoid having to re-write the file multiple times

# Update the __init__.py file to point to the deprecated model.
print(f"Updating __init__.py file to point to the deprecated models")
print("Updating __init__.py file to point to the deprecated models")
update_init_file("src/transformers/__init__.py", models)

# Remove model references from other files
print(f"Removing model references from other files")
remove_model_references_from_file("src/transformers/models/__init__.py", models, lambda line, model: model == line.strip().strip(","))
remove_model_references_from_file("utils/slow_documentation_tests.txt", models, lambda line, model: "/" + model + "/" in line)
print("Removing model references from other files")
remove_model_references_from_file(
"src/transformers/models/__init__.py", models, lambda line, model: model == line.strip().strip(",")
)
remove_model_references_from_file(
"utils/slow_documentation_tests.txt", models, lambda line, model: "/" + model + "/" in line
)


if __name__ == "__main__":
Expand Down

0 comments on commit f52900c

Please sign in to comment.