forked from jiuyuan/CPM-9G-8B
128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
|
|
# This function has been mostly taken from huggingface conversational ai code at
|
|
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
|
|
|
if top_k > 0:
|
|
# Remove all tokens with a probability less than the last token of the top-k
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
logits[indices_to_remove] = filter_value
|
|
|
|
batch_size = logits.size()[0]
|
|
if top_p > 0.0:
|
|
logits = logits.view(batch_size, -1).contiguous()
|
|
for index in range(len(logits)):
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
# Remove tokens with cumulative probability above the threshold
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
# Shift the indices to the right to keep also the first token above the threshold
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
logits[index][indices_to_remove] = filter_value
|
|
|
|
logits = logits.view(batch_size, -1).contiguous()
|
|
|
|
return logits
|
|
|
|
|
|
def apply_repetition_penalty(
|
|
logits,
|
|
batch_size,
|
|
num_beams,
|
|
prev_output_tokens,
|
|
repetition_penalty,
|
|
start_idx=None,
|
|
end_idx=None,
|
|
window_size=None,
|
|
):
|
|
# only conduct repetition penalty for the output
|
|
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
|
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
|
for i in range(batch_size * num_beams):
|
|
if start_idx is None or end_idx is None:
|
|
output_tokens = prev_output_tokens[i].tolist()
|
|
else:
|
|
if end_idx >= start_idx:
|
|
if window_size:
|
|
output_tokens = prev_output_tokens[i][
|
|
max(start_idx, end_idx + 1 - window_size) : end_idx + 1
|
|
].tolist()
|
|
else:
|
|
output_tokens = prev_output_tokens[i][start_idx : end_idx + 1].tolist()
|
|
else:
|
|
output_tokens = []
|
|
for previous_token in set(output_tokens):
|
|
# if score < 0 then repetition penalty has to
|
|
# multiplied to reduce the previous token probability
|
|
if logits[i, previous_token] < 0:
|
|
logits[i, previous_token] *= repetition_penalty
|
|
else:
|
|
logits[i, previous_token] /= repetition_penalty
|
|
|
|
|
|
class BeamHypotheses:
|
|
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
|
|
"""
|
|
Initialize n-best list of hypotheses.
|
|
"""
|
|
self.max_len = max_len
|
|
self.length_penalty = length_penalty
|
|
self.early_stopping = early_stopping
|
|
self.n_hyp = n_hyp
|
|
self.hyp = []
|
|
self.worst_score = 1e9
|
|
|
|
def __len__(self):
|
|
"""
|
|
Number of hypotheses in the list.
|
|
"""
|
|
return len(self.hyp)
|
|
|
|
def add(self, hyp, sum_logprobs):
|
|
"""
|
|
Add a new hypothesis to the list.
|
|
"""
|
|
score = sum_logprobs / len(hyp) ** self.length_penalty
|
|
|
|
if len(self) < self.n_hyp or score > self.worst_score:
|
|
self.hyp.append((score, hyp))
|
|
if len(self) > self.n_hyp:
|
|
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
|
del self.hyp[sorted_scores[0][1]]
|
|
self.worst_score = sorted_scores[1][0]
|
|
else:
|
|
self.worst_score = min(score, self.worst_score)
|
|
|
|
def is_done(self, best_sum_logprobs, cur_len):
|
|
"""
|
|
If there are enough hypotheses and that none of the hypotheses being generated
|
|
can become better than the worst one in the heap, then we are done with this sentence.
|
|
"""
|
|
if len(self) < self.n_hyp:
|
|
return False
|
|
elif self.early_stopping:
|
|
return True
|
|
else:
|
|
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty
|
|
|
|
|
|
def NBCE(logits):
|
|
"""
|
|
Naive Bayes-based Context Extension
|
|
blog: https://www.kexue.fm/archives/9617
|
|
"""
|
|
beta = 0.25
|
|
logits = logits[:, -1] # bsh -> bh
|
|
logits = logits - logits.logsumexp(dim=-1, keepdims=True)
|
|
k = (logits.exp() * logits).sum(dim=-1)[1:].argmax() + 1
|
|
logits_max = logits[k]
|
|
logits_uncond = logits[0]
|
|
logits = (1 + beta) * logits_max - beta * logits_uncond
|
|
return logits
|