train: modify inference and add start end time

This commit is contained in:
wql 2024-08-26 12:06:23 +08:00
parent 47b4cd3a54
commit e434505c6b
25 changed files with 33 additions and 6 deletions

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_1

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_1_single

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_2

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_2_single

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_3

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_3_single

View File

@ -1,5 +1,5 @@
### model
model_name_or_path: ZhipuAI/chatglm2-6b
model_name_or_path: ../../llm/chatglm/data
### method
do_predict: true
@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_1

View File

@ -1,5 +1,5 @@
### model
model_name_or_path: ZhipuAI/chatglm2-6b
model_name_or_path: ../../llm/chatglm/data
### method
do_predict: true
@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_1_single

View File

@ -1,5 +1,5 @@
### model
model_name_or_path: ZhipuAI/chatglm2-6b
model_name_or_path: ../../llm/chatglm/data
### method
do_predict: true
@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_2

View File

@ -1,5 +1,5 @@
### model
model_name_or_path: ZhipuAI/chatglm2-6b
model_name_or_path: ../../llm/chatglm/data
### method
do_predict: true
@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_2_single

View File

@ -1,5 +1,5 @@
### model
model_name_or_path: ZhipuAI/chatglm2-6b
model_name_or_path: ../../llm/chatglm/data
### method
do_predict: true
@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_3

View File

@ -1,5 +1,5 @@
### model
model_name_or_path: ZhipuAI/chatglm2-6b
model_name_or_path: ../../llm/chatglm/data
### method
do_predict: true
@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_3_single

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Llama2-7B/llama2_predict_1

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Llama2-7B/llama2_predict_1_single

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Llama2-7B/llama2_predict_2

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Llama2-7B/llama2_predict_2_single

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Llama2-7B/llama2_predict_3

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Llama2-7B/llama2_predict_3_single

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Qwen-7B/Qwen_predict_1

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Qwen-7B/Qwen_predict_1_single

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Qwen-7B/Qwen_predict_2

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Qwen-7B/Qwen_predict_2_single

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Qwen-7B/Qwen_predict_3

View File

@ -11,6 +11,7 @@ cutoff_len: 1024
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
include_tokens_per_second: true
### output
output_dir: ./results/inference/Qwen-7B/Qwen_predict_3_single

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import TYPE_CHECKING, List, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
@ -112,7 +113,9 @@ def run_sft(
# Predict
if training_args.do_predict:
print("predict start time: " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
print("predict end time: " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)