Merge pull request #3584 from zhou-wjjw/main
Enhancing Ascend 910A Training Efficiency in LlamaFactory with NPU
This commit is contained in:
commit
ee4752f6d2
|
@ -1,3 +1,8 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import is_torch_npu_available
|
||||||
|
|
||||||
from llmtuner.train.tuner import run_exp
|
from llmtuner.train.tuner import run_exp
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,4 +16,7 @@ def _mp_fn(index):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
if is_torch_npu_available():
|
||||||
|
use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1']
|
||||||
|
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue