Skip to content

Commit

Permalink
Update preprocessors.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lehougoogle authored Apr 10, 2023
1 parent 5d1b9f2 commit 5adf329
Showing 1 changed file with 50 additions and 32 deletions.
82 changes: 50 additions & 32 deletions flan/v2/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,25 +274,35 @@ def reformat_single_example(example, patterns_list, i):
inputs_pattern, targets_pattern = patterns_list[i]
else:
assert len(patterns_list[i]) == 3
inputs_pattern, targets_pattern, example_separator = patterns_list[i]
inputs_pattern, targets_pattern, final_input_pattern = patterns_list[i]
format_strings = {"inputs": inputs_pattern, "targets": targets_pattern}
new_example = dict(example)
for f_name, format_str in format_strings.items():
if "exemplar_inputs" in example and f_name == "inputs":
# TODO(Shayne Longpre): implement format_few_shot_from_feature_dictionary.
# The inputs_pattern should be in the format of:
# inputs_prefix + {inputs} + x_y_delimiter + targets_prefix
# for example:
# "Q: {input}\nA: "
# But we actually want to use the following format (using 2-shot as an example):
# {Definition} + "\n\n" +
# inputs_prefix + {exemplar_inputs[0]} + x_y_delimiter + targets_prefix + {exemplar_targets[0]}
# + {example_separator} +
# inputs_prefix + {exemplar_inputs[1]} + x_y_delimiter + targets_prefix + {exemplar_targets[1]}
# + {example_separator} +
# inputs_prefix + {inputs} + x_y_delimiter + targets_prefix
new_example[f_name] = format_niv2_few_shot_from_feature_dictionary(
format_str, example, example_separator)
if "exemplar_inputs_0" in example and f_name == "inputs":
exp_0 = {"input": example["exemplar_inputs_0"],
"output": example["exemplar_targets_0"]}
exp_0_str = format_from_feature_dictionary(inputs_pattern + targets_pattern, exp_0)
exp_1 = {"input": example["exemplar_inputs_1"],
"output": example["exemplar_targets_1"]}
exp_1_str = format_from_feature_dictionary(inputs_pattern + targets_pattern, exp_1)

#exp_0_str = tf.constant('exp0\n', dtype=tf.string) # test
#exp_1_str = tf.constant('exp1\n', dtype=tf.string) # test

inputs_format = "{exp_0_str}{exp_1_str}" + inputs_pattern
inputs = {"exp_0_str": exp_0_str,
"exp_1_str": exp_1_str,
"input": example["input"]}

inputs_with_ex = format_from_feature_dictionary(inputs_format, inputs)
inputs_with_ex = tf.strings.regex_replace(inputs_with_ex, r" *$", "")

final_input_pattern = final_input_pattern.replace("{{Definition}}", "{Definition}")
final_input_pattern = final_input_pattern.replace("{{inputs}}", "{inputs_with_ex}")
example["inputs_with_ex"] = inputs_with_ex
example["final_suffix"] = tf.constant("", dtype=tf.string)

new_example[f_name] = format_from_feature_dictionary(final_input_pattern, example)
else:
new_example[f_name] = format_from_feature_dictionary(format_str, example)
return new_example
Expand All @@ -314,7 +324,11 @@ def batch_apply_template(dataset, patterns_list):
reformat_batched_example, patterns_list=patterns_list)

# First, strip the dataset of unbatchable features.
passthrough_keys = ["_template_idx", "_task_source", "_task_name"]
passthrough_keys = [
"_template_idx", "_task_source", "_task_name",
"exemplar_inputs_0", "exemplar_targets_0",
"exemplar_inputs_1", "exemplar_targets_1",
"Definition", "ex_explanation", "ex_input", "ex_output"]
dataset_batchable_features_only = remove_unbatchable_items_ds(
dataset, training_keys=get_training_keys(patterns_list, passthrough_keys))

Expand Down Expand Up @@ -437,16 +451,16 @@ def is_correct_fn(example):
@seqio.utils.map_over_dataset
def add_template_info(ex, template_type):
new_ex = dict(ex)
new_ex['_template_type'] = template_type
new_ex["_template_type"] = template_type
return new_ex


# Add source and dataset metadata
@seqio.utils.map_over_dataset
def add_source_info(ex, task_name, task_source):
new_ex = dict(ex)
new_ex['_task_name'] = task_name
new_ex['_task_source'] = task_source
new_ex["_task_name"] = task_name
new_ex["_task_source"] = task_source
return new_ex


Expand Down Expand Up @@ -1479,18 +1493,22 @@ def niv2_few_shot_exemplar_lookup_fn(
):
"""Lookup positive example fields and populate the dataset example."""
task_key = tf.strings.split(example["task_name"], sep=".")[0]
example["exemplar_inputs"] = exemplar_input_lookup.lookup(task_key)
example["exemplar_targets"] = exemplar_targets_lookup.lookup(task_key)
return example


@seqio.map_over_dataset
def niv2_few_shot_exemplar_lookup_fn(example):
# TODO(Shayne Longpre)
# Just keep 2 to 3 random exemplars in
# example["exemplar_inputs"] and example["exemplar_targets"].
# Note that if you pick the N-th exemplar_inputs, you also need
# to pick the N-th exemplar_targets.
example["exemplar_inputs"] = tf.strings.split(
exemplar_input_lookup.lookup(task_key),
sep='\nflanv2-separator\n')
example["exemplar_targets"] = tf.strings.split(
exemplar_targets_lookup.lookup(task_key),
sep='\nflanv2-separator\n')

n = len(example["exemplar_inputs"])
sample_size = tf.constant(2) # 2 exemplars.
idx = tf.random.shuffle(tf.range(n))[:sample_size]
exemplar_inputs = tf.gather(example["exemplar_inputs"], idx)
exemplar_targets = tf.gather(example["exemplar_targets"], idx)
example["exemplar_inputs_0"] = exemplar_inputs[0]
example["exemplar_targets_0"] = exemplar_targets[0]
example["exemplar_inputs_1"] = exemplar_inputs[1]
example["exemplar_targets_1"] = exemplar_targets[1]
return example


Expand Down

0 comments on commit 5adf329

Please sign in to comment.