diff --git a/scripts/autogen.py b/scripts/autogen.py index e12f8dea38..fad203bff5 100644 --- a/scripts/autogen.py +++ b/scripts/autogen.py @@ -32,7 +32,7 @@ GUIDES_GH_LOCATION = Path("keras-team") / "keras-io" / "blob" / "master" / "guides" KERAS_TEAM_GH = "https://github.com/keras-team" PROJECT_URL = { - "keras": f"{KERAS_TEAM_GH}/keras/tree/v3.6.0/", + "keras": f"{KERAS_TEAM_GH}/keras/tree/v3.7.0/", "keras_tuner": f"{KERAS_TEAM_GH}/keras-tuner/tree/v1.4.7/", "keras_hub": f"{KERAS_TEAM_GH}/keras-hub/tree/v0.17.0/", "tf_keras": f"{KERAS_TEAM_GH}/tf-keras/tree/v2.18.0/", diff --git a/scripts/render_presets.py b/scripts/render_presets.py index 30b4697d74..b6c599ad80 100644 --- a/scripts/render_presets.py +++ b/scripts/render_presets.py @@ -11,6 +11,8 @@ } """ +from hub_master import MODELS_MASTER + try: import keras_hub except Exception as e: @@ -46,10 +48,18 @@ def format_param_count(metadata): def format_path(metadata): """Returns Path for the given preset""" - try: - return f"[{metadata['official_name']}]({metadata['path']})" - except KeyError: - return "Unknown" + for child in MODELS_MASTER["children"]: + path = child["path"].strip("/") + if metadata["path"] == path: + text = child["title"] + link = f"/keras_hub/api/models/{path}" + return f"[{text}]({link})" + return "-" + + +def format_preset_link(preset, handle): + url = handle.replace("kaggle://", "https://www.kaggle.com/models/") + return f"[{preset}]({url})" def is_base_class(symbol): @@ -61,35 +71,38 @@ def is_base_class(symbol): ) -def render_all_presets(symbols): - """Renders the markdown table for backbone presets as a string.""" +def sort_presets(presets): + # Sort by path and then by parameter count. + return sorted( + presets.keys(), + key=lambda x: ( + presets[x]["metadata"]["path"], + presets[x]["metadata"]["params"], + ) + ) + + +def render_row(preset, data, add_doc_link=False): + """Renders a row for a preset in a markdown table.""" + metadata = data["metadata"] + url = data["kaggle_handle"] + url = url.replace("kaggle://", "https://www.kaggle.com/models/") + cols = [] + cols.append(format_preset_link(preset, data["kaggle_handle"])) + if add_doc_link: + cols.append(format_path(metadata)) + cols.append(format_param_count(metadata)) + cols.append(metadata["description"]) + return " | ".join(cols) + "\n" - table = TABLE_HEADER - # Backbones has alias, which duplicates some presets. - # Use a set to keep them unique. - added_presets = set() - # Bakcbone presets - for name, symbol in symbols: - if is_base_class(symbol) or "Backbone" not in name: - continue - presets = symbol.presets - # Only keep the ones with pretrained weights for KerasCV Backbones. - for preset in presets: - if preset in added_presets: - continue - else: - added_presets.add(preset) - metadata = presets[preset]["metadata"] - url = presets[preset]["kaggle_handle"] - url = url.replace("kaggle://", "https://www.kaggle.com/models/") - table += ( - f"[{preset}]({url}) | " - f"{format_path(metadata)} | " - f"{format_param_count(metadata)} | " - f"{metadata['description']}" - ) - table += "\n" +def render_all_presets(): + """Renders the markdown table for backbone presets as a string.""" + table = TABLE_HEADER + symbol = keras_hub.models.Backbone + for preset in sort_presets(symbol.presets): + data = symbol.presets[preset] + table += render_row(preset, data, add_doc_link=True) return table @@ -100,15 +113,9 @@ def render_table(symbol): table = TABLE_HEADER_PER_MODEL if is_base_class(symbol) or len(symbol.presets) == 0: return None - for preset in symbol.presets: - metadata = symbol.presets[preset]["metadata"] - url = symbol.presets[preset]["kaggle_handle"] - url = url.replace("kaggle://", "https://www.kaggle.com/models/") - table += ( - f"[{preset}]({url}) | " - f"{format_param_count(metadata)} | " - f"{metadata['description']} \n" - ) + for preset in sort_presets(symbol.presets): + data = symbol.presets[preset] + table += render_row(preset, data) return table @@ -117,9 +124,6 @@ def render_tags(template): if keras_hub is None: return template - symbols = keras_hub.models.__dict__.items() if "{{presets_table}}" in template: - template = template.replace( - "{{presets_table}}", render_all_presets(symbols) - ) + template = template.replace("{{presets_table}}", render_all_presets()) return template