diff --git a/flan/v2/tasks.py b/flan/v2/tasks.py index b18a045..3c8e92f 100644 --- a/flan/v2/tasks.py +++ b/flan/v2/tasks.py @@ -53,6 +53,30 @@ def register_zero_shot_task(zero_shot_name: str, metric_fns=zero_shot_config.metric_fns) +# Add zero-shot tasks in task_configs.TASK_CONFIGS. +def register_niv2_few_shot_task( + few_shot_name: str, + zero_shot_config: task_configs.TaskConfig, + patterns: List[Tuple[str, str]], + template_type: str=None): + if len(patterns) == 1: + formatter = prep.get_formatter(patterns[0][0], patterns[0][1]) + else: + # This batch formatter applies many prompts to a single task. + formatter = prep.get_batch_formatter(patterns) + + add_template_metadata_fn = functools.partial(prep.add_template_info, template_type=template_type) + for suffix, output_features in constants.TRAIN_TASK_SUFFIXES_AND_FEATURES: + seqio.TaskRegistry.add( + zero_shot_name + suffix, + source=zero_shot_config.source, + preprocessors=zero_shot_config.preprocessors + [add_template_metadata_fn] + formatter + + prep.FLAN_TOKENIZE, + postprocess_fn=zero_shot_config.postprocess_fn, + output_features=output_features, + metric_fns=zero_shot_config.metric_fns) + + for t_name, config in task_configs.ALL_CANDIDATE_TASK_CONFIGS.items(): flan_pattern_name = utils.t_name_to_flan_pattern_name(t_name) patterns_list = templates.PATTERNS[flan_pattern_name] @@ -263,45 +287,15 @@ def register_zero_shot_task(zero_shot_name: str, register_zero_shot_task(f"{t_name}_template_mix_no_opt_five_shot", config, mixed_templates, "fs_noopt") - # Build few-shot with non-deterministic templates. - # The #. of shots is non-deterministic too. - # We have 10 templates/task. So in average 2.5 shots/task. - num_shots_int = 25 - fewshot_base_task_name = f"{t_name}_template_0to10" - - all_patterns = [] + few_shot_patterns = [] for few_shot_pattern in x_shot_templates: - all_patterns.append((few_shot_pattern.combined_inputs_w_target_prefix, - few_shot_pattern.combined_targets_wo_target_prefix)) - all_formatter = prep.get_batch_formatter(all_patterns) + few_shot_patterns.append((few_shot_pattern.combined_inputs_w_target_prefix, + few_shot_pattern.combined_targets_wo_target_prefix)) - add_template_metadata_fn = functools.partial(prep.add_template_info, template_type="fs_opt") - for task_suffix, task_output_features in constants.TRAIN_TASK_SUFFIXES_AND_FEATURES: - seqio.TaskRegistry.add( - fewshot_base_task_name + task_suffix, - source=config.source, - preprocessors=config.preprocessors + [add_template_metadata_fn] + all_formatter + prep.FLAN_TOKENIZE, - postprocess_fn=config.postprocess_fn, - output_features=task_output_features, - metric_fns=config.metric_fns) - # Task names: - # f'{t_name}_template_0to10_x_shot{suffix}' - # f'{t_name}_template_0to10_no_opt_x_shot{suffix}' - # The `_no_opt` version is not really used. - for opt_type_name in ["", "_no_opt"]: - few_shot.register_few_shot_version_of_task( - base_task_name=fewshot_base_task_name + task_suffix, - new_task_name=(fewshot_base_task_name + opt_type_name + "_x_shot" + - task_suffix), - num_shots=num_shots_int, - prune_exemplars=True, - max_input_length=constants.FEW_SHOT_MAX_LEN, - prune_based_on_template_idx=True, - x_y_delimiter="", - inputs_prefix="", - targets_prefix="", - example_separator="", - strip_targets=True, - # We always use the 0th template's final_suffix and input_pattern. - final_suffix=x_shot_templates[0].final_suffix, - input_pattern=x_shot_templates[0].input_pattern) + for opt_type_name, template_type in [ + ("", "fs_opt"), + ("_no_opt", "fs_noopt"), + ]: + few_shot_task_name = f"{t_name}_template_0to10{opt_type_name}_x_shot" + register_niv2_few_shot_task(few_shot_task_name, config, + few_shot_patterns, template_type)