Skip to content

Commit

Permalink
use new save load in all dygraph modell (PaddlePaddle#3495)
Browse files Browse the repository at this point in the history
* user new save load in all dygraph modell; test=develop

* change load_dict to set_dict; test=develop
  • Loading branch information
phlrain authored Oct 10, 2019
1 parent 7e59194 commit 7a3e0c7
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions dygraph/cycle_gan/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def infer():
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
save_dir = args.init_model
restore, _ = fluid.dygraph.load_persistables(save_dir)
cycle_gan.load_dict(restore)
restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.set_dict(restore)
cycle_gan.eval()
for file in glob.glob(args.input):
print ("read %s" % file)
Expand Down
4 changes: 2 additions & 2 deletions dygraph/cycle_gan/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test():
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
save_dir = args.init_model + str(epoch)
restore, _ = fluid.dygraph.load_persistables(save_dir)
cycle_gan.load_dict(restore)
restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.set_dict(restore)
cycle_gan.eval()
for data_A , data_B in zip(A_test_reader(), B_test_reader()):
A_name = data_A[1]
Expand Down
2 changes: 1 addition & 1 deletion dygraph/cycle_gan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def train(args):
break

if args.save_checkpoints:
fluid.dygraph.save_persistables(
fluid.save_dygraph(
cycle_gan.state_dict(),
args.output + "/checkpoints/{}".format(epoch))

Expand Down
6 changes: 3 additions & 3 deletions dygraph/mnist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def inference_mnist():
with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist")
# load checkpoint
model_dict, _ = fluid.dygraph.load_persistables("save_dir")
mnist_infer.load_dict(model_dict)
model_dict, _ = fluid.load_dygraph("save_temp")
mnist_infer.set_dict(model_dict)
print("checkpoint loaded")

# start evaluate mode
Expand Down Expand Up @@ -245,7 +245,7 @@ def train_mnist(args):
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
fluid.save_dygraph(mnist.state_dict(), "save_temp")
print("checkpoint saved")

inference_mnist()
Expand Down
2 changes: 1 addition & 1 deletion dygraph/reinforcement_learning/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,5 +200,5 @@ def finish_episode():
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(
running_reward, t))
fluid.dygraph.save_persistables(policy.state_dict(), args.save_dir)
fluid.save_dygraph(policy.state_dict(), args.save_dir)
break
2 changes: 1 addition & 1 deletion dygraph/reinforcement_learning/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,5 +186,5 @@ def finish_episode():
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(
running_reward, t))
fluid.dygraph.save_persistables(policy.state_dict(), args.save_dir)
fluid.save_dygraph(policy.state_dict(), args.save_dir)
break
4 changes: 2 additions & 2 deletions dygraph/reinforcement_learning/test_actor_critic_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def finish_episode():
return returns

running_reward = 10
model_dict, _ = fluid.dygraph.load_persistables(args.save_dir)
policy.load_dict(model_dict)
model_dict, _ = fluid.load_dygraph(args.save_dir)
policy.set_dict(model_dict)

state, ep_reward = env.reset(), 0
for t in range(1, 10000): # Don't infinite loop while learning
Expand Down
4 changes: 2 additions & 2 deletions dygraph/reinforcement_learning/test_reinforce_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def finish_episode():

running_reward = 10
state, ep_reward = env.reset(), 0
model_dict, _ = fluid.dygraph.load_persistables(args.save_dir)
policy.load_dict(model_dict)
model_dict, _ = fluid.load_dygraph(args.save_dir)
policy.set_dict(model_dict)

for t in range(1, 10000): # Don't infinite loop while learning
state = np.array(state).astype("float32")
Expand Down
2 changes: 1 addition & 1 deletion dygraph/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def train_resnet():
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.dygraph.save_persistables(resnet.state_dict(),
fluid.save_dygraph(resnet.state_dict(),
'resnet_params')


Expand Down
6 changes: 3 additions & 3 deletions dygraph/sentiment/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def train():
if steps % args.save_steps == 0:
save_path = "save_dir_" + str(steps)
print('save model to: ' + save_path)
fluid.dygraph.save_persistables(cnn_net.state_dict(),
fluid.save_dygraph(cnn_net.state_dict(),
save_path)
if enable_profile:
print('save profile result into /tmp/profile_file')
Expand All @@ -258,8 +258,8 @@ def infer():
print('Do inferring ...... ')
total_acc, total_num_seqs = [], []

restore, _ = fluid.dygraph.load_persistables(args.checkpoints)
cnn_net_infer.load_dict(restore)
restore, _ = fluid.load_dygraph(args.checkpoints)
cnn_net_infer.set_dict(restore)
cnn_net_infer.eval()

steps = 0
Expand Down

0 comments on commit 7a3e0c7

Please sign in to comment.