137 lines
5.1 KiB
Python
137 lines
5.1 KiB
Python
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, HfArgumentParser
|
||
|
from transformers.trainer_utils import is_main_process
|
||
|
from dataclasses import field, dataclass, asdict
|
||
|
from datasets import load_dataset
|
||
|
from typing import Optional
|
||
|
import torch
|
||
|
import os
|
||
|
|
||
|
@dataclass()
|
||
|
class BaseModelConfig:
|
||
|
pretrained_model_name_or_path: Optional[str] = field(
|
||
|
default="/2b_sft_model"
|
||
|
)
|
||
|
|
||
|
@dataclass()
|
||
|
class DataConfig:
|
||
|
train_dataset_path: Optional[str] = field(
|
||
|
default="test.parquet"
|
||
|
)
|
||
|
eval_dataset_path: Optional[str] = field(
|
||
|
default="eval.parquet"
|
||
|
)
|
||
|
eval_size: Optional[int] = field(
|
||
|
default=256
|
||
|
)
|
||
|
max_length: Optional[int] = field(
|
||
|
default=512
|
||
|
)
|
||
|
num_data_proc: Optional[int] = field(
|
||
|
default=16
|
||
|
)
|
||
|
training_key: Optional[str] = field(
|
||
|
default='input'
|
||
|
)
|
||
|
skip_eos_token: Optional[bool] = field(
|
||
|
default=False
|
||
|
)
|
||
|
|
||
|
@dataclass()
|
||
|
class CustomTrainingConfig(TrainingArguments):
|
||
|
run_name: Optional[str] = field(default="cpm")
|
||
|
output_dir: Optional[str] = field(default="checkpoints/")
|
||
|
per_device_train_batch_size: Optional[int] = field(default=4)
|
||
|
per_device_eval_batch_size: Optional[int] = field(default=4)
|
||
|
num_train_epochs: Optional[int] = field(default=20)
|
||
|
weight_decay: Optional[float] = field(default=0)
|
||
|
learning_rate: Optional[float] = field(default=1e-7)
|
||
|
lr_scheduler_type: Optional[str] = field(default="cosine")
|
||
|
warmup_ratio: Optional[float] = field(default=0.1)
|
||
|
eval_strategy: Optional[str] = field(default="steps")
|
||
|
eval_steps: Optional[int] = field(default=100)
|
||
|
load_best_model_at_end: Optional[bool] = field(default=True)
|
||
|
logging_strategy: Optional[str] = field(default="steps")
|
||
|
logging_steps: Optional[int] = field(default=1)
|
||
|
save_strategy: Optional[str] = field(default="steps")
|
||
|
save_steps: Optional[int] = field(default=100)
|
||
|
save_total_limit: Optional[int] = field(default=10)
|
||
|
save_only_model: Optional[bool] = field(default=True)
|
||
|
bf16: Optional[bool] = field(default=True)
|
||
|
|
||
|
def process_text(text):
|
||
|
text_lines = text.split('\n')
|
||
|
processed_lines = [line for line in text_lines if not line.strip().startswith(('print', 'assert', '# '))]
|
||
|
return '\n'.join(processed_lines).strip()
|
||
|
|
||
|
def transform_data(samples, tokenizer, training_key, max_length, skip_eos_token=False, **kwargs):
|
||
|
sequences = []
|
||
|
for idx, question in enumerate(samples['input']):
|
||
|
sequences.append(f"<用户>{question}<AI>{process_text(samples['output'][idx])}{tokenizer.eos_token}")
|
||
|
|
||
|
encoded = tokenizer.batch_encode_plus(sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
|
||
|
input_ids = encoded["input_ids"]
|
||
|
attention_mask = encoded["attention_mask"]
|
||
|
attention_mask = torch.where(input_ids == 0, torch.zeros_like(attention_mask), attention_mask)
|
||
|
target_ids = input_ids.clone()
|
||
|
|
||
|
for i, sequence in enumerate(input_ids):
|
||
|
ai_marker = (sequence == tokenizer.encode("<AI>", add_special_tokens=False)[2]).nonzero(as_tuple=False)
|
||
|
if ai_marker.nelement() > 0:
|
||
|
start_pos = ai_marker[0].item()
|
||
|
target_ids[i, :start_pos] = -100
|
||
|
if skip_eos_token:
|
||
|
target_ids[i, -1] = -100
|
||
|
|
||
|
return {
|
||
|
"input_ids": input_ids,
|
||
|
"attention_mask": attention_mask,
|
||
|
"labels": target_ids
|
||
|
}
|
||
|
|
||
|
def batch_processor(items):
|
||
|
input_ids = torch.tensor([item["input_ids"] for item in items])
|
||
|
attention_mask = torch.tensor([item["attention_mask"] for item in items])
|
||
|
labels = torch.tensor([item["labels"] for item in items])
|
||
|
return {
|
||
|
"input_ids": input_ids,
|
||
|
"attention_mask": attention_mask,
|
||
|
"labels": labels
|
||
|
}
|
||
|
|
||
|
def execute():
|
||
|
parser = HfArgumentParser((BaseModelConfig, DataConfig, CustomTrainingConfig))
|
||
|
model_config, data_config, training_config = parser.parse_args_into_dataclasses()
|
||
|
|
||
|
model = AutoModelForCausalLM.from_pretrained(model_config.pretrained_model_name_or_path, trust_remote_code=True)
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_config.pretrained_model_name_or_path)
|
||
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||
|
|
||
|
train_data = load_dataset('parquet', data_files={'train':data_config.train_dataset_path.split(',')})['train']
|
||
|
eval_data = load_dataset('parquet', data_files={'train':data_config.eval_dataset_path})['train']
|
||
|
|
||
|
train_data = train_data.map(transform_data, batched=True, num_proc=data_config.num_data_proc, fn_kwargs={
|
||
|
"tokenizer": tokenizer,
|
||
|
**asdict(data_config)
|
||
|
})
|
||
|
eval_data = eval_data.map(transform_data, batched=True, num_proc=data_config.num_data_proc, fn_kwargs={
|
||
|
"tokenizer": tokenizer,
|
||
|
**asdict(data_config)
|
||
|
})
|
||
|
|
||
|
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||
|
if is_main_process(local_rank):
|
||
|
print(train_data[0])
|
||
|
|
||
|
trainer = Trainer(
|
||
|
model=model,
|
||
|
args=training_config,
|
||
|
train_dataset=train_data,
|
||
|
eval_dataset=eval_data,
|
||
|
tokenizer=tokenizer,
|
||
|
data_collator=batch_processor,
|
||
|
)
|
||
|
|
||
|
trainer.train()
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
execute()
|