fix saving custom code
This commit is contained in:
parent
2c867b9bb1
commit
1e1358431d
|
@ -11,7 +11,7 @@ from transformers import (
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
@ -36,7 +36,7 @@ def load_model_and_tokenizer(
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: FinetuningArguments,
|
||||||
is_trainable: Optional[bool] = False,
|
is_trainable: Optional[bool] = False,
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained model and tokenizer.
|
Loads pretrained model and tokenizer.
|
||||||
|
|
||||||
|
@ -113,12 +113,12 @@ def load_model_and_tokenizer(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register auto class to save the custom code files.
|
# Register auto class to save the custom code files.
|
||||||
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
|
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||||
config.__class__.register_for_auto_class()
|
config.__class__.register_for_auto_class()
|
||||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
|
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||||
tokenizer.__class__.register_for_auto_class()
|
|
||||||
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
|
|
||||||
model.__class__.register_for_auto_class()
|
model.__class__.register_for_auto_class()
|
||||||
|
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||||
|
tokenizer.__class__.register_for_auto_class()
|
||||||
|
|
||||||
# Initialize adapters
|
# Initialize adapters
|
||||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||||
|
|
|
@ -300,6 +300,45 @@ class BaichuanPreTrainedModel(PreTrainedModel):
|
||||||
if isinstance(module, BaichuanModel):
|
if isinstance(module, BaichuanModel):
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_standard_cache(
|
||||||
|
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||||
|
num_heads, ...]))
|
||||||
|
"""
|
||||||
|
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||||
|
num_heads = batch_size_times_num_heads // batch_size
|
||||||
|
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
||||||
|
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
||||||
|
return tuple(
|
||||||
|
(
|
||||||
|
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
||||||
|
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
||||||
|
)
|
||||||
|
for layer_past in past_key_value
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_baichuan_cache(
|
||||||
|
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||||
|
batch_size_times_num_heads = batch_size * num_heads
|
||||||
|
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
||||||
|
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
||||||
|
return tuple(
|
||||||
|
(
|
||||||
|
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
||||||
|
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
||||||
|
)
|
||||||
|
for layer_past in past_key_value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaichuanModel(BaichuanPreTrainedModel):
|
class BaichuanModel(BaichuanPreTrainedModel):
|
||||||
|
|
||||||
|
@ -559,11 +598,20 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
self,
|
||||||
):
|
input_ids: torch.LongTensor,
|
||||||
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> dict:
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
# the cache may be in the standard format (e.g. in contrastive search)
|
||||||
|
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||||
|
past_key_values = self._convert_to_baichuan_cache(past_key_values)
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
|
@ -579,13 +627,30 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@staticmethod
|
def _reorder_cache(
|
||||||
def _reorder_cache(past_key_values, beam_idx):
|
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||||
return tuple(
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
"""
|
||||||
for layer_past in past_key_values
|
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||||
)
|
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||||
|
beam_idx at every generation step.
|
||||||
|
|
||||||
|
Output shares the same memory storage as `past`.
|
||||||
|
"""
|
||||||
|
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
||||||
|
|
||||||
|
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
||||||
|
device_to_beam_idx = {
|
||||||
|
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
||||||
|
}
|
||||||
|
reordered_past = tuple(
|
||||||
|
(
|
||||||
|
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||||
|
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||||
|
)
|
||||||
|
for layer_past in standardized_past
|
||||||
|
)
|
||||||
|
return self._convert_to_baichuan_cache(reordered_past)
|
||||||
|
|
||||||
def quantize(self, bits: int):
|
def quantize(self, bits: int):
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue