diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index c65cd255..291bbc7a 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -39,10 +39,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) if prompt_len > label_len: inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) - if label_len > prompt_len: - inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs + if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) + inputs["labels"] = inputs["labels"][:, :prompt_len] - loss, generated_tokens, _ = super().prediction_step( + loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) if generated_tokens is not None and self.args.predict_with_generate: @@ -79,14 +79,19 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") logger.info(f"Saving prediction results to {output_prediction_file}") - preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) + preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) + for i in range(len(preds)): + pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] + if len(pad_len): + preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]), axis=-1) # move pad token to last + + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False) decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) - decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] - for pred, label in zip(decoded_preds, decoded_labels): + for label, pred in zip(decoded_labels, decoded_preds): res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res))