Skip to content

Implementation of Self-adjusting Dice Loss from "Dice Loss for Data-imbalanced NLP Tasks" paper

License

Notifications You must be signed in to change notification settings

fursovia/self-adj-dice

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Self-adjusting Dice Loss

This is an unofficial PyTorch implementation of the Dice Loss for Data-imbalanced NLP Tasks paper.

Usage

Installation

pip install sadice

Text classification example

import torch
from sadice import SelfAdjDiceLoss

criterion = SelfAdjDiceLoss()
# (batch_size, num_classes)
logits = torch.rand(128, 10, requires_grad=True)
targets = torch.randint(0, 10, size=(128, ))

loss = criterion(logits, targets)
loss.backward()

NER example

import torch
from sadice import SelfAdjDiceLoss

criterion = SelfAdjDiceLoss(reduction="none")
# (batch_size, num_tokens, num_classes)
logits = torch.rand(128, 40, 10, requires_grad=True)
targets = torch.randint(0, 10, size=(128, 40))

loss = criterion(logits.view(-1, 10), targets.view(-1))
loss = loss.reshape(-1, 40).mean(-1).mean()
loss.backward()

About

Implementation of Self-adjusting Dice Loss from "Dice Loss for Data-imbalanced NLP Tasks" paper

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages