Skip to content

Commit

Permalink
New support for BSTI, Removed COVID images without visual cues
Browse files Browse the repository at this point in the history
  • Loading branch information
SuryaKalia committed Apr 22, 2020
1 parent 78b04ef commit 3f940db
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 13 deletions.
4 changes: 2 additions & 2 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

We randomly select a subset of patients for `test` and `val` sets.
```
python data_tools/prepare_covid_data.py
python data_tools/prepare_covid_data.py
```
Modify the file and rerun to update the train-val-test data split.

Expand Down Expand Up @@ -38,7 +38,7 @@ python tools/transfer.py [--combine_pneumonia]
Next we take the best model from previous step (according to loss), and fine tune the full model. Since we are interested in increasing the recall of `COVID-19`, we specify the `inc_recall` option to `3` (see our paper [paper](http://arxiv.org/abs/2004.09803) for details).
```
python tools/trainer.py --mode train --checkpoint <PATH_TO_BEST_MOMDEL> --bs 8 --save <PATH_TO_SAVE_MODELS_FOLDER> [--combine_pneumonia]
python tools/trainer.py --mode train --checkpoint <PATH_TO_BEST_MODEL> --bs 8 --save <PATH_TO_SAVE_MODELS_FOLDER> [--combine_pneumonia]
```
## Evaluation
Expand Down
61 changes: 53 additions & 8 deletions data_tools/prepare_covid_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@
import numpy as np
import pandas as pd
from collections import Counter
import argparse
np.random.seed(0)


# Dataset path
COVID_DATA_PATH='./covid-chestxray-dataset'
BSTI_DATA_PATH = "./BSTI-dataset"
METADATA_CSV = os.path.join(COVID_DATA_PATH, 'metadata.csv')
TRAIN_FILE = './data/covid19/train_list.txt'
VAL_FILE = './data/covid19/val_list.txt'
TEST_FILE = './data/covid19/test_list.txt'
BSTI_TRAIN_FILE = './data/covid19/bsti_train_list.txt'
BSTI_VAL_FILE = './data/covid19/bsti_val_list.txt'
BSTI_TEST_FILE = './data/covid19/bsti_test_list.txt'
REMOVED_LIST = './data/covid19/removed_files.txt'

parser = argparse.ArgumentParser()
parser.add_argument("--bsti", action='store_true', default=False)
args = parser.parse_args()

# Load patient stats
covids = dict()
Expand All @@ -30,9 +42,17 @@
print ("Total Images:", total_data, '\n')

# Assign train-val-test split
test_patients = set({4, 15, 86, 59, 6, 82, 80, 78, 76, 65, 36, 32, 50, 18, 115, 152, 138, 70, 116})
val_patients = set({73, 51, 48, 11, 43, 24, 112})
test_patients = set({4, 15, 86, 59, 6, 82, 80, 78, 76, 65, 36, 32, 50, 18, 115, 152, 138, 70, 116, 121, 133, 136, 139, 144, 154, 161, 163, 165})
val_patients = set({73, 51, 48, 11, 43, 24, 112, 181})

removed_files = set()

with open(REMOVED_LIST, 'r') as removed_list:
for filename in removed_list:
filename = filename.rstrip()
removed_files.add(filename)

#Initial values for covid-chestxray-dataset prior to removal
print ('#Train patients:', len(set(covids.keys()).difference(test_patients.union(val_patients))))
print ('#Test patients:', len(test_patients))
print ('#Val patients:', len(val_patients))
Expand All @@ -48,6 +68,8 @@

for i, row in df.iterrows():
patient_id = row['patientid']
if row['filename'] in removed_files:
continue
filename = os.path.join(row['folder'], row['filename'])

if int(patient_id) in test_patients:
Expand All @@ -57,18 +79,41 @@
else:
train_list.append(filename)

print (len(train_list), len(test_list), len(val_list))
print ("covid-chestxray-dataset train-test-val split: ",len(train_list), len(test_list), len(val_list))

# Write image list in file
def make_img_list(data_file, img_file_list):
def make_img_list(data_file, img_file_list, data_path):
with open(data_file, 'w') as f:
for imgfile in img_file_list:
try:
assert os.path.isfile(os.path.join(COVID_DATA_PATH, imgfile))
assert os.path.isfile(os.path.join(data_path, imgfile))
f.write("%s\n" % imgfile)
except:
print ("Image %s NOT FOUND" % imgfile)

make_img_list(TRAIN_FILE, train_list)
make_img_list(VAL_FILE, val_list)
make_img_list(TEST_FILE, test_list)
make_img_list(TRAIN_FILE, train_list, COVID_DATA_PATH)
make_img_list(VAL_FILE, val_list, COVID_DATA_PATH)
make_img_list(TEST_FILE, test_list, COVID_DATA_PATH)

#Include BSTI Dataset
if args.bsti :
# Construct the split lists
bsti_train_list = []
bsti_test_list = []
bsti_val_list = []

bsti_images = [f for f in os.listdir(BSTI_DATA_PATH) if os.path.isfile(os.path.join(BSTI_DATA_PATH, f))]
for imgfile in bsti_images:
rand_val = np.random.rand(1)
if rand_val < 0.1:
bsti_val_list.append(imgfile)
elif rand_val < 0.3:
bsti_test_list.append(imgfile)
else:
bsti_train_list.append(imgfile)

print("BSTI train-test-val split: ",len(bsti_train_list), len(bsti_test_list), len(bsti_val_list))

make_img_list(BSTI_TRAIN_FILE, bsti_train_list, BSTI_DATA_PATH)
make_img_list(BSTI_VAL_FILE, bsti_val_list, BSTI_DATA_PATH)
make_img_list(BSTI_TEST_FILE, bsti_test_list, BSTI_DATA_PATH)
21 changes: 18 additions & 3 deletions data_tools/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@

parser = argparse.ArgumentParser()
parser.add_argument("--combine_pneumonia", action='store_true', default=False)
parser.add_argument("--bsti", action='store_true', default=False)
args = parser.parse_args()

COVID19_DATA_PATH = "./data/covid19"
COVID19_IMGS_PATH = "./covid-chestxray-dataset"
BSTI_IMGS_PATH = "./BSTI-dataset"
PNEUMONIDA_DATA_PATH = "./chest-xray-pneumonia"
DATA_PATH = "./data"

# Assert that the data directories are present
for d in [COVID19_DATA_PATH, COVID19_IMGS_PATH, PNEUMONIDA_DATA_PATH, DATA_PATH]:
check_list = [COVID19_DATA_PATH, COVID19_IMGS_PATH, PNEUMONIDA_DATA_PATH, DATA_PATH]
if args.bsti:
check_list.append(BSTI_IMGS_PATH)
for d in check_list:
try:
assert os.path.isdir(d)
except:
Expand Down Expand Up @@ -51,11 +56,21 @@ def create_list (split):
l.append((f, 2)) # Class 2
else:
l.append((f, 3)) # Class 3

# Prepare list using BSTI covid dataset
if args.bsti:
bsti_covid_file = os.path.join(COVID19_DATA_PATH, 'bsti_%s_list.txt'%split)
with open(bsti_covid_file, 'r') as cf:
for f in cf.readlines():
f = os.path.join(BSTI_IMGS_PATH, f.strip())
if args.combine_pneumonia:
l.append((f, 2)) # Class 2
else:
l.append((f, 3)) # Class 3

with open(os.path.join(DATA_PATH, '%s.txt'%split), 'w') as f:
for item in l:
f.write("%s %d\n" % item)

for split in ['train', 'test', 'val']:
create_list(split)

create_list(split)

0 comments on commit 3f940db

Please sign in to comment.