CPM-9G-8B/quick_start_clean/model_inference.py

22 lines
696 B
Python
Raw Permalink Normal View History

2024-07-15 14:27:10 +08:00
# pip install vllm
# vllm 0.22版本适合低nvidia驱动 如470.xx.xx,cuda11
# vllm 0.41+ 版本适合高nvidia驱动如535.161.xx, cuda12
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = ["请介绍下启元实验室",]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.3, top_p=0.8)
# Create an LLM.
#llm = LLM(model="../models/facebook/opt-125m")
llm = LLM(model="fm9g-selfrec/", trust_remote_code=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")