-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathrun_t5.py
262 lines (226 loc) · 9.69 KB
/
run_t5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Copyright 2020 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Main file for launching training/eval/predictions of mesh-transformer model."""
import importlib
import os
import sys
from absl import app
from absl import flags
from absl import logging
import t5.models.gin as gin
from mesh_tensorflow.transformer import utils
import pkg_resources
import t5
from t5.models import mtf_model
import tensorflow.compat.v1 as tf
flags.DEFINE_string(
"tpu_job_name", None,
"Name of TPU worker binary. Only necessary if job name is changed from "
"default tpu_worker.")
flags.DEFINE_string(
"model_dir", "/tmp/transformer_standalone", "Estimator model_dir")
flags.DEFINE_string(
"tpu", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.")
flags.DEFINE_string(
"gcp_project",
None,
"Project name for the Cloud TPU-enabled project. If not specified, we "
"will attempt to automatically detect the GCE project from metadata.")
flags.DEFINE_string(
"tpu_zone", None,
"GCE zone where the Cloud TPU is located in. If not specified, we "
"will attempt to automatically detect the GCE project from metadata.")
flags.DEFINE_multi_string(
"module_import", None,
"Modules to import. Use this, for example, to add new `Task`s to the "
"global `TaskRegistry`.")
flags.DEFINE_string(
"t5_tfds_data_dir", None,
"If set, this directory will be used to store datasets prepared by "
"TensorFlow Datasets that are not available in the public TFDS GCS bucket. "
"Note that this flag overrides the `tfds_data_dir` attribute of all "
"`Task`s.")
flags.DEFINE_list(
"additional_task_cache_dirs", [],
"Directories to search for Tasks in addition to defaults.")
flags.DEFINE_boolean("use_model_api", False,
"Use Model API instead of utils.run.")
# Note: All the args from here on are only used when use_model_api is set
flags.DEFINE_enum("mode", None,
["train", "finetune", "eval", "predict", "export", "score"],
"Mode with which to run the model.")
flags.DEFINE_integer("batch_size", 1,
"Number of sequences per batch.")
flags.DEFINE_integer("input_sequence_length", 512,
"Number of tokens in input sequence.")
flags.DEFINE_integer("target_sequence_length", 512,
"Number of tokens in target sequence.")
# TPU-specific args.
flags.DEFINE_string("tpu_topology", "v2-8",
"The TPU topology being used. Ignored if --tpu not set.")
flags.DEFINE_integer("model_parallelism", 8,
"The number of cores per model replica. Ignored if --tpu "
"not set.")
# Train mode args
flags.DEFINE_integer("train_steps", 1000, "Number of training iterations.")
flags.DEFINE_string("mixture_or_task", "wmt_t2t_ende_v003",
"Name of Mixture or Task to use for training/evaluation.")
flags.DEFINE_string("pretrained_model_dir", "",
"Pretrained model dir for finetuning a model.")
# Eval mode args
flags.DEFINE_enum(
"checkpoint_mode", "latest", ["all", "latest", "specific"],
"Checkpoint steps to use when running 'eval', 'predict', 'finetune', and "
"'export' modes. Can specify a list of checkpoints or all or the latest "
"checkpoint. 'finetune' and 'export' modes work with 'latest' or "
"'specific' with a single checkpoint.")
flags.DEFINE_list(
"checkpoint_steps", [],
"Checkpoint step numbers used for 'eval', 'predict', and 'finetune' modes. "
"This argument is only used when which_checkpoint='specific'. "
"For the 'finetune' mode, only a single checkpoint must be specified.")
flags.DEFINE_string("eval_summary_dir", "", "Path to save eval summaries")
flags.DEFINE_string("eval_split", "validation",
"Dataset split to use for evaluation.")
# Predict mode args
flags.DEFINE_string("input_file", "",
"Path to input file for decoding or scoring.")
flags.DEFINE_string("target_file", "", "Path to target file for scoring.")
flags.DEFINE_string("output_file", "", "Path to output file to save decodes.")
# Export mode args
flags.DEFINE_string(
"export_dir", "",
"Directory to export SavedModels to. Will use `model_dir` if unspecified.")
flags.DEFINE_integer("export_beam_size", 1, "Beam size in export mode.")
flags.DEFINE_float("export_temperature", 0.0,
"Sampling emperature in export mode.")
FLAGS = flags.FLAGS
def main(_):
if FLAGS.module_import:
for module in FLAGS.module_import:
importlib.import_module(module)
if FLAGS.t5_tfds_data_dir:
t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)
# Add search path for gin files stored in package.
gin.add_config_file_search_path(
pkg_resources.resource_filename(__name__, "gin"))
try:
tf.io.gfile.makedirs(FLAGS.model_dir)
suffix = 0
command_filename = os.path.join(FLAGS.model_dir, "command")
while tf.io.gfile.exists(command_filename):
suffix += 1
command_filename = os.path.join(
FLAGS.model_dir, "command.{}".format(suffix))
with tf.io.gfile.GFile(command_filename, "w") as f:
f.write(" ".join(sys.argv))
except tf.errors.PermissionDeniedError:
logging.info(
"No write access to model directory. Skipping command logging.")
utils.parse_gin_defaults_and_flags()
if FLAGS.use_model_api:
model = mtf_model.MtfModel(
tpu_job_name=FLAGS.tpu_job_name,
tpu=FLAGS.tpu,
gcp_project=FLAGS.gcp_project,
tpu_zone=FLAGS.tpu_zone,
tpu_topology=FLAGS.tpu_topology,
model_parallelism=FLAGS.model_parallelism,
model_dir=FLAGS.model_dir,
batch_size=FLAGS.batch_size,
sequence_length={"inputs": FLAGS.input_sequence_length,
"targets": FLAGS.target_sequence_length}
)
if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
raise ValueError("checkpoint_mode is set to %s and checkpoint_steps is "
"also set. To use a particular checkpoint, please set "
"checkpoint_mode to 'specific'. For other modes, please "
"ensure that checkpoint_steps is not set."
% FLAGS.checkpoint_mode)
if FLAGS.checkpoint_mode == "latest":
checkpoint_steps = -1
elif FLAGS.checkpoint_mode == "all":
checkpoint_steps = "all"
else:
checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]
if FLAGS.mode == "train":
model.train(mixture_or_task_name=FLAGS.mixture_or_task,
steps=FLAGS.train_steps)
elif FLAGS.mode == "eval":
model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
checkpoint_steps=checkpoint_steps,
summary_dir=FLAGS.eval_summary_dir,
split=FLAGS.eval_split)
elif FLAGS.mode == "finetune":
if not (FLAGS.checkpoint_mode == "latest" or
(FLAGS.checkpoint_mode == "specific" and
len(FLAGS.checkpoint_steps) == 1)):
raise ValueError(
"Must specify a single checkpoint for finetuning a model.")
if isinstance(checkpoint_steps, list):
checkpoint_steps = checkpoint_steps[0]
model.finetune(
mixture_or_task_name=FLAGS.mixture_or_task,
steps=FLAGS.train_steps,
pretrained_model_dir=FLAGS.pretrained_model_dir,
checkpoint_steps=checkpoint_steps)
elif FLAGS.mode == "predict":
model.predict(
checkpoint_steps=checkpoint_steps,
input_file=FLAGS.input_file,
output_file=FLAGS.output_file)
elif FLAGS.mode == "score":
model.score(
FLAGS.input_file,
FLAGS.target_file,
scores_file=FLAGS.output_file,
checkpoint_steps=checkpoint_steps)
elif FLAGS.mode == "export":
if not (FLAGS.checkpoint_mode == "latest" or
(FLAGS.checkpoint_mode == "specific" and
len(FLAGS.checkpoint_steps) == 1)):
raise ValueError(
"Must specify a single checkpoint for exporting a model.")
if isinstance(checkpoint_steps, list):
checkpoint_steps = checkpoint_steps[0]
model.export(
export_dir=FLAGS.export_dir,
checkpoint_step=checkpoint_steps,
beam_size=FLAGS.export_beam_size,
temperature=FLAGS.export_temperature)
else:
raise ValueError("--mode flag must be set when using Model API.")
else:
if FLAGS.mode:
raise ValueError("--mode flag should only be set when using Model API.")
if not FLAGS.tpu:
with gin.unlock_config():
gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32")
gin.bind_parameter(
"utils.get_variable_dtype.activation_dtype", "float32")
utils.run(
tpu_job_name=FLAGS.tpu_job_name,
tpu=FLAGS.tpu,
gcp_project=FLAGS.gcp_project,
tpu_zone=FLAGS.tpu_zone,
model_dir=FLAGS.model_dir)
def console_entry_point():
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.INFO)
app.run(main)
if __name__ == "__main__":
console_entry_point()