Skip to content

Commit

Permalink
Removed unused temperature scaling. Fixes #1.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromerony committed Mar 25, 2020
1 parent f0863b8 commit b72c693
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
6 changes: 2 additions & 4 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def config():

no_bias_decay = True
label_smoothing = 0.1
temperature = 1.


@ex.capture
Expand All @@ -67,8 +66,7 @@ def get_optimizer_scheduler(parameters, loader_length, epochs, lr, momentum, nes


@ex.automain
def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, label_smoothing,
temperature):
def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, label_smoothing):
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
callback = VisdomLogger(port=visdom_port) if visdom_port else None
if cudnn_flag == 'deterministic':
Expand All @@ -79,7 +77,7 @@ def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_b

torch.manual_seed(seed)
model = get_model(num_classes=loaders.num_classes)
class_loss = SmoothCrossEntropy(epsilon=label_smoothing, temperature=temperature)
class_loss = SmoothCrossEntropy(epsilon=label_smoothing)

model.to(device)
if torch.cuda.device_count() > 1:
Expand Down
3 changes: 1 addition & 2 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ def state_dict_to_cpu(state_dict: OrderedDict):


class SmoothCrossEntropy(nn.Module):
def __init__(self, epsilon: float = 0., temperature: float = 1.):
def __init__(self, epsilon: float = 0.):
super(SmoothCrossEntropy, self).__init__()
self.epsilon = float(epsilon)
self.temperature = float(temperature)

def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
target_probs = torch.full_like(logits, self.epsilon / (logits.shape[1] - 1))
Expand Down

0 comments on commit b72c693

Please sign in to comment.