Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
jeykigung committed Dec 6, 2022
1 parent 48ae261 commit 596aa11
Showing 4 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/modeling_p5.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@

logger = logging.get_logger(__name__)

# The encoder for input token sequence
class JointEncoder(T5Stack):
def __init__(self, config, embed_tokens=None):
super(T5Stack, self).__init__(config)
5 changes: 5 additions & 0 deletions src/pretrain.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@

from trainer_base import TrainerBase

# The Trainer inherits TrainerBase in trainer_base.py
class Trainer(TrainerBase):
def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True):
super().__init__(
@@ -339,6 +340,7 @@ def main_worker(gpu, args):
dist.init_process_group(backend='nccl')

print(f'Building train loader at GPU {gpu}')
# define the prompts used in training
if args.train == 'yelp':
train_task_list = {'rating': ['1-1', '1-2', '1-3', '1-4', '1-5', '1-6', '1-7', '1-8', '1-9'],
'sequential': ['2-1', '2-2', '2-3', '2-4', '2-5', '2-6', '2-7', '2-8', '2-9', '2-10', '2-11', '2-12'],
@@ -353,6 +355,8 @@ def main_worker(gpu, args):
'review': ['4-1', '4-2', '4-3'],
'traditional': ['5-1', '5-2', '5-3', '5-4', '5-5', '5-6', '5-7']
}
# define sampling numbers for each group of personalized prompts (see pretrain_data.py)
# if greater than 1, a data sample will be used for multiple times with different prompts in certain task family
train_sample_numbers = {'rating': 1, 'sequential': (5, 5, 10), 'explanation': 1, 'review': 1, 'traditional': (10, 5)}
train_loader = get_loader(
args,
@@ -366,6 +370,7 @@ def main_worker(gpu, args):
)

print(f'Building val loader at GPU {gpu}')
# define the prompts used in validation
if args.valid == 'yelp':
val_task_list = {'rating': ['1-1', '1-2', '1-3', '1-4', '1-5', '1-6', '1-7', '1-8', '1-9'],
'sequential': ['2-1', '2-2', '2-3', '2-4', '2-5', '2-6', '2-7', '2-8', '2-9', '2-10', '2-11', '2-12'],
9 changes: 8 additions & 1 deletion src/pretrain_data.py
Original file line number Diff line number Diff line change
@@ -118,6 +118,7 @@ def __init__(self, all_tasks, task_list, tokenizer, args, sample_numbers, mode='
self.datum_info = []
self.compute_datum_info()

# compute_datum_info function intends to plan which data sample to be used for which task group according to the sample numbers in train_sample_numbers of pretrain.py
def compute_datum_info(self):
curr = 0
for key in list(self.task_list.keys()):
@@ -127,16 +128,19 @@ def compute_datum_info(self):
self.datum_info.append((i + curr, key, i // self.sample_numbers[key]))
curr = self.total_length
elif key == 'sequential':
# The first group of sequential prompts (directly predict next item): 2-1 to 2-6 and 2-13
if sum([0 < int(ind.split('-')[1]) <= 6 or int(ind.split('-')[1]) == 13 for ind in self.task_list[key]]):
self.total_length += len(self.sequential_data) * self.sample_numbers[key][0]
for i in range(self.total_length - curr):
self.datum_info.append((i + curr, key, i // self.sample_numbers[key][0]))
curr = self.total_length
# The second group of sequential prompts (predict next item from a candidate list): 2-7 to 2-10
if sum([6 < int(ind.split('-')[1]) <= 10 for ind in self.task_list[key]]):
self.total_length += len(self.sequential_data) * self.sample_numbers[key][1]
for i in range(self.total_length - curr):
self.datum_info.append((i + curr, key, i // self.sample_numbers[key][1]))
curr = self.total_length
# The third group of sequential prompts (predict yes or no for each user-item pair): 2-11 to 2-12
if sum([10 < int(ind.split('-')[1]) <= 12 for ind in self.task_list[key]]):
self.total_length += len(self.sequential_data) * self.sample_numbers[key][2]
for i in range(self.total_length - curr):
@@ -153,11 +157,13 @@ def compute_datum_info(self):
self.datum_info.append((i + curr, key, i // self.sample_numbers[key]))
curr = self.total_length
elif key == 'traditional':
# The first group of direct recommendation prompts (choose one item from 100 candidates): 5-1 to 5-4
if sum([0 < int(ind.split('-')[1]) <= 4 for ind in self.task_list[key]]):
self.total_length += len(self.user2id) * self.sample_numbers[key][0]
for i in range(self.total_length - curr):
self.datum_info.append((i + curr, key, i // self.sample_numbers[key][0]))
curr = self.total_length
# The second group of direct recommendation prompts (predict yes or no for each user-item pair): 5-5 to 5-8
if sum([4 < int(ind.split('-')[1]) <= 8 for ind in self.task_list[key]]):
self.total_length += len(self.user2id) * self.sample_numbers[key][1]
for i in range(self.total_length - curr):
@@ -172,6 +178,7 @@ def compute_datum_info(self):
else:
raise NotImplementedError

# use Gaussian sampling to augment rating scores
def gaussian_sampling(self, datum):
if self.mode == 'train':
if int(datum['overall']) == 1:
@@ -330,7 +337,7 @@ def __getitem__(self, idx):
start_candidates = [_ for _ in range(1, min(4, end_pos))]
start_index = random.randint(0, len(start_candidates)-1)
start_pos = start_candidates[start_index]
purchase_history = sequence[start_pos:end_pos+1]
purchase_history = sequence[start_pos:end_pos+1] # sample a history sequence from the full user purchase history
target_item = sequence[end_pos+1]
elif self.mode == 'val':
purchase_history = sequence[1:-2]
1 change: 1 addition & 0 deletions src/test_data_group5_part1.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ def parse(path):
yield eval(l)


# This test dataloader version is created for the first four prompts in Task Family 5 (direct recommendation)
class P5_Amazon_Dataset(Dataset):
def __init__(self, all_tasks, task_list, tokenizer, args, sample_numbers, mode='train', split='toys', rating_augment=False, sample_type='random'):
self.all_tasks = all_tasks

0 comments on commit 596aa11

Please sign in to comment.