fix #818
This commit is contained in:
parent
ed1c2c5557
commit
5a9970dbef
|
@ -79,7 +79,7 @@ def load_model_and_tokenizer(
|
|||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
|
||||
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
|
||||
if model_args.compute_dtype == torch.bfloat16:
|
||||
setattr(config, "bf16", True)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue