Skip to content

Commit

Permalink
update build_once for cycle_gan (PaddlePaddle#4108)
Browse files Browse the repository at this point in the history
* update build_once for cycle_gan
test=develop

* add input_channel arg
 for cycle_gan
test=develop
  • Loading branch information
songyouwei authored and phlrain committed Dec 25, 2019
1 parent 9b50a73 commit 7634fe7
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 68 deletions.
11 changes: 6 additions & 5 deletions dygraph/cycle_gan/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from PIL import Image, ImageOps
import numpy as np

A_LIST_FILE = "./data/cityscapes/trainA.txt"
B_LIST_FILE = "./data/cityscapes/trainB.txt"
A_TEST_LIST_FILE = "./data/cityscapes/testA.txt"
B_TEST_LIST_FILE = "./data/cityscapes/testB.txt"
IMAGES_ROOT = "./data/cityscapes/"
DATASET = "cityscapes"
A_LIST_FILE = "./data/"+DATASET+"/trainA.txt"
B_LIST_FILE = "./data/"+DATASET+"/trainB.txt"
A_TEST_LIST_FILE = "./data/"+DATASET+"/testA.txt"
B_TEST_LIST_FILE = "./data/"+DATASET+"/testB.txt"
IMAGES_ROOT = "./data/"+DATASET+"/"

def image_shape():
return [3, 256, 256]
Expand Down
34 changes: 17 additions & 17 deletions dygraph/cycle_gan/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

class conv2d(fluid.dygraph.Layer):
"""docstring for Conv2D"""
def __init__(self,
name_scope,
def __init__(self,
num_channels,
num_filters=64,
filter_size=7,
stride=1,
Expand All @@ -35,15 +35,15 @@ def __init__(self,
relu=True,
relufactor=0.0,
use_bias=False):
super(conv2d, self).__init__(name_scope)
super(conv2d, self).__init__()

if use_bias == False:
con_bias_attr = False
else:
con_bias_attr = fluid.ParamAttr(name="conv_bias",initializer=fluid.initializer.Constant(0.0))

self.conv = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
Expand All @@ -54,7 +54,7 @@ def __init__(self,
initializer=fluid.initializer.NormalInitializer(loc=0.0,scale=stddev)),
bias_attr=con_bias_attr)
if norm:
self.bn = BatchNorm(self.full_name(),
self.bn = BatchNorm(
num_channels=num_filters,
param_attr=fluid.ParamAttr(
name="scale",
Expand Down Expand Up @@ -82,7 +82,7 @@ def forward(self,inputs):

class DeConv2D(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters=64,
filter_size=7,
stride=1,
Expand All @@ -94,27 +94,27 @@ def __init__(self,
relufactor=0.0,
use_bias=False
):
super(DeConv2D,self).__init__(name_scope)
super(DeConv2D,self).__init__()

if use_bias == False:
de_bias_attr = False
else:
de_bias_attr = fluid.ParamAttr(name="de_bias",initializer=fluid.initializer.Constant(0.0))

self._deconv = Conv2DTranspose(self.full_name(),
num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=fluid.ParamAttr(
name="this_is_deconv_weights",
initializer=fluid.initializer.NormalInitializer(loc=0.0, scale=stddev)),
bias_attr=de_bias_attr)
self._deconv = Conv2DTranspose(num_channels,
num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=fluid.ParamAttr(
name="this_is_deconv_weights",
initializer=fluid.initializer.NormalInitializer(loc=0.0, scale=stddev)),
bias_attr=de_bias_attr)



if norm:
self.bn = BatchNorm(self.full_name(),
self.bn = BatchNorm(
num_channels=num_filters,
param_attr=fluid.ParamAttr(
name="de_wights",
Expand Down
62 changes: 38 additions & 24 deletions dygraph/cycle_gan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@

class build_resnet_block(fluid.dygraph.Layer):
def __init__(self,
name_scope,
dim,
use_bias=False):
super(build_resnet_block,self).__init__(name_scope)
super(build_resnet_block,self).__init__()

self.conv0 = conv2d(self.full_name(),
self.conv0 = conv2d(
num_channels=dim,
num_filters=dim,
filter_size=3,
stride=1,
stddev=0.02,
use_bias=False)
self.conv1 = conv2d(self.full_name(),
self.conv1 = conv2d(
num_channels=dim,
num_filters=dim,
filter_size=3,
stride=1,
Expand All @@ -47,61 +48,67 @@ def forward(self,inputs):
out_res = self.conv1(out_res)
return out_res + inputs


class build_generator_resnet_9blocks(fluid.dygraph.Layer):
def __init__ (self,
name_scope):
super(build_generator_resnet_9blocks,self).__init__(name_scope)
def __init__ (self, input_channel):
super(build_generator_resnet_9blocks, self).__init__()

self.conv0 = conv2d(self.full_name(),
self.conv0 = conv2d(
num_channels=input_channel,
num_filters=32,
filter_size=7,
stride=1,
padding=0,
stddev=0.02)
self.conv1 = conv2d(self.full_name(),
self.conv1 = conv2d(
num_channels=32,
num_filters=64,
filter_size=3,
stride=2,
padding=1,
stddev=0.02)
self.conv2 = conv2d(self.full_name(),
self.conv2 = conv2d(
num_channels=64,
num_filters=128,
filter_size=3,
stride=2,
padding=1,
stddev=0.02)
self.build_resnet_block_list=[]
dim = 32*4
dim = 128
for i in range(9):
Build_Resnet_Block = self.add_sublayer(
"generator_%d" % (i+1),
build_resnet_block(self.full_name(),
128))
build_resnet_block(dim))
self.build_resnet_block_list.append(Build_Resnet_Block)
self.deconv0 = DeConv2D(self.full_name(),
self.deconv0 = DeConv2D(
num_channels=dim,
num_filters=32*2,
filter_size=3,
stride=2,
stddev=0.02,
padding=[1, 1],
outpadding=[0, 1, 0, 1],
)
self.deconv1 = DeConv2D(self.full_name(),
self.deconv1 = DeConv2D(
num_channels=32*2,
num_filters=32,
filter_size=3,
stride=2,
stddev=0.02,
padding=[1, 1],
outpadding=[0, 1, 0, 1])
self.conv3 = conv2d(self.full_name(),
num_filters=3,
self.conv3 = conv2d(
num_channels=32,
num_filters=input_channel,
filter_size=7,
stride=1,
stddev=0.02,
padding=0,
relu=False,
norm=False,
use_bias=True)

def forward(self,inputs):
pad_input = fluid.layers.pad2d(inputs, [3, 3, 3, 3], mode="reflect")
y = self.conv0(pad_input)
Expand All @@ -116,11 +123,13 @@ def forward(self,inputs):
y = fluid.layers.tanh(y)
return y


class build_gen_discriminator(fluid.dygraph.Layer):
def __init__(self,name_scope):
super(build_gen_discriminator,self).__init__(name_scope)
def __init__(self, input_channel):
super(build_gen_discriminator, self).__init__()

self.conv0 = conv2d(self.full_name(),
self.conv0 = conv2d(
num_channels=input_channel,
num_filters=64,
filter_size=4,
stride=2,
Expand All @@ -129,28 +138,32 @@ def __init__(self,name_scope):
norm=False,
use_bias=True,
relufactor=0.2)
self.conv1 = conv2d(self.full_name(),
self.conv1 = conv2d(
num_channels=64,
num_filters=128,
filter_size=4,
stride=2,
stddev=0.02,
padding=1,
relufactor=0.2)
self.conv2 = conv2d(self.full_name(),
self.conv2 = conv2d(
num_channels=128,
num_filters=256,
filter_size=4,
stride=2,
stddev=0.02,
padding=1,
relufactor=0.2)
self.conv3 = conv2d(self.full_name(),
self.conv3 = conv2d(
num_channels=256,
num_filters=512,
filter_size=4,
stride=1,
stddev=0.02,
padding=1,
relufactor=0.2)
self.conv4 = conv2d(self.full_name(),
self.conv4 = conv2d(
num_channels=512,
num_filters=1,
filter_size=4,
stride=1,
Expand All @@ -159,6 +172,7 @@ def __init__(self,name_scope):
norm=False,
relu=False,
use_bias=True)

def forward(self,inputs):
y = self.conv0(inputs)
y = self.conv1(y)
Expand Down
20 changes: 4 additions & 16 deletions dygraph/cycle_gan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def train(args):
A_test_reader = data_reader.a_test_reader()
B_test_reader = data_reader.b_test_reader()

cycle_gan = Cycle_Gan("cycle_gan", istrain=True)
cycle_gan = Cycle_Gan(input_channel=data_shape[1], istrain=True)

losses = [[], []]
t_time = 0
Expand Down Expand Up @@ -114,11 +114,7 @@ def train(args):
g_loss_out = g_loss.numpy()

g_loss.backward()
vars_G = []
for param in cycle_gan.parameters():
if param.name[:
52] == "cycle_gan/Cycle_Gan_0/build_generator_resnet_9blocks":
vars_G.append(param)
vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters() + cycle_gan.build_generator_resnet_9blocks_b.parameters()

optimizer1.minimize(g_loss, parameter_list=vars_G)
cycle_gan.clear_gradients()
Expand All @@ -141,11 +137,7 @@ def train(args):
d_loss_A = fluid.layers.reduce_mean(d_loss_A)

d_loss_A.backward()
vars_da = []
for param in cycle_gan.parameters():
if param.name[:
47] == "cycle_gan/Cycle_Gan_0/build_gen_discriminator_0":
vars_da.append(param)
vars_da = cycle_gan.build_gen_discriminator_a.parameters()
optimizer2.minimize(d_loss_A, parameter_list=vars_da)
cycle_gan.clear_gradients()

Expand All @@ -158,11 +150,7 @@ def train(args):
d_loss_B = fluid.layers.reduce_mean(d_loss_B)

d_loss_B.backward()
vars_db = []
for param in cycle_gan.parameters():
if param.name[:
47] == "cycle_gan/Cycle_Gan_0/build_gen_discriminator_1":
vars_db.append(param)
vars_db = cycle_gan.build_gen_discriminator_b.parameters()
optimizer3.minimize(d_loss_B, parameter_list=vars_db)

cycle_gan.clear_gradients()
Expand Down
12 changes: 6 additions & 6 deletions dygraph/cycle_gan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@


class Cycle_Gan(fluid.dygraph.Layer):
def __init__(self, name_scope,istrain=True):
super (Cycle_Gan, self).__init__(name_scope)
def __init__(self, input_channel, istrain=True):
super (Cycle_Gan, self).__init__()

self.build_generator_resnet_9blocks_a = build_generator_resnet_9blocks(self.full_name())
self.build_generator_resnet_9blocks_b = build_generator_resnet_9blocks(self.full_name())
self.build_generator_resnet_9blocks_a = build_generator_resnet_9blocks(input_channel)
self.build_generator_resnet_9blocks_b = build_generator_resnet_9blocks(input_channel)
if istrain:
self.build_gen_discriminator_a = build_gen_discriminator(self.full_name())
self.build_gen_discriminator_b = build_gen_discriminator(self.full_name())
self.build_gen_discriminator_a = build_gen_discriminator(input_channel)
self.build_gen_discriminator_b = build_gen_discriminator(input_channel)

def forward(self,input_A,input_B,is_G,is_DA,is_DB):

Expand Down

0 comments on commit 7634fe7

Please sign in to comment.