support baichuan model
This commit is contained in:
parent
c527399424
commit
0cee6ad67f
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||||
|
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B)
|
||||||
|
|
||||||
## Supported Training Approaches
|
## Supported Training Approaches
|
||||||
|
|
||||||
|
|
|
@ -170,6 +170,8 @@ def load_pretrained(
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||||
|
if tokenizer.pad_token_id == 64000:
|
||||||
|
tokenizer.pad_token_id = 0 # for baichuan model (need fix)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
|
|
|
@ -83,7 +83,13 @@ def prepare_model_for_training(
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if use_gradient_checkpointing:
|
if use_gradient_checkpointing:
|
||||||
model.enable_input_require_grads()
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
|
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue