Skip to content

Commit

Permalink
Clean up unit tests (#1568)
Browse files Browse the repository at this point in the history
  • Loading branch information
lintool authored Jul 8, 2023
1 parent 748bf5b commit d57bf4b
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 57 deletions.
11 changes: 7 additions & 4 deletions tests/test_index_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@
#

import os
import random
import shutil
import unittest

from pyserini.search.lucene import LuceneSearcher


class TestIndexDownload(unittest.TestCase):
def setUp(self):
self.tmp_dir = f'tmp_{self.__class__.__name__}_{str(random.randint(0, 1000))}'

def test_default_cache(self):
LuceneSearcher.from_prebuilt_index('cacm')
self.assertTrue(os.path.exists(os.path.expanduser('~/.cache/pyserini/indexes')))

def test_custom_cache(self):
os.environ['PYSERINI_CACHE'] = 'temp_dir'
os.environ['PYSERINI_CACHE'] = self.tmp_dir
LuceneSearcher.from_prebuilt_index('cacm')
self.assertTrue(os.path.exists('temp_dir/indexes'))
self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'indexes')))

def tearDown(self):
if os.path.exists('temp_dir'):
shutil.rmtree('temp_dir')
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)
os.environ['PYSERINI_CACHE'] = ''
17 changes: 10 additions & 7 deletions tests/test_index_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,31 @@

import json
import os
import random

import faiss
import shutil
import unittest
import pathlib as pl


class TestSearch(unittest.TestCase):
class TestIndexFaiss(unittest.TestCase):
@staticmethod
def assertIsFile(path):
if not pl.Path(path).resolve().is_file():
raise AssertionError("File does not exist: %s" % str(path))

def setUp(self):
self.docids = []
self.texts = []
self.test_file = 'tests/resources/simple_cacm_corpus.json'
self.tmp_dir = "temp_dir"
self.tmp_dir = f'tmp_{self.__class__.__name__}_{str(random.randint(0, 1000))}'

with open(self.test_file) as f:
for line in f:
line = json.loads(line)
self.docids.append(line['id'])
self.texts.append(line['contents'])

def assertIsFile(self, path):
if not pl.Path(path).resolve().is_file():
raise AssertionError("File does not exist: %s" % str(path))

def prepare_encoded_collection(self):
encoded_corpus_dir = f'{self.tmp_dir}/temp_index'
Expand Down Expand Up @@ -116,4 +119,4 @@ def test_faiss_pq(self):
self.assertAlmostEqual(vectors[2][-1], 0.075478144, places=4)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
shutil.rmtree(self.tmp_dir)
5 changes: 3 additions & 2 deletions tests/test_index_otf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
import os
import shutil
import unittest
import random
from typing import List

from pyserini.index.lucene import LuceneIndexer, IndexReader, JacksonObjectMapper
from pyserini.search.lucene import JLuceneSearcherResult, LuceneSearcher


class TestSearch(unittest.TestCase):
class TestIndexOTF(unittest.TestCase):
def setUp(self):
self.docs = []
self.tmp_dir = "temp_dir"
self.tmp_dir = f'tmp_{self.__class__.__name__}_{str(random.randint(0, 1000))}'

# The current directory depends on if you're running inside an IDE or from command line.
curdir = os.getcwd()
Expand Down
59 changes: 26 additions & 33 deletions tests/test_load_qrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,20 @@
# limitations under the License.
#

import os
import shutil
import unittest

from pyserini import search


def read_file_lines(path):
with open(path) as f:
return f.readlines()
class TestLoadQrels(unittest.TestCase):
@staticmethod
def read_file_lines(path):
with open(path) as f:
return f.readlines()


class TestGetQrels(unittest.TestCase):

def setUp(self):
os.environ['PYSERINI_CACHE'] = 'temp_dir'
# Note that these test cases download and cache qrels in ~/.cache/anserini/topics-and-qrels,
# which is hard-coded from the Anserini end. So if the original source is unavailable, these
# tests will still pass.

def test_trec1_adhoc(self):
qrels = search.get_qrels('trec1-adhoc')
Expand All @@ -51,7 +49,7 @@ def test_trec3_adhoc(self):

def test_robust04a(self):
qrels_path = search.get_qrels_file('robust04')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length//2].rstrip()
Expand All @@ -69,7 +67,7 @@ def test_robust04b(self):

def test_robust05a(self):
qrels_path = search.get_qrels_file('robust05')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -87,7 +85,7 @@ def test_robust05b(self):

def test_core17a(self):
qrels_path = search.get_qrels_file('core17')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -105,7 +103,7 @@ def test_core17b(self):

def test_core18a(self):
qrels_path = search.get_qrels_file('core18')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand Down Expand Up @@ -195,7 +193,7 @@ def test_mb14(self):

def test_car15a(self):
qrels_path = search.get_qrels_file('car17v1.5-benchmarkY1test')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -216,7 +214,7 @@ def test_car15b(self):

def test_car20a(self):
qrels_path = search.get_qrels_file('car17v2.0-benchmarkY1test')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -236,7 +234,7 @@ def test_car20b(self):

def test_msmarco_doc1(self):
qrels_path = search.get_qrels_file('msmarco-doc-dev')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -254,7 +252,7 @@ def test_msmarco_doc2(self):

def test_msmarco_passage1(self):
qrels_path = search.get_qrels_file('msmarco-passage-dev-subset')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand Down Expand Up @@ -398,7 +396,7 @@ def test_fire2012_en(self):

def test_covid_round1(self):
qrels_path = search.get_qrels_file('covid-round1')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -415,7 +413,7 @@ def test_covid_round1(self):

def test_covid_round2(self):
qrels_path = search.get_qrels_file('covid-round2')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -432,7 +430,7 @@ def test_covid_round2(self):

def test_covid_round3(self):
qrels_path = search.get_qrels_file('covid-round3')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -449,7 +447,7 @@ def test_covid_round3(self):

def test_covid_round4(self):
qrels_path = search.get_qrels_file('covid-round4')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -466,7 +464,7 @@ def test_covid_round4(self):

def test_covid_round5(self):
qrels_path = search.get_qrels_file('covid-round5')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -483,7 +481,7 @@ def test_covid_round5(self):

def test_covid_round3_cumulative(self):
qrels_path = search.get_qrels_file('covid-round3-cumulative')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -495,7 +493,7 @@ def test_covid_round3_cumulative(self):

def test_covid_round4_cumulative(self):
qrels_path = search.get_qrels_file('covid-round4-cumulative')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -507,7 +505,7 @@ def test_covid_round4_cumulative(self):

def test_covid_complete(self):
qrels_path = search.get_qrels_file('covid-complete')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -519,7 +517,7 @@ def test_covid_complete(self):

def test_trec2018_bl(self):
qrels_path = search.get_qrels_file('trec2018-bl')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand All @@ -536,7 +534,7 @@ def test_trec2018_bl(self):

def test_trec2019_bl(self):
qrels_path = search.get_qrels_file('trec2019-bl')
lines = read_file_lines(qrels_path)
lines = self.read_file_lines(qrels_path)
length = len(lines)
first_line = lines[0].rstrip()
mid_line = lines[length // 2].rstrip()
Expand Down Expand Up @@ -1019,11 +1017,6 @@ def test_hc4_neuclir22(self):
self.assertEqual(len(qrels), 60)
self.assertTrue(isinstance(next(iter(qrels.keys())), int))

def tearDown(self):
if os.path.exists('temp_dir'):
shutil.rmtree('temp_dir')
os.environ['PYSERINI_CACHE'] = ''


if __name__ == '__main__':
unittest.main()
30 changes: 19 additions & 11 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,28 @@


class TestSearch(unittest.TestCase):
tarball_name = None
collection_url = None
searcher = None
searcher_index_dir = None
no_vec_searcher = None
no_vec_searcher_index_dir = None

@classmethod
def setUpClass(cls):
# Download pre-built CACM index built using Lucene 9; append a random value to avoid filename clashes.
r = randint(0, 10000000)
cls.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene9-index.cacm.tar.gz'
cls.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
cls.index_dir = 'index{}/'.format(r)
cls.searcher_index_dir = 'index{}/'.format(r)

filename, headers = urlretrieve(cls.collection_url, cls.tarball_name)
urlretrieve(cls.collection_url, cls.tarball_name)

tarball = tarfile.open(cls.tarball_name)
tarball.extractall(cls.index_dir)
tarball.extractall(cls.searcher_index_dir)
tarball.close()

cls.searcher = LuceneSearcher(f'{cls.index_dir}lucene9-index.cacm')
cls.searcher = LuceneSearcher(f'{cls.searcher_index_dir}lucene9-index.cacm')

# Create index without document vectors
# The current directory depends on if you're running inside an IDE or from command line.
Expand All @@ -50,12 +57,13 @@ def setUpClass(cls):
corpus_path = '../tests/resources/sample_collection_json'
else:
corpus_path = 'tests/resources/sample_collection_json'
cls.no_vec_index_dir = 'no_vec_index'

cls.no_vec_searcher_index_dir = 'no_vec_index'
cmd1 = f'python -m pyserini.index.lucene -collection JsonCollection ' + \
f'-generator DefaultLuceneDocumentGenerator ' + \
f'-threads 1 -input {corpus_path} -index {cls.no_vec_index_dir}'
f'-threads 1 -input {corpus_path} -index {cls.no_vec_searcher_index_dir}'
os.system(cmd1)
cls.no_vec_searcher = LuceneSearcher(cls.no_vec_index_dir)
cls.no_vec_searcher = LuceneSearcher(cls.no_vec_searcher_index_dir)

def test_basic(self):
self.assertTrue(self.searcher.get_similarity().toString().startswith('BM25'))
Expand Down Expand Up @@ -231,7 +239,7 @@ def test_different_similarity(self):
self.assertAlmostEqual(hits[9].score, 4.33320, places=5)

def test_rm3(self):
self.searcher = LuceneSearcher(f'{self.index_dir}lucene9-index.cacm')
self.searcher = LuceneSearcher(f'{self.searcher_index_dir}lucene9-index.cacm')
self.searcher.set_rm3()
self.assertTrue(self.searcher.is_using_rm3())

Expand Down Expand Up @@ -271,7 +279,7 @@ def test_rm3(self):
self.no_vec_searcher.set_rm3()

def test_rocchio(self):
self.searcher = LuceneSearcher(f'{self.index_dir}lucene9-index.cacm')
self.searcher = LuceneSearcher(f'{self.searcher_index_dir}lucene9-index.cacm')
self.searcher.set_rocchio()
self.assertTrue(self.searcher.is_using_rocchio())

Expand Down Expand Up @@ -406,8 +414,8 @@ def tearDownClass(cls):
cls.searcher.close()
cls.no_vec_searcher.close()
os.remove(cls.tarball_name)
shutil.rmtree(cls.index_dir)
shutil.rmtree(cls.no_vec_index_dir)
shutil.rmtree(cls.searcher_index_dir)
shutil.rmtree(cls.no_vec_searcher_index_dir)


if __name__ == '__main__':
Expand Down
5 changes: 5 additions & 0 deletions tests/test_search_lucene8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@


class TestSearch(unittest.TestCase):
searcher = None
index_dir = None
collection_url = None
tarball_name = None

@classmethod
def setUpClass(cls):
# Download pre-built CACM index; append a random value to avoid filename clashes.
Expand Down

0 comments on commit d57bf4b

Please sign in to comment.