train: modify inference and add start end time
This commit is contained in:
parent
47b4cd3a54
commit
e434505c6b
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_1
|
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_1
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_1_single
|
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_1_single
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_2
|
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_2
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_2_single
|
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_2_single
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_3
|
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_3
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_3_single
|
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_3_single
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
### model
|
### model
|
||||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
model_name_or_path: ../../llm/chatglm/data
|
||||||
|
|
||||||
### method
|
### method
|
||||||
do_predict: true
|
do_predict: true
|
||||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_1
|
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_1
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
### model
|
### model
|
||||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
model_name_or_path: ../../llm/chatglm/data
|
||||||
|
|
||||||
### method
|
### method
|
||||||
do_predict: true
|
do_predict: true
|
||||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_1_single
|
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_1_single
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
### model
|
### model
|
||||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
model_name_or_path: ../../llm/chatglm/data
|
||||||
|
|
||||||
### method
|
### method
|
||||||
do_predict: true
|
do_predict: true
|
||||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_2
|
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_2
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
### model
|
### model
|
||||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
model_name_or_path: ../../llm/chatglm/data
|
||||||
|
|
||||||
### method
|
### method
|
||||||
do_predict: true
|
do_predict: true
|
||||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_2_single
|
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_2_single
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
### model
|
### model
|
||||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
model_name_or_path: ../../llm/chatglm/data
|
||||||
|
|
||||||
### method
|
### method
|
||||||
do_predict: true
|
do_predict: true
|
||||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_3
|
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_3
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
### model
|
### model
|
||||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
model_name_or_path: ../../llm/chatglm/data
|
||||||
|
|
||||||
### method
|
### method
|
||||||
do_predict: true
|
do_predict: true
|
||||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_3_single
|
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_3_single
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_1
|
output_dir: ./results/inference/Llama2-7B/llama2_predict_1
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_1_single
|
output_dir: ./results/inference/Llama2-7B/llama2_predict_1_single
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_2
|
output_dir: ./results/inference/Llama2-7B/llama2_predict_2
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_2_single
|
output_dir: ./results/inference/Llama2-7B/llama2_predict_2_single
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_3
|
output_dir: ./results/inference/Llama2-7B/llama2_predict_3
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_3_single
|
output_dir: ./results/inference/Llama2-7B/llama2_predict_3_single
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_1
|
output_dir: ./results/inference/Qwen-7B/Qwen_predict_1
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_1_single
|
output_dir: ./results/inference/Qwen-7B/Qwen_predict_1_single
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_2
|
output_dir: ./results/inference/Qwen-7B/Qwen_predict_2
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_2_single
|
output_dir: ./results/inference/Qwen-7B/Qwen_predict_2_single
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_3
|
output_dir: ./results/inference/Qwen-7B/Qwen_predict_3
|
||||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
||||||
max_samples: 50
|
max_samples: 50
|
||||||
overwrite_cache: true
|
overwrite_cache: true
|
||||||
preprocessing_num_workers: 16
|
preprocessing_num_workers: 16
|
||||||
|
include_tokens_per_second: true
|
||||||
|
|
||||||
### output
|
### output
|
||||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_3_single
|
output_dir: ./results/inference/Qwen-7B/Qwen_predict_3_single
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import time
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
|
||||||
|
@ -112,7 +113,9 @@ def run_sft(
|
||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
|
print("predict start time: " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||||
|
print("predict end time: " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||||
predict_results.metrics.pop("predict_loss", None)
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
|
|
Loading…
Reference in New Issue