from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, HfArgumentParser from transformers.trainer_utils import is_main_process from dataclasses import field, dataclass, asdict from datasets import load_dataset from typing import Optional import torch import os @dataclass() class BaseModelConfig: pretrained_model_name_or_path: Optional[str] = field( default="/2b_sft_model" ) @dataclass() class DataConfig: train_dataset_path: Optional[str] = field( default="test.parquet" ) eval_dataset_path: Optional[str] = field( default="eval.parquet" ) eval_size: Optional[int] = field( default=256 ) max_length: Optional[int] = field( default=512 ) num_data_proc: Optional[int] = field( default=16 ) training_key: Optional[str] = field( default='input' ) skip_eos_token: Optional[bool] = field( default=False ) @dataclass() class CustomTrainingConfig(TrainingArguments): run_name: Optional[str] = field(default="cpm") output_dir: Optional[str] = field(default="checkpoints/") per_device_train_batch_size: Optional[int] = field(default=4) per_device_eval_batch_size: Optional[int] = field(default=4) num_train_epochs: Optional[int] = field(default=20) weight_decay: Optional[float] = field(default=0) learning_rate: Optional[float] = field(default=1e-7) lr_scheduler_type: Optional[str] = field(default="cosine") warmup_ratio: Optional[float] = field(default=0.1) eval_strategy: Optional[str] = field(default="steps") eval_steps: Optional[int] = field(default=100) load_best_model_at_end: Optional[bool] = field(default=True) logging_strategy: Optional[str] = field(default="steps") logging_steps: Optional[int] = field(default=1) save_strategy: Optional[str] = field(default="steps") save_steps: Optional[int] = field(default=100) save_total_limit: Optional[int] = field(default=10) save_only_model: Optional[bool] = field(default=True) bf16: Optional[bool] = field(default=True) def process_text(text): text_lines = text.split('\n') processed_lines = [line for line in text_lines if not line.strip().startswith(('print', 'assert', '# '))] return '\n'.join(processed_lines).strip() def transform_data(samples, tokenizer, training_key, max_length, skip_eos_token=False, **kwargs): sequences = [] for idx, question in enumerate(samples['input']): sequences.append(f"<用户>{question}{process_text(samples['output'][idx])}{tokenizer.eos_token}") encoded = tokenizer.batch_encode_plus(sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] attention_mask = torch.where(input_ids == 0, torch.zeros_like(attention_mask), attention_mask) target_ids = input_ids.clone() for i, sequence in enumerate(input_ids): ai_marker = (sequence == tokenizer.encode("", add_special_tokens=False)[2]).nonzero(as_tuple=False) if ai_marker.nelement() > 0: start_pos = ai_marker[0].item() target_ids[i, :start_pos] = -100 if skip_eos_token: target_ids[i, -1] = -100 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": target_ids } def batch_processor(items): input_ids = torch.tensor([item["input_ids"] for item in items]) attention_mask = torch.tensor([item["attention_mask"] for item in items]) labels = torch.tensor([item["labels"] for item in items]) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } def execute(): parser = HfArgumentParser((BaseModelConfig, DataConfig, CustomTrainingConfig)) model_config, data_config, training_config = parser.parse_args_into_dataclasses() model = AutoModelForCausalLM.from_pretrained(model_config.pretrained_model_name_or_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_config.pretrained_model_name_or_path) tokenizer.pad_token_id = tokenizer.unk_token_id train_data = load_dataset('parquet', data_files={'train':data_config.train_dataset_path.split(',')})['train'] eval_data = load_dataset('parquet', data_files={'train':data_config.eval_dataset_path})['train'] train_data = train_data.map(transform_data, batched=True, num_proc=data_config.num_data_proc, fn_kwargs={ "tokenizer": tokenizer, **asdict(data_config) }) eval_data = eval_data.map(transform_data, batched=True, num_proc=data_config.num_data_proc, fn_kwargs={ "tokenizer": tokenizer, **asdict(data_config) }) local_rank = int(os.environ.get("LOCAL_RANK", -1)) if is_main_process(local_rank): print(train_data[0]) trainer = Trainer( model=model, args=training_config, train_dataset=train_data, eval_dataset=eval_data, tokenizer=tokenizer, data_collator=batch_processor, ) trainer.train() if __name__ == '__main__': execute()