146 lines
5.0 KiB
Python
146 lines
5.0 KiB
Python
import os
|
|
import torch
|
|
from typing import Optional
|
|
from dataclasses import field, dataclass, asdict
|
|
from datasets import load_dataset
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
TrainingArguments,
|
|
Trainer,
|
|
HfArgumentParser
|
|
)
|
|
from transformers.trainer_utils import is_main_process
|
|
|
|
@dataclass()
|
|
class EngineSettings:
|
|
pretrained_model_name_or_path: Optional[str] = field(
|
|
default="/2b_sft_model"
|
|
)
|
|
|
|
@dataclass()
|
|
class RuntimeSettings(TrainingArguments):
|
|
run_name: Optional[str] = field(default="cpm")
|
|
output_dir: Optional[str] = field(default="checkpoints/")
|
|
per_device_train_batch_size: Optional[int] = field(default=4)
|
|
per_device_eval_batch_size: Optional[int] = field(default=4)
|
|
num_train_epochs: Optional[int] = field(default=20)
|
|
weight_decay: Optional[float] = field(default=0)
|
|
learning_rate: Optional[float] = field(default=1e-7)
|
|
lr_scheduler_type: Optional[str] = field(default="cosine")
|
|
warmup_ratio: Optional[float] = field(default=0.1)
|
|
eval_strategy: Optional[str] = field(default="steps")
|
|
eval_steps: Optional[int] = field(default=100)
|
|
load_best_model_at_end: Optional[bool] = field(default=True)
|
|
logging_strategy: Optional[str] = field(default="steps")
|
|
logging_steps: Optional[int] = field(default=1)
|
|
save_strategy: Optional[str] = field(default="steps")
|
|
save_steps: Optional[int] = field(default=100)
|
|
save_total_limit: Optional[int] = field(default=10)
|
|
save_only_model: Optional[bool] = field(default=True)
|
|
bf16: Optional[bool] = field(default=True)
|
|
|
|
@dataclass()
|
|
class InputSettings:
|
|
train_dataset_path: Optional[str] = field(default="test.parquet")
|
|
eval_dataset_path: Optional[str] = field(default="eval.parquet")
|
|
eval_size: Optional[int] = field(default=256)
|
|
max_length: Optional[int] = field(default=512)
|
|
num_data_proc: Optional[int] = field(default=16)
|
|
training_key: Optional[str] = field(default='input')
|
|
skip_eos_token: Optional[bool] = field(default=False)
|
|
|
|
def sanitize_input(content):
|
|
lines = content.split('\n')
|
|
filtered = [l for l in lines if not l.strip().startswith(('print', 'assert', '# '))]
|
|
return '\n'.join(filtered).strip()
|
|
|
|
def format_sequence(examples, tokenizer, training_key, max_length, skip_eos_token=False, **kwargs):
|
|
results = []
|
|
for idx, query in enumerate(examples['input']):
|
|
results.append(f"<用户>{query}<AI>{sanitize_input(examples['output'][idx])}{tokenizer.eos_token}")
|
|
|
|
encoded_data = tokenizer.batch_encode_plus(
|
|
results,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="pt"
|
|
)
|
|
|
|
tokens = encoded_data["input_ids"]
|
|
mask = encoded_data["attention_mask"]
|
|
mask = torch.where(tokens == 0, torch.zeros_like(mask), mask)
|
|
targets = tokens.clone()
|
|
|
|
for idx, seq in enumerate(tokens):
|
|
response_start = (seq == tokenizer.encode("<AI>", add_special_tokens=False)[2]).nonzero(as_tuple=False)
|
|
if response_start.nelement() > 0:
|
|
prefix_end = response_start[0].item()
|
|
targets[idx, :prefix_end] = -100
|
|
if skip_eos_token:
|
|
targets[idx, -1] = -100
|
|
|
|
return {
|
|
"input_ids": tokens,
|
|
"attention_mask": mask,
|
|
"labels": targets
|
|
}
|
|
|
|
def merge_batch(batch_items):
|
|
return {
|
|
"input_ids": torch.tensor([item["input_ids"] for item in batch_items]),
|
|
"attention_mask": torch.tensor([item["attention_mask"] for item in batch_items]),
|
|
"labels": torch.tensor([item["labels"] for item in batch_items])
|
|
}
|
|
|
|
def launch_training():
|
|
config_parser = HfArgumentParser((EngineSettings, InputSettings, RuntimeSettings))
|
|
engine_cfg, input_cfg, runtime_cfg = config_parser.parse_args_into_dataclasses()
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
engine_cfg.pretrained_model_name_or_path,
|
|
trust_remote_code=True
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(engine_cfg.pretrained_model_name_or_path)
|
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
|
|
|
train_set = load_dataset('parquet', data_files={'train': input_cfg.train_dataset_path.split(',')})['train']
|
|
eval_set = load_dataset('parquet', data_files={'train': input_cfg.eval_dataset_path})['train']
|
|
|
|
process_args = {
|
|
"tokenizer": tokenizer,
|
|
**asdict(input_cfg)
|
|
}
|
|
|
|
train_set = train_set.map(
|
|
format_sequence,
|
|
batched=True,
|
|
num_proc=input_cfg.num_data_proc,
|
|
fn_kwargs=process_args
|
|
)
|
|
|
|
eval_set = eval_set.map(
|
|
format_sequence,
|
|
batched=True,
|
|
num_proc=input_cfg.num_data_proc,
|
|
fn_kwargs=process_args
|
|
)
|
|
|
|
rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
if is_main_process(rank):
|
|
print(train_set[0])
|
|
|
|
engine = Trainer(
|
|
model=model,
|
|
args=runtime_cfg,
|
|
train_dataset=train_set,
|
|
eval_dataset=eval_set,
|
|
tokenizer=tokenizer,
|
|
data_collator=merge_batch,
|
|
)
|
|
|
|
engine.train()
|
|
|
|
if __name__ == '__main__':
|
|
launch_training() |