Skip to content

Commit

Permalink
using wandb for logs
Browse files Browse the repository at this point in the history
  • Loading branch information
goncamateus committed Feb 24, 2021
1 parent 33e825b commit cf39755
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 42 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,5 @@ cython_debug/

# Logs
*tfevents*
*wandb/
*.exb
50 changes: 17 additions & 33 deletions envs/agents/deepvss/agents/agentDDPGMA.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import shutil
import time
import traceback
from pprint import pprint

import numpy as np
import ptan
import torch
import torch.autograd as autograd
import wandb
from lib import common, ddpg_model
from ptan.experience import ExperienceFirstLast
from tensorboardX import SummaryWriter
from pprint import pprint

gradMax = 0
gradAvg = 0
Expand Down Expand Up @@ -120,7 +121,7 @@ def load_actor_model(net, checkpoint):
return net


def play(params, net, device, exp_queue, agent_env, test, writer, collected_samples, finish_event):
def play(params, net, device, exp_queue, agent_env, test, collected_samples, finish_event):

try:
agent = ddpg_model.AgentDDPG(net, device=device,
Expand All @@ -141,7 +142,7 @@ def play(params, net, device, exp_queue, agent_env, test, writer, collected_samp
next_states, rewards, done, info = agent_env.step(actions)
steps += 1
for i in range(agent_env.n_robots_blue):
epi_rewards[f'robot_{i}'] = rewards[f'robot_{i}']
epi_rewards[f'robot_{i}'] += rewards[f'robot_{i}']

next_states = next_states if not done \
else [None]*agent_env.n_robots_blue
Expand All @@ -157,32 +158,19 @@ def play(params, net, device, exp_queue, agent_env, test, writer, collected_samp
agent_env.render('human')

states = next_states

if done:
fps = steps/(time.time() - then)
then = time.time()
writer.add_scalar("rw/goal_score",
info['goal_score'],
matches_played)
writer.add_scalar("rw/ball_grad",
info['ball_grad'],
matches_played)
writer.add_scalar("rw/goals_blue",
info['goals_blue'],
matches_played)
writer.add_scalar("rw/goals_yellow",
info['goals_yellow'],
matches_played)
log_dict = {}
log_dict["rw/goal_score"] = info['goal_score']
log_dict["rw/ball_grad"] = info['ball_grad']
log_dict["rw/goals_blue"] = info['goals_blue']
log_dict["rw/goals_yellow"] = info['goals_yellow']
for i in range(agent_env.n_robots_blue):
writer.add_scalar(f"rw/robot_{i}/total",
epi_rewards[f'robot_{i}'],
matches_played)
writer.add_scalar(f"rw/robot_{i}/move",
info[f'robot_{i}']['move'],
matches_played)
writer.add_scalar(f"rw/robot_{i}/energy",
info[f'robot_{i}']['energy'],
matches_played)
log_dict[f"rw/robot_{i}/total"] = epi_rewards[f'robot_{i}']
log_dict[f"rw/robot_{i}/move"] = info[f'robot_{i}']['move']
log_dict[f"rw/robot_{i}/energy"] = info[f'robot_{i}']['energy']
wandb.log(log_dict)
print(f'<======Match {matches_played}======>')
print(f'-------Rewards:')
pprint(epi_rewards)
Expand All @@ -197,9 +185,7 @@ def play(params, net, device, exp_queue, agent_env, test, writer, collected_samp

if not test and evaluation: # evaluation just finished
for i in range(agent_env.n_robots_blue):
writer.add_scalar(f"eval/robot_{i}/",
epi_rewards[f'robot_{i}'],
matches_played)
wandb.log({f"eval/robot_{i}/": epi_rewards[f'robot_{i}']})
print("evaluation finished")

evaluation = matches_played % eval_freq_matches == 0
Expand Down Expand Up @@ -301,9 +287,6 @@ def train(model_params, act_net, device,
model_params['save_model_frequency']
next_net_sync = processed_samples + model_params['target_net_sync']
queue_max_size = batch_size = model_params['batch_size']
writer_path = model_params['writer_path']
writer = SummaryWriter(log_dir=writer_path+"/train")
tracker = common.RewardTracker(writer)

actor_loss = 0.0
critic_loss = 0.0
Expand Down Expand Up @@ -404,8 +387,9 @@ def train(model_params, act_net, device,
(processed_samples-last_loss_average)
print("avg_reward:%.4f, avg_loss:%f" %
(reward_avg, actor_loss))
tracker.track_training(
processed_samples, reward_avg, actor_loss, critic_loss)
wandb.log({"actor_loss": actor_loss})
wandb.log({"critic_loss": critic_loss})

actor_loss = 0.0
critic_loss = 0.0
last_loss_average = processed_samples
Expand Down
19 changes: 10 additions & 9 deletions envs/agents/deepvss/deepvss_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
# VSS customization:
import traceback
from importlib.machinery import SourceFileLoader
from pprint import pprint

import gym
import numpy as np
import rc_gym
import torch
import torch.multiprocessing as mp
import wandb
from tensorboardX import SummaryWriter

import rc_gym
from agents.agentDDPGMA import train, play, create_actor_model, load_actor_model
import numpy as np
from pprint import pprint
from agents.agentDDPGMA import (create_actor_model, load_actor_model, play,
train)

# Global variables
writer = None
Expand Down Expand Up @@ -119,9 +121,9 @@ def dict_to_str(dct):
if len(args.exp) > 0: # load experience buffer
checkpoint['exp'] = args.exp[0]

writer_path = model_params['data_path'] + \
'/logs/' + run_name
writer = SummaryWriter(log_dir=writer_path+"/agents", comment="-agent")
wandb.init(name=run_name,
project='RC-Reinforcement',
dir='./data_path/logs')

queue_size = model_params['batch_size']
exp_queue = mp.Queue(maxsize=queue_size)
Expand All @@ -130,7 +132,7 @@ def dict_to_str(dct):
print("Threads available: %d" % torch.get_num_threads())

th_a = threading.Thread(target=play, args=(
model_params, net, device, exp_queue, env, args.test, writer, collected_samples, finish_event))
model_params, net, device, exp_queue, env, args.test, collected_samples, finish_event))
play_threads.append(th_a)
th_a.start()

Expand All @@ -140,7 +142,6 @@ def dict_to_str(dct):

else: # crate train process:
model_params['run_name'] = run_name
model_params['writer_path'] = writer_path
model_params['action_format'] = '2f'
model_params['state_format'] = f"{state_shape.shape[0]}f"
net.share_memory()
Expand Down

0 comments on commit cf39755

Please sign in to comment.