Skip to content

Commit

Permalink
[nn] nn.Embedding : padding_idx doc update (#53809) (#54026)
Browse files Browse the repository at this point in the history
Summary:
Follow-up of #53447

Reference: #53447 (comment)

Pull Request resolved: #53809

Reviewed By: bdhirsh

Differential Revision: D27049643

Pulled By: jbschlosser

fbshipit-source-id: 623a2a254783b86391dc2b0777b688506adb4c0e

Co-authored-by: kshitij12345 <kshitijkalambarkar@gmail.com>
  • Loading branch information
malfet and kshitij12345 authored Mar 16, 2021
1 parent 51233ea commit 264d0ec
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
6 changes: 3 additions & 3 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1846,9 +1846,9 @@ def embedding(
input (LongTensor): Tensor containing indices into the embedding matrix
weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
and number of columns equal to the embedding size
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
whenever it encounters the index.
Note: Vector at :attr:`padding_idx` will not receive gradient update.
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
i.e. it remains as a fixed "pad".
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
Note: this will modify :attr:`weight` in-place.
Expand Down
35 changes: 24 additions & 11 deletions torch/nn/modules/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ class Embedding(Module):
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
i.e. it remains as a fixed "pad". For a newly constructed Embedding,
the embedding vector at :attr:`padding_idx` will default to all zeros,
but can be updated to another value to be used as the padding vector.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
Expand All @@ -42,14 +45,6 @@ class Embedding(Module):
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
.. note::
With :attr:`padding_idx` set, the embedding vector at
:attr:`padding_idx` is initialized to all zeros. However, note that this
vector can be modified afterwards, e.g., using a customized
initialization method, and thus changing the vector used to pad the
output. The gradient for this vector from :class:`~torch.nn.Embedding`
is always zero.
.. note::
When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
:attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
Expand Down Expand Up @@ -93,6 +88,22 @@ class Embedding(Module):
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
>>> # example of changing `pad` vector
>>> padding_idx = 0
>>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
>>> embedding.weight
Parameter containing:
tensor([[ 0.0000, 0.0000, 0.0000],
[-0.7895, -0.7089, -0.0364],
[ 0.6778, 0.5803, 0.2678]], requires_grad=True)
>>> with torch.no_grad():
... embedding.weight[padding_idx] = torch.ones(3)
>>> embedding.weight
Parameter containing:
tensor([[ 1.0000, 1.0000, 1.0000],
[-0.7895, -0.7089, -0.0364],
[ 0.6778, 0.5803, 0.2678]], requires_grad=True)
"""
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
'norm_type', 'scale_grad_by_freq', 'sparse']
Expand Down Expand Up @@ -171,7 +182,9 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
padding_idx (int, optional): See module initialization documentation.
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
i.e. it remains as a fixed "pad".
max_norm (float, optional): See module initialization documentation.
norm_type (float, optional): See module initialization documentation. Default ``2``.
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
Expand Down

0 comments on commit 264d0ec

Please sign in to comment.