diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d5a87df092f710..e17cbae2037cc7 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -172,11 +172,58 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class TemperatureLogitsWarper(LogitsWarper): r""" - [`LogitsWarper`] for temperature (exponential scaling output probability distribution). + [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means + that it can control the randomness of the predicted tokens. + + + + Make sure that `do_sample=True` is included in the `generate` arguments otherwise the temperature value won't have + any effect. + + Args: temperature (`float`): - The value used to module the logits distribution. + Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases + randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely + token. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + >>> input_context = "Hugging Face Company is" + >>> input_ids = tokenizer.encode(input_context, return_tensors="pt") + + >>> torch.manual_seed(0) + + >>> # With temperature=1, the default, we consistently get random outputs due to random sampling. + >>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=1, do_sample=True) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + Hugging Face Company is one of these companies that is going to take a + + >>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=1, do_sample=True) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + Hugging Face Company is one of these companies, you can make a very + + >>> # However, with temperature close to 0 , the output remains invariant. + >>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=0.0001, do_sample=True) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + Hugging Face Company is a company that has been around for over 20 years + + >>> # even if we set a different seed. + >>> torch.manual_seed(42) + >>> outputs = model.generate(input_ids=input_ids, max_new_tokens=10, temperature=0.0001, do_sample=True) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + Hugging Face Company is a company that has been around for over 20 years + ``` """ def __init__(self, temperature: float):