Skip to content

Commit

Permalink
Update tasks.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lehougoogle authored Mar 20, 2023
1 parent 8e6ee0f commit 23cb543
Showing 1 changed file with 34 additions and 40 deletions.
74 changes: 34 additions & 40 deletions flan/v2/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,30 @@ def register_zero_shot_task(zero_shot_name: str,
metric_fns=zero_shot_config.metric_fns)


# Add zero-shot tasks in task_configs.TASK_CONFIGS.
def register_niv2_few_shot_task(
few_shot_name: str,
zero_shot_config: task_configs.TaskConfig,
patterns: List[Tuple[str, str]],
template_type: str=None):
if len(patterns) == 1:
formatter = prep.get_formatter(patterns[0][0], patterns[0][1])
else:
# This batch formatter applies many prompts to a single task.
formatter = prep.get_batch_formatter(patterns)

add_template_metadata_fn = functools.partial(prep.add_template_info, template_type=template_type)
for suffix, output_features in constants.TRAIN_TASK_SUFFIXES_AND_FEATURES:
seqio.TaskRegistry.add(
zero_shot_name + suffix,
source=zero_shot_config.source,
preprocessors=zero_shot_config.preprocessors + [add_template_metadata_fn] + formatter +
prep.FLAN_TOKENIZE,
postprocess_fn=zero_shot_config.postprocess_fn,
output_features=output_features,
metric_fns=zero_shot_config.metric_fns)


for t_name, config in task_configs.ALL_CANDIDATE_TASK_CONFIGS.items():
flan_pattern_name = utils.t_name_to_flan_pattern_name(t_name)
patterns_list = templates.PATTERNS[flan_pattern_name]
Expand Down Expand Up @@ -263,45 +287,15 @@ def register_zero_shot_task(zero_shot_name: str,
register_zero_shot_task(f"{t_name}_template_mix_no_opt_five_shot", config,
mixed_templates, "fs_noopt")

# Build few-shot with non-deterministic templates.
# The #. of shots is non-deterministic too.
# We have 10 templates/task. So in average 2.5 shots/task.
num_shots_int = 25
fewshot_base_task_name = f"{t_name}_template_0to10"

all_patterns = []
few_shot_patterns = []
for few_shot_pattern in x_shot_templates:
all_patterns.append((few_shot_pattern.combined_inputs_w_target_prefix,
few_shot_pattern.combined_targets_wo_target_prefix))
all_formatter = prep.get_batch_formatter(all_patterns)
few_shot_patterns.append((few_shot_pattern.combined_inputs_w_target_prefix,
few_shot_pattern.combined_targets_wo_target_prefix))

add_template_metadata_fn = functools.partial(prep.add_template_info, template_type="fs_opt")
for task_suffix, task_output_features in constants.TRAIN_TASK_SUFFIXES_AND_FEATURES:
seqio.TaskRegistry.add(
fewshot_base_task_name + task_suffix,
source=config.source,
preprocessors=config.preprocessors + [add_template_metadata_fn] + all_formatter + prep.FLAN_TOKENIZE,
postprocess_fn=config.postprocess_fn,
output_features=task_output_features,
metric_fns=config.metric_fns)
# Task names:
# f'{t_name}_template_0to10_x_shot{suffix}'
# f'{t_name}_template_0to10_no_opt_x_shot{suffix}'
# The `_no_opt` version is not really used.
for opt_type_name in ["", "_no_opt"]:
few_shot.register_few_shot_version_of_task(
base_task_name=fewshot_base_task_name + task_suffix,
new_task_name=(fewshot_base_task_name + opt_type_name + "_x_shot" +
task_suffix),
num_shots=num_shots_int,
prune_exemplars=True,
max_input_length=constants.FEW_SHOT_MAX_LEN,
prune_based_on_template_idx=True,
x_y_delimiter="",
inputs_prefix="",
targets_prefix="",
example_separator="",
strip_targets=True,
# We always use the 0th template's final_suffix and input_pattern.
final_suffix=x_shot_templates[0].final_suffix,
input_pattern=x_shot_templates[0].input_pattern)
for opt_type_name, template_type in [
("", "fs_opt"),
("_no_opt", "fs_noopt"),
]:
few_shot_task_name = f"{t_name}_template_0to10{opt_type_name}_x_shot"
register_niv2_few_shot_task(few_shot_task_name, config,
few_shot_patterns, template_type)

0 comments on commit 23cb543

Please sign in to comment.