fix bug at inference
This commit is contained in:
parent
38ca429228
commit
fd709eacff
|
@ -29,8 +29,8 @@ def main():
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def predict(query, history: list):
|
def predict(query, history: list):
|
||||||
inputs = tokenizer([format_example(query)], return_tensors="pt")
|
input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"]
|
||||||
inputs = inputs.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
|
@ -41,8 +41,8 @@ def main():
|
||||||
"repetition_penalty": 1.5
|
"repetition_penalty": 1.5
|
||||||
}
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_output = model.generate(**inputs, **gen_kwargs)
|
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
outputs = generation_output.tolist()[0][len(input_ids[0]):]
|
||||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||||
history = history + [(query, response)]
|
history = history + [(query, response)]
|
||||||
return response, history
|
return response, history
|
||||||
|
|
|
@ -83,8 +83,8 @@ def format_example(query):
|
||||||
def predict(input, chatbot, max_length, top_p, temperature, history):
|
def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||||
chatbot.append((parse_text(input), ""))
|
chatbot.append((parse_text(input), ""))
|
||||||
|
|
||||||
inputs = tokenizer([format_example(input)], return_tensors="pt")
|
input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"]
|
||||||
inputs = inputs.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
|
@ -94,8 +94,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||||
"repetition_penalty": 1.0
|
"repetition_penalty": 1.0
|
||||||
}
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_output = model.generate(**inputs, **gen_kwargs)
|
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
outputs = generation_output.tolist()[0][len(input_ids[0]):]
|
||||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||||
history = history + [(input, response)]
|
history = history + [(input, response)]
|
||||||
chatbot[-1] = (parse_text(input), parse_text(response))
|
chatbot[-1] = (parse_text(input), parse_text(response))
|
||||||
|
|
Loading…
Reference in New Issue