Skip to content

Commit

Permalink
code cleanup for TF 0.12.1
Browse files Browse the repository at this point in the history
  • Loading branch information
akademi4eg committed Feb 13, 2017
1 parent ed69cf7 commit b735da4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
14 changes: 7 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def _str_to_bool(s):
parser.add_argument('--gc_channels', type=int, default=None,
help='Number of global condition channels. Default: None. Expecting: Int')
parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP,
help='Maximum amount of checkpoints that will be kept alive. Default: ' + str(MAX_TO_KEEP) + '.')
help='Maximum amount of checkpoints that will be kept alive. Default: '
+ str(MAX_TO_KEEP) + '.')
return parser.parse_args()


Expand Down Expand Up @@ -195,7 +196,6 @@ def main():
return

logdir = directories['logdir']
logdir_root = directories['logdir_root']
restore_from = directories['restore_from']

# Even if we restored the model, we will treat it as new training
Expand Down Expand Up @@ -225,7 +225,7 @@ def main():
wavenet_params["scalar_input"],
wavenet_params["initial_filter_width"]),
sample_size=args.sample_size,
silence_threshold=args.silence_threshold)
silence_threshold=silence_threshold)
audio_batch = reader.dequeue(args.batch_size)
if gc_enabled:
gc_id_batch = reader.dequeue_gc(args.batch_size)
Expand Down Expand Up @@ -260,14 +260,14 @@ def main():
optim = optimizer.minimize(loss, var_list=trainable)

# Set up logging for TensorBoard.
writer = tf.train.SummaryWriter(logdir)
writer = tf.summary.FileWriter(logdir)
writer.add_graph(tf.get_default_graph())
run_metadata = tf.RunMetadata()
summaries = tf.merge_all_summaries()
summaries = tf.summary.merge_all()

# Set up session
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
sess.run(init)

# Saver for storing checkpoints of the model.
Expand All @@ -290,8 +290,8 @@ def main():
reader.start_threads(sess)

step = None
last_saved_step = saved_global_step
try:
last_saved_step = saved_global_step
for step in range(saved_global_step + 1, args.num_steps):
start_time = time.time()
if args.store_metadata and step % 50 == 0:
Expand Down
6 changes: 3 additions & 3 deletions wavenet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def loss(self,
target_output)
reduced_loss = tf.reduce_mean(loss)

tf.scalar_summary('loss', reduced_loss)
tf.summary.scalar('loss', reduced_loss)

if l2_regularization_strength is None:
return reduced_loss
Expand All @@ -681,7 +681,7 @@ def loss(self,
total_loss = (reduced_loss +
l2_regularization_strength * l2_loss)

tf.scalar_summary('l2_loss', l2_loss)
tf.scalar_summary('total_loss', total_loss)
tf.summary.scalar('l2_loss', l2_loss)
tf.summary.scalar('total_loss', total_loss)

return total_loss
6 changes: 4 additions & 2 deletions wavenet/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ def mu_law_encode(audio, quantization_channels):
with tf.name_scope('encode'):
mu = quantization_channels - 1
# Perform mu-law companding transformation (ITU-T, 1988).
# Minimum operation is here to deal with rare large amplitudes caused by resampling.
magnitude = tf.log(1 + mu * tf.minimum(tf.abs(audio), 1.0)) / tf.log(1. + mu)
# Minimum operation is here to deal with rare large amplitudes caused
# by resampling.
safe_audio_abs = tf.minimum(tf.abs(audio), 1.0)
magnitude = tf.log(1 + mu * safe_audio_abs) / tf.log(1. + mu)
signal = tf.sign(audio) * magnitude
# Quantize signal to the specified number of levels.
return tf.cast((signal + 1) / 2 * mu + 0.5, tf.int32)
Expand Down

0 comments on commit b735da4

Please sign in to comment.