cc/training/train.py

146 lines
5.0 KiB
Python

import os
import torch
from typing import Optional
from dataclasses import field, dataclass, asdict
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
HfArgumentParser
)
from transformers.trainer_utils import is_main_process
@dataclass()
class EngineSettings:
pretrained_model_name_or_path: Optional[str] = field(
default="/2b_sft_model"
)
@dataclass()
class RuntimeSettings(TrainingArguments):
run_name: Optional[str] = field(default="cpm")
output_dir: Optional[str] = field(default="checkpoints/")
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=4)
num_train_epochs: Optional[int] = field(default=20)
weight_decay: Optional[float] = field(default=0)
learning_rate: Optional[float] = field(default=1e-7)
lr_scheduler_type: Optional[str] = field(default="cosine")
warmup_ratio: Optional[float] = field(default=0.1)
eval_strategy: Optional[str] = field(default="steps")
eval_steps: Optional[int] = field(default=100)
load_best_model_at_end: Optional[bool] = field(default=True)
logging_strategy: Optional[str] = field(default="steps")
logging_steps: Optional[int] = field(default=1)
save_strategy: Optional[str] = field(default="steps")
save_steps: Optional[int] = field(default=100)
save_total_limit: Optional[int] = field(default=10)
save_only_model: Optional[bool] = field(default=True)
bf16: Optional[bool] = field(default=True)
@dataclass()
class InputSettings:
train_dataset_path: Optional[str] = field(default="test.parquet")
eval_dataset_path: Optional[str] = field(default="eval.parquet")
eval_size: Optional[int] = field(default=256)
max_length: Optional[int] = field(default=512)
num_data_proc: Optional[int] = field(default=16)
training_key: Optional[str] = field(default='input')
skip_eos_token: Optional[bool] = field(default=False)
def sanitize_input(content):
lines = content.split('\n')
filtered = [l for l in lines if not l.strip().startswith(('print', 'assert', '# '))]
return '\n'.join(filtered).strip()
def format_sequence(examples, tokenizer, training_key, max_length, skip_eos_token=False, **kwargs):
results = []
for idx, query in enumerate(examples['input']):
results.append(f"<用户>{query}<AI>{sanitize_input(examples['output'][idx])}{tokenizer.eos_token}")
encoded_data = tokenizer.batch_encode_plus(
results,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
tokens = encoded_data["input_ids"]
mask = encoded_data["attention_mask"]
mask = torch.where(tokens == 0, torch.zeros_like(mask), mask)
targets = tokens.clone()
for idx, seq in enumerate(tokens):
response_start = (seq == tokenizer.encode("<AI>", add_special_tokens=False)[2]).nonzero(as_tuple=False)
if response_start.nelement() > 0:
prefix_end = response_start[0].item()
targets[idx, :prefix_end] = -100
if skip_eos_token:
targets[idx, -1] = -100
return {
"input_ids": tokens,
"attention_mask": mask,
"labels": targets
}
def merge_batch(batch_items):
return {
"input_ids": torch.tensor([item["input_ids"] for item in batch_items]),
"attention_mask": torch.tensor([item["attention_mask"] for item in batch_items]),
"labels": torch.tensor([item["labels"] for item in batch_items])
}
def launch_training():
config_parser = HfArgumentParser((EngineSettings, InputSettings, RuntimeSettings))
engine_cfg, input_cfg, runtime_cfg = config_parser.parse_args_into_dataclasses()
model = AutoModelForCausalLM.from_pretrained(
engine_cfg.pretrained_model_name_or_path,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(engine_cfg.pretrained_model_name_or_path)
tokenizer.pad_token_id = tokenizer.unk_token_id
train_set = load_dataset('parquet', data_files={'train': input_cfg.train_dataset_path.split(',')})['train']
eval_set = load_dataset('parquet', data_files={'train': input_cfg.eval_dataset_path})['train']
process_args = {
"tokenizer": tokenizer,
**asdict(input_cfg)
}
train_set = train_set.map(
format_sequence,
batched=True,
num_proc=input_cfg.num_data_proc,
fn_kwargs=process_args
)
eval_set = eval_set.map(
format_sequence,
batched=True,
num_proc=input_cfg.num_data_proc,
fn_kwargs=process_args
)
rank = int(os.environ.get("LOCAL_RANK", -1))
if is_main_process(rank):
print(train_set[0])
engine = Trainer(
model=model,
args=runtime_cfg,
train_dataset=train_set,
eval_dataset=eval_set,
tokenizer=tokenizer,
data_collator=merge_batch,
)
engine.train()
if __name__ == '__main__':
launch_training()