Merge pull request #3584 from zhou-wjjw/main

Enhancing Ascend 910A Training Efficiency in LlamaFactory with NPU
This commit is contained in:
hoshi-hiyouga 2024-05-14 22:18:37 +08:00 committed by GitHub
commit ee4752f6d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 0 deletions

View File

@ -1,3 +1,8 @@
import os
import torch
from transformers import is_torch_npu_available
from llmtuner.train.tuner import run_exp
@ -11,4 +16,7 @@ def _mp_fn(index):
if __name__ == "__main__":
if is_torch_npu_available():
use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1']
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
main()