diff --git a/src/train.py b/src/train.py index 6a3212cb..4cc21194 100644 --- a/src/train.py +++ b/src/train.py @@ -1,3 +1,8 @@ +import os + +import torch +from transformers import is_torch_npu_available + from llmtuner.train.tuner import run_exp @@ -11,4 +16,7 @@ def _mp_fn(index): if __name__ == "__main__": + if is_torch_npu_available(): + use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1'] + torch.npu.set_compile_mode(jit_compile=use_jit_compile) main()