This commit is contained in:
hiyouga 2024-05-07 17:50:27 +08:00
parent 6159acbaa0
commit b0888262e3
7 changed files with 43 additions and 5 deletions

View File

@ -148,6 +148,8 @@ bash examples/full_multi_gpu/predict.sh
#### Merge LoRA Adapters
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```

View File

@ -148,6 +148,8 @@ bash examples/full_multi_gpu/predict.sh
#### 合并 LoRA 适配器
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```

View File

@ -1,6 +1,12 @@
#!/bin/bash
python -m torch.distributed.run \
NPROC_PER_NODE=4
NNODES=2
RANK=0
MASTER_ADDR=192.168.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \

View File

@ -1,4 +1,9 @@
#!/bin/bash
deepspeed --include "localhost:0,1,2,3" \
NPROC_PER_NODE=4
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run \
--nproc_per_node $NPROC_PER_NODE \
--nnodes 1 \
--standalone \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@ -1,5 +1,9 @@
#!/bin/bash
# ZeRO-3 enables weight sharding on multiple GPUs
deepspeed --include "localhost:0,1,2,3" \
NPROC_PER_NODE=4
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run \
--nproc_per_node $NPROC_PER_NODE \
--nnodes 1 \
--standalone \
src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml

View File

@ -1,4 +1,4 @@
# Note: DO NOT use quantized model or quantization_bit when merging lora weights
# Note: DO NOT use quantized model or quantization_bit when merging lora adapters
# model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct

19
src/api.py Normal file
View File

@ -0,0 +1,19 @@
import os
import uvicorn
from llmtuner.api.app import create_app
from llmtuner.chat import ChatModel
def main():
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
uvicorn.run(app, host=api_host, port=api_port)
if __name__ == "__main__":
main()