Skip to content

Commit

Permalink
Add descriptive docstring to TemperatureLogitsWarper (huggingface#24892)
Browse files Browse the repository at this point in the history
* Add descriptive docstring to TemperatureLogitsWarper

It addresses huggingface#24783

* Remove niche features

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Commit suggestion

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Refactor the examples to simpler ones

* Add a missing comma

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Make args description more compact

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Remove extra text after making description more compact

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Fix linter

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
  • Loading branch information
nablabits and gante authored Jul 26, 2023
1 parent 31acba5 commit 04a5c85
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,58 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class TemperatureLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
that it can control the randomness of the predicted tokens.
<Tip>
Make sure that `do_sample=True` is included in the `generate` arguments otherwise the temperature value won't have
any effect.
</Tip>
Args:
temperature (`float`):
The value used to module the logits distribution.
Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases
randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely
token.
Examples:
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> model.config.pad_token_id = model.config.eos_token_id
>>> model.generation_config.pad_token_id = model.config.eos_token_id
>>> input_context = "Hugging Face Company is"
>>> input_ids = tokenizer.encode(input_context, return_tensors="pt")
>>> torch.manual_seed(0)
>>> # With temperature=1, the default, we consistently get random outputs due to random sampling.
>>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=1, do_sample=True)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is one of these companies that is going to take a
>>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=1, do_sample=True)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is one of these companies, you can make a very
>>> # However, with temperature close to 0 , the output remains invariant.
>>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=0.0001, do_sample=True)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a company that has been around for over 20 years
>>> # even if we set a different seed.
>>> torch.manual_seed(42)
>>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=0.0001, do_sample=True)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a company that has been around for over 20 years
```
"""

def __init__(self, temperature: float):
Expand Down

0 comments on commit 04a5c85

Please sign in to comment.