sugar/training/train.py

137 lines
5.1 KiB
Python

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, HfArgumentParser
from transformers.trainer_utils import is_main_process
from dataclasses import field, dataclass, asdict
from datasets import load_dataset
from typing import Optional
import torch
import os
@dataclass()
class BaseModelConfig:
pretrained_model_name_or_path: Optional[str] = field(
default="/2b_sft_model"
)
@dataclass()
class DataConfig:
train_dataset_path: Optional[str] = field(
default="test.parquet"
)
eval_dataset_path: Optional[str] = field(
default="eval.parquet"
)
eval_size: Optional[int] = field(
default=256
)
max_length: Optional[int] = field(
default=512
)
num_data_proc: Optional[int] = field(
default=16
)
training_key: Optional[str] = field(
default='input'
)
skip_eos_token: Optional[bool] = field(
default=False
)
@dataclass()
class CustomTrainingConfig(TrainingArguments):
run_name: Optional[str] = field(default="cpm")
output_dir: Optional[str] = field(default="checkpoints/")
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=4)
num_train_epochs: Optional[int] = field(default=20)
weight_decay: Optional[float] = field(default=0)
learning_rate: Optional[float] = field(default=1e-7)
lr_scheduler_type: Optional[str] = field(default="cosine")
warmup_ratio: Optional[float] = field(default=0.1)
eval_strategy: Optional[str] = field(default="steps")
eval_steps: Optional[int] = field(default=100)
load_best_model_at_end: Optional[bool] = field(default=True)
logging_strategy: Optional[str] = field(default="steps")
logging_steps: Optional[int] = field(default=1)
save_strategy: Optional[str] = field(default="steps")
save_steps: Optional[int] = field(default=100)
save_total_limit: Optional[int] = field(default=10)
save_only_model: Optional[bool] = field(default=True)
bf16: Optional[bool] = field(default=True)
def process_text(text):
text_lines = text.split('\n')
processed_lines = [line for line in text_lines if not line.strip().startswith(('print', 'assert', '# '))]
return '\n'.join(processed_lines).strip()
def transform_data(samples, tokenizer, training_key, max_length, skip_eos_token=False, **kwargs):
sequences = []
for idx, question in enumerate(samples['input']):
sequences.append(f"<用户>{question}<AI>{process_text(samples['output'][idx])}{tokenizer.eos_token}")
encoded = tokenizer.batch_encode_plus(sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
attention_mask = torch.where(input_ids == 0, torch.zeros_like(attention_mask), attention_mask)
target_ids = input_ids.clone()
for i, sequence in enumerate(input_ids):
ai_marker = (sequence == tokenizer.encode("<AI>", add_special_tokens=False)[2]).nonzero(as_tuple=False)
if ai_marker.nelement() > 0:
start_pos = ai_marker[0].item()
target_ids[i, :start_pos] = -100
if skip_eos_token:
target_ids[i, -1] = -100
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": target_ids
}
def batch_processor(items):
input_ids = torch.tensor([item["input_ids"] for item in items])
attention_mask = torch.tensor([item["attention_mask"] for item in items])
labels = torch.tensor([item["labels"] for item in items])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
def execute():
parser = HfArgumentParser((BaseModelConfig, DataConfig, CustomTrainingConfig))
model_config, data_config, training_config = parser.parse_args_into_dataclasses()
model = AutoModelForCausalLM.from_pretrained(model_config.pretrained_model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_config.pretrained_model_name_or_path)
tokenizer.pad_token_id = tokenizer.unk_token_id
train_data = load_dataset('parquet', data_files={'train':data_config.train_dataset_path.split(',')})['train']
eval_data = load_dataset('parquet', data_files={'train':data_config.eval_dataset_path})['train']
train_data = train_data.map(transform_data, batched=True, num_proc=data_config.num_data_proc, fn_kwargs={
"tokenizer": tokenizer,
**asdict(data_config)
})
eval_data = eval_data.map(transform_data, batched=True, num_proc=data_config.num_data_proc, fn_kwargs={
"tokenizer": tokenizer,
**asdict(data_config)
})
local_rank = int(os.environ.get("LOCAL_RANK", -1))
if is_main_process(local_rank):
print(train_data[0])
trainer = Trainer(
model=model,
args=training_config,
train_dataset=train_data,
eval_dataset=eval_data,
tokenizer=tokenizer,
data_collator=batch_processor,
)
trainer.train()
if __name__ == '__main__':
execute()