import os import torch from typing import Optional from dataclasses import field, dataclass, asdict from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, HfArgumentParser ) from transformers.trainer_utils import is_main_process @dataclass() class EngineSettings: pretrained_model_name_or_path: Optional[str] = field( default="/2b_sft_model" ) @dataclass() class RuntimeSettings(TrainingArguments): run_name: Optional[str] = field(default="cpm") output_dir: Optional[str] = field(default="checkpoints/") per_device_train_batch_size: Optional[int] = field(default=4) per_device_eval_batch_size: Optional[int] = field(default=4) num_train_epochs: Optional[int] = field(default=20) weight_decay: Optional[float] = field(default=0) learning_rate: Optional[float] = field(default=1e-7) lr_scheduler_type: Optional[str] = field(default="cosine") warmup_ratio: Optional[float] = field(default=0.1) eval_strategy: Optional[str] = field(default="steps") eval_steps: Optional[int] = field(default=100) load_best_model_at_end: Optional[bool] = field(default=True) logging_strategy: Optional[str] = field(default="steps") logging_steps: Optional[int] = field(default=1) save_strategy: Optional[str] = field(default="steps") save_steps: Optional[int] = field(default=100) save_total_limit: Optional[int] = field(default=10) save_only_model: Optional[bool] = field(default=True) bf16: Optional[bool] = field(default=True) @dataclass() class InputSettings: train_dataset_path: Optional[str] = field(default="test.parquet") eval_dataset_path: Optional[str] = field(default="eval.parquet") eval_size: Optional[int] = field(default=256) max_length: Optional[int] = field(default=512) num_data_proc: Optional[int] = field(default=16) training_key: Optional[str] = field(default='input') skip_eos_token: Optional[bool] = field(default=False) def sanitize_input(content): lines = content.split('\n') filtered = [l for l in lines if not l.strip().startswith(('print', 'assert', '# '))] return '\n'.join(filtered).strip() def format_sequence(examples, tokenizer, training_key, max_length, skip_eos_token=False, **kwargs): results = [] for idx, query in enumerate(examples['input']): results.append(f"<用户>{query}{sanitize_input(examples['output'][idx])}{tokenizer.eos_token}") encoded_data = tokenizer.batch_encode_plus( results, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) tokens = encoded_data["input_ids"] mask = encoded_data["attention_mask"] mask = torch.where(tokens == 0, torch.zeros_like(mask), mask) targets = tokens.clone() for idx, seq in enumerate(tokens): response_start = (seq == tokenizer.encode("", add_special_tokens=False)[2]).nonzero(as_tuple=False) if response_start.nelement() > 0: prefix_end = response_start[0].item() targets[idx, :prefix_end] = -100 if skip_eos_token: targets[idx, -1] = -100 return { "input_ids": tokens, "attention_mask": mask, "labels": targets } def merge_batch(batch_items): return { "input_ids": torch.tensor([item["input_ids"] for item in batch_items]), "attention_mask": torch.tensor([item["attention_mask"] for item in batch_items]), "labels": torch.tensor([item["labels"] for item in batch_items]) } def launch_training(): config_parser = HfArgumentParser((EngineSettings, InputSettings, RuntimeSettings)) engine_cfg, input_cfg, runtime_cfg = config_parser.parse_args_into_dataclasses() model = AutoModelForCausalLM.from_pretrained( engine_cfg.pretrained_model_name_or_path, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(engine_cfg.pretrained_model_name_or_path) tokenizer.pad_token_id = tokenizer.unk_token_id train_set = load_dataset('parquet', data_files={'train': input_cfg.train_dataset_path.split(',')})['train'] eval_set = load_dataset('parquet', data_files={'train': input_cfg.eval_dataset_path})['train'] process_args = { "tokenizer": tokenizer, **asdict(input_cfg) } train_set = train_set.map( format_sequence, batched=True, num_proc=input_cfg.num_data_proc, fn_kwargs=process_args ) eval_set = eval_set.map( format_sequence, batched=True, num_proc=input_cfg.num_data_proc, fn_kwargs=process_args ) rank = int(os.environ.get("LOCAL_RANK", -1)) if is_main_process(rank): print(train_set[0]) engine = Trainer( model=model, args=runtime_cfg, train_dataset=train_set, eval_dataset=eval_set, tokenizer=tokenizer, data_collator=merge_batch, ) engine.train() if __name__ == '__main__': launch_training()