diff --git a/flan/v2/preprocessors.py b/flan/v2/preprocessors.py index e753373..f7d08e0 100644 --- a/flan/v2/preprocessors.py +++ b/flan/v2/preprocessors.py @@ -274,25 +274,35 @@ def reformat_single_example(example, patterns_list, i): inputs_pattern, targets_pattern = patterns_list[i] else: assert len(patterns_list[i]) == 3 - inputs_pattern, targets_pattern, example_separator = patterns_list[i] + inputs_pattern, targets_pattern, final_input_pattern = patterns_list[i] format_strings = {"inputs": inputs_pattern, "targets": targets_pattern} new_example = dict(example) for f_name, format_str in format_strings.items(): - if "exemplar_inputs" in example and f_name == "inputs": - # TODO(Shayne Longpre): implement format_few_shot_from_feature_dictionary. - # The inputs_pattern should be in the format of: - # inputs_prefix + {inputs} + x_y_delimiter + targets_prefix - # for example: - # "Q: {input}\nA: " - # But we actually want to use the following format (using 2-shot as an example): - # {Definition} + "\n\n" + - # inputs_prefix + {exemplar_inputs[0]} + x_y_delimiter + targets_prefix + {exemplar_targets[0]} - # + {example_separator} + - # inputs_prefix + {exemplar_inputs[1]} + x_y_delimiter + targets_prefix + {exemplar_targets[1]} - # + {example_separator} + - # inputs_prefix + {inputs} + x_y_delimiter + targets_prefix - new_example[f_name] = format_niv2_few_shot_from_feature_dictionary( - format_str, example, example_separator) + if "exemplar_inputs_0" in example and f_name == "inputs": + exp_0 = {"input": example["exemplar_inputs_0"], + "output": example["exemplar_targets_0"]} + exp_0_str = format_from_feature_dictionary(inputs_pattern + targets_pattern, exp_0) + exp_1 = {"input": example["exemplar_inputs_1"], + "output": example["exemplar_targets_1"]} + exp_1_str = format_from_feature_dictionary(inputs_pattern + targets_pattern, exp_1) + + #exp_0_str = tf.constant('exp0\n', dtype=tf.string) # test + #exp_1_str = tf.constant('exp1\n', dtype=tf.string) # test + + inputs_format = "{exp_0_str}{exp_1_str}" + inputs_pattern + inputs = {"exp_0_str": exp_0_str, + "exp_1_str": exp_1_str, + "input": example["input"]} + + inputs_with_ex = format_from_feature_dictionary(inputs_format, inputs) + inputs_with_ex = tf.strings.regex_replace(inputs_with_ex, r" *$", "") + + final_input_pattern = final_input_pattern.replace("{{Definition}}", "{Definition}") + final_input_pattern = final_input_pattern.replace("{{inputs}}", "{inputs_with_ex}") + example["inputs_with_ex"] = inputs_with_ex + example["final_suffix"] = tf.constant("", dtype=tf.string) + + new_example[f_name] = format_from_feature_dictionary(final_input_pattern, example) else: new_example[f_name] = format_from_feature_dictionary(format_str, example) return new_example @@ -314,7 +324,11 @@ def batch_apply_template(dataset, patterns_list): reformat_batched_example, patterns_list=patterns_list) # First, strip the dataset of unbatchable features. - passthrough_keys = ["_template_idx", "_task_source", "_task_name"] + passthrough_keys = [ + "_template_idx", "_task_source", "_task_name", + "exemplar_inputs_0", "exemplar_targets_0", + "exemplar_inputs_1", "exemplar_targets_1", + "Definition", "ex_explanation", "ex_input", "ex_output"] dataset_batchable_features_only = remove_unbatchable_items_ds( dataset, training_keys=get_training_keys(patterns_list, passthrough_keys)) @@ -437,7 +451,7 @@ def is_correct_fn(example): @seqio.utils.map_over_dataset def add_template_info(ex, template_type): new_ex = dict(ex) - new_ex['_template_type'] = template_type + new_ex["_template_type"] = template_type return new_ex @@ -445,8 +459,8 @@ def add_template_info(ex, template_type): @seqio.utils.map_over_dataset def add_source_info(ex, task_name, task_source): new_ex = dict(ex) - new_ex['_task_name'] = task_name - new_ex['_task_source'] = task_source + new_ex["_task_name"] = task_name + new_ex["_task_source"] = task_source return new_ex @@ -1479,18 +1493,22 @@ def niv2_few_shot_exemplar_lookup_fn( ): """Lookup positive example fields and populate the dataset example.""" task_key = tf.strings.split(example["task_name"], sep=".")[0] - example["exemplar_inputs"] = exemplar_input_lookup.lookup(task_key) - example["exemplar_targets"] = exemplar_targets_lookup.lookup(task_key) - return example - - -@seqio.map_over_dataset -def niv2_few_shot_exemplar_lookup_fn(example): - # TODO(Shayne Longpre) - # Just keep 2 to 3 random exemplars in - # example["exemplar_inputs"] and example["exemplar_targets"]. - # Note that if you pick the N-th exemplar_inputs, you also need - # to pick the N-th exemplar_targets. + example["exemplar_inputs"] = tf.strings.split( + exemplar_input_lookup.lookup(task_key), + sep='\nflanv2-separator\n') + example["exemplar_targets"] = tf.strings.split( + exemplar_targets_lookup.lookup(task_key), + sep='\nflanv2-separator\n') + + n = len(example["exemplar_inputs"]) + sample_size = tf.constant(2) # 2 exemplars. + idx = tf.random.shuffle(tf.range(n))[:sample_size] + exemplar_inputs = tf.gather(example["exemplar_inputs"], idx) + exemplar_targets = tf.gather(example["exemplar_targets"], idx) + example["exemplar_inputs_0"] = exemplar_inputs[0] + example["exemplar_targets_0"] = exemplar_targets[0] + example["exemplar_inputs_1"] = exemplar_inputs[1] + example["exemplar_targets_1"] = exemplar_targets[1] return example