-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathcot_decoding.py
162 lines (140 loc) · 6.26 KB
/
cot_decoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import List, Tuple, Dict, Optional
import numpy as np
def get_device():
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
def calculate_confidence(logits: List[torch.Tensor], answer_ids: torch.Tensor) -> float:
"""
Calculate the confidence score (Δ) as specified in the paper.
Args:
logits: List of logits for each decoding step
answer_ids: Tensor of token ids for the answer
Returns:
Confidence score (Δ)
"""
confidence_sum = 0.0
valid_tokens = 0
for t, token_id in enumerate(answer_ids):
if t >= len(logits):
break
token_logits = logits[t]
probs = torch.softmax(token_logits, dim=-1)
if probs.size(-1) > 1:
top_2_probs, _ = torch.topk(probs, min(2, probs.size(-1)))
if top_2_probs.size(-1) > 1:
confidence_sum += (top_2_probs[-1][0] - top_2_probs[-1][1]).item()
else:
confidence_sum += 1.0 # Max confidence if there's only one token
else:
confidence_sum += 1.0 # Max confidence if there's only one token
valid_tokens += 1
return confidence_sum / valid_tokens if valid_tokens > 0 else 0.0
def aggregate_paths_based_on_scores(paths: List[Tuple[str, float]]) -> Tuple[str, float]:
"""Aggregate multiple paths based on their confidence scores."""
answer_scores = {}
for answer, delta in paths:
answer_scores[answer] = answer_scores.get(answer, 0) + delta
best_answer = max(answer_scores, key=answer_scores.get)
return best_answer, answer_scores[best_answer]
def cot_decode(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
messages: List[Dict[str, str]],
k: int = 10,
num_beams: int = 1,
max_new_tokens: int = 512,
temperature: float = 1.0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
length_penalty: float = 1.0,
no_repeat_ngram_size: int = 0,
early_stopping: bool = False,
aggregate_paths: bool = False,
) -> Tuple[str, float]:
"""
Implement CoT-decoding for a given chat input.
Args:
model: The Hugging Face transformer model.
tokenizer: The associated tokenizer.
messages: List of chat messages in the format [{"role": "user", "content": "..."}]
k: The number of alternative tokens to consider at the first step.
num_beams: Number of beams for beam search.
max_new_tokens: Maximum number of new tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability.
repetition_penalty: Repetition penalty factor.
length_penalty: Length penalty factor.
no_repeat_ngram_size: Size of n-grams to avoid repeating.
early_stopping: Whether to stop generation when all beams are finished.
aggregate_paths: Whether to aggregate multiple paths.
Returns:
A tuple containing the best path (or aggregated result) and its confidence score.
"""
device = get_device()
model.to(device)
# Use the chat template to format the input
if tokenizer.chat_template:
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# Fallback for tokenizers without chat templates
input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
input_text += "\nassistant:"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
attention_mask = torch.ones_like(input_ids).to(device)
# Set pad_token_id if it's not set
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Get the top-k tokens for the first decoding step
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
first_token_logits = outputs.logits[0, -1, :]
top_k_logits, top_k_indices = torch.topk(first_token_logits, k)
paths = []
for idx in top_k_indices:
# Generate sequence starting with the selected token
start_ids = torch.cat([input_ids, idx.unsqueeze(0).unsqueeze(0)], dim=-1)
start_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long, device=device)], dim=-1)
output = model.generate(
start_ids,
attention_mask=start_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=early_stopping,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True,
)
generated_sequence = output.sequences[0]
answer_ids = generated_sequence[len(input_ids[0]):]
answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True)
# Calculate confidence score (Δ)
confidence = calculate_confidence(output.scores, answer_ids)
paths.append((answer_text, confidence))
if aggregate_paths:
return aggregate_paths_based_on_scores(paths)
else:
return max(paths, key=lambda x: x[1])
# Usage example
# from transformers import AutoModelForCausalLM, AutoTokenizer
# model_name = "Qwen/Qwen2.5-0.5B-Instruct"
# model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager")
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# messages = [
# {"role": "user", "content": "In a dance class of 20 students, 20% enrolled in contemporary dance, 25% of the remaining enrolled in jazz dance, and the rest enrolled in hip-hop dance. What percentage of the entire students enrolled in hip-hop dance?"}
# ]
# # Generate the response using CoT decoding
# print(f"Using device: {get_device()}")
# result, confidence = cot_decode(model, tokenizer, messages, aggregate_paths=True, max_new_tokens=512)
# print(f"CoT Decoding:\n {result}")