Skip to content

Commit

Permalink
[testing] a new TestCasePlus subclass + get_auto_remove_tmp_dir() (hu…
Browse files Browse the repository at this point in the history
…ggingface#6494)

* [testing] switch to a new TestCasePlus + get_auto_remove_tmp_dir() for auto-removal of tmp dirs

* respect after=True for tempfile, simplify code

* comments

* comment fix

* put `before` last in args, so can make debug even faster
  • Loading branch information
stas00 authored Aug 17, 2020
1 parent 36010cb commit 9dbe409
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 49 deletions.
23 changes: 8 additions & 15 deletions examples/bert-loses-patience/test_run_glue_with_pabee.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import argparse
import logging
import shutil
import sys
import unittest
from unittest.mock import patch

import run_glue_with_pabee
from transformers.testing_utils import TestCasePlus


logging.basicConfig(level=logging.DEBUG)
Expand All @@ -20,20 +19,19 @@ def get_setup_file():
return args.f


def clean_test_dir(path):
shutil.rmtree(path, ignore_errors=True)


class PabeeTests(unittest.TestCase):
class PabeeTests(TestCasePlus):
def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

testargs = """
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue_with_pabee.py
--model_type albert
--model_name_or_path albert-base-v2
--data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir {tmp_dir}
--overwrite_output_dir
--task_name mrpc
--do_train
--do_eval
Expand All @@ -42,16 +40,11 @@ def test_run_glue(self):
--learning_rate=2e-5
--max_steps=50
--warmup_steps=2
--overwrite_output_dir
--seed=42
--max_seq_length=128
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
""".split()

with patch.object(sys, "argv", testargs):
result = run_glue_with_pabee.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)

clean_test_dir(output_dir)
58 changes: 24 additions & 34 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import argparse
import logging
import os
import shutil
import sys
import unittest
from unittest.mock import patch

import torch

from transformers.testing_utils import TestCasePlus


SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
Expand Down Expand Up @@ -52,19 +52,18 @@ def get_setup_file():
return args.f


def clean_test_dir(path):
shutil.rmtree(path, ignore_errors=True)


class ExamplesTests(unittest.TestCase):
class ExamplesTests(TestCasePlus):
def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

testargs = """
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue.py
--model_name_or_path distilbert-base-uncased
--data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir {tmp_dir}
--overwrite_output_dir
--task_name mrpc
--do_train
--do_eval
Expand All @@ -73,28 +72,26 @@ def test_run_glue(self):
--learning_rate=1e-4
--max_steps=10
--warmup_steps=2
--overwrite_output_dir
--seed=42
--max_seq_length=128
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
""".split()

with patch.object(sys, "argv", testargs):
result = run_glue.main()
del result["eval_loss"]
for value in result.values():
self.assertGreaterEqual(value, 0.75)
clean_test_dir(output_dir)

def test_run_pl_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

testargs = """
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_pl_glue.py
--model_name_or_path bert-base-cased
--data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir {tmp_dir}
--task mrpc
--do_train
--do_predict
Expand All @@ -103,11 +100,7 @@ def test_run_pl_glue(self):
--num_train_epochs=1
--seed=42
--max_seq_length=128
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()

""".split()
if torch.cuda.is_available():
testargs += ["--fp16", "--gpus=1"]

Expand All @@ -123,43 +116,44 @@ def test_run_pl_glue(self):
# for k, v in result.items():
# self.assertGreaterEqual(v, 0.75, f"({k})")
#
clean_test_dir(output_dir)

def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

testargs = """
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_language_modeling.py
--model_name_or_path distilroberta-base
--model_type roberta
--mlm
--line_by_line
--train_data_file ./tests/fixtures/sample_text.txt
--eval_data_file ./tests/fixtures/sample_text.txt
--output_dir {tmp_dir}
--overwrite_output_dir
--do_train
--do_eval
--num_train_epochs=1
--no_cuda
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
""".split()

with patch.object(sys, "argv", testargs):
result = run_language_modeling.main()
self.assertLess(result["perplexity"], 35)
clean_test_dir(output_dir)

def test_run_squad(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

testargs = """
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_squad.py
--model_type=distilbert
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=10
--warmup_steps=2
--do_train
Expand All @@ -168,17 +162,13 @@ def test_run_squad(self):
--learning_rate=2e-4
--per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=1
--overwrite_output_dir
--seed=42
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
""".split()

with patch.object(sys, "argv", testargs):
result = run_squad.main()
self.assertGreaterEqual(result["f1"], 25)
self.assertGreaterEqual(result["exact"], 21)
clean_test_dir(output_dir)

def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout)
Expand Down
92 changes: 92 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import re
import shutil
import sys
import tempfile
import unittest
from distutils.util import strtobool
from io import StringIO
from pathlib import Path

from .file_utils import _tf_available, _torch_available, _torch_tpu_available

Expand Down Expand Up @@ -255,3 +258,92 @@ class CaptureStderr(CaptureStd):

def __init__(self):
super().__init__(out=False)


class TestCasePlus(unittest.TestCase):
"""This class extends `unittest.TestCase` with additional features.
Feature 1: Flexible auto-removable temp dirs which are guaranteed to get
removed at the end of test.
In all the following scenarios the temp dir will be auto-removed at the end
of test, unless `after=False`.
# 1. create a unique temp dir, `tmp_dir` will contain the path to the created temp dir
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir()
# 2. create a temp dir of my choice and delete it at the end - useful for debug when you want to
# monitor a specific directory
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test")
# 3. create a temp dir of my choice and do not delete it at the end - useful for when you want
# to look at the temp results
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False)
# 4. create a temp dir of my choice and ensure to delete it right away - useful for when you
# disabled deletion in the previous test run and want to make sure the that tmp dir is empty
# before the new test is run
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", before=True)
Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the
project repository checkout are allowed if an explicit `tmp_dir` is used, so
that by mistake no `/tmp` or similar important part of the filesystem will
get nuked. i.e. please always pass paths that start with `./`
Note 2: Each test can register multiple temp dirs and they all will get
auto-removed, unless requested otherwise.
"""

def setUp(self):
self.teardown_tmp_dirs = []

def get_auto_remove_tmp_dir(self, tmp_dir=None, after=True, before=False):
"""
Args:
tmp_dir (:obj:`string`, `optional`, defaults to :obj:`None`):
use this path, if None a unique path will be assigned
before (:obj:`bool`, `optional`, defaults to :obj:`False`):
if `True` and tmp dir already exists make sure to empty it right away
after (:obj:`bool`, `optional`, defaults to :obj:`True`):
delete the tmp dir at the end of the test
Returns:
tmp_dir(:obj:`string`):
either the same value as passed via `tmp_dir` or the path to the auto-created tmp dir
"""
if tmp_dir is not None:
# using provided path
path = Path(tmp_dir).resolve()

# to avoid nuking parts of the filesystem, only relative paths are allowed
if not tmp_dir.startswith("./"):
raise ValueError(
f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
)

# ensure the dir is empty to start with
if before is True and path.exists():
shutil.rmtree(tmp_dir, ignore_errors=True)

path.mkdir(parents=True, exist_ok=True)

else:
# using unique tmp dir (always empty, regardless of `before`)
tmp_dir = tempfile.mkdtemp()

if after is True:
# register for deletion
self.teardown_tmp_dirs.append(tmp_dir)

return tmp_dir

def tearDown(self):
# remove registered temp dirs
for path in self.teardown_tmp_dirs:
shutil.rmtree(path, ignore_errors=True)
self.teardown_tmp_dirs = []

0 comments on commit 9dbe409

Please sign in to comment.