Skip to content

Commit

Permalink
Automatically download models
Browse files Browse the repository at this point in the history
  • Loading branch information
movabo committed Dec 8, 2021
1 parent f546cc6 commit 181b970
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
6 changes: 5 additions & 1 deletion DeepNewsSentiment/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from DeepNewsSentiment.models.FXBaseModel import FXBaseModel

logger = get_logger()
nlp = spacy.load("en_core_web_sm")
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# this is what we could run on spacy 2.2+
# nlp_dep_parser_labels = list(nlp.parser.labels)
# but since we're on spacy 2.1 (and need to be because of newsalyze-backend)
Expand Down
10 changes: 4 additions & 6 deletions DeepNewsSentiment/download.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
"""
Download a specific version of a finetuned model and place it in pretrained_models.
"""
import argparse
import os
import string
from contextlib import suppress

import torch

from DeepNewsSentiment.fxlogger import get_logger


class Download:
MODEL_DIRECTORY = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "pretrained_models", "state_dicts"
)

def __init__(
self, own_model_name, version="default", force=False, list_versions=False
Expand Down Expand Up @@ -76,7 +71,10 @@ def model_filename(model_cls, version=None):

@classmethod
def model_path(cls, model_cls, version=None):
return os.path.join(cls.MODEL_DIRECTORY, cls.model_filename(model_cls, version))
return os.path.join(torch.hub.get_dir(),
"pretrained_models",
"state_dicts",
cls.model_filename(model_cls, version))

@staticmethod
def add_subparser(subparser):
Expand Down
File renamed without changes.

0 comments on commit 181b970

Please sign in to comment.