train: modify inference and add start end time
This commit is contained in:
parent
47b4cd3a54
commit
e434505c6b
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_1
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_1_single
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_2
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_2_single
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_3
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Baichuan2-7B/Baichuan2_predict_3_single
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
### model
|
||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
||||
model_name_or_path: ../../llm/chatglm/data
|
||||
|
||||
### method
|
||||
do_predict: true
|
||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_1
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
### model
|
||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
||||
model_name_or_path: ../../llm/chatglm/data
|
||||
|
||||
### method
|
||||
do_predict: true
|
||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_1_single
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
### model
|
||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
||||
model_name_or_path: ../../llm/chatglm/data
|
||||
|
||||
### method
|
||||
do_predict: true
|
||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_2
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
### model
|
||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
||||
model_name_or_path: ../../llm/chatglm/data
|
||||
|
||||
### method
|
||||
do_predict: true
|
||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_2_single
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
### model
|
||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
||||
model_name_or_path: ../../llm/chatglm/data
|
||||
|
||||
### method
|
||||
do_predict: true
|
||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_3
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
### model
|
||||
model_name_or_path: ZhipuAI/chatglm2-6b
|
||||
model_name_or_path: ../../llm/chatglm/data
|
||||
|
||||
### method
|
||||
do_predict: true
|
||||
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/ChatGLM2-6B/ChatGLM2_predict_3_single
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_1
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_1_single
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_2
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_2_single
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_3
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Llama2-7B/llama2_predict_3_single
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_1
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_1_single
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_2
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_2_single
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_3
|
||||
|
|
|
@ -11,6 +11,7 @@ cutoff_len: 1024
|
|||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
include_tokens_per_second: true
|
||||
|
||||
### output
|
||||
output_dir: ./results/inference/Qwen-7B/Qwen_predict_3_single
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
|
||||
|
@ -112,7 +113,9 @@ def run_sft(
|
|||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
print("predict start time: " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
print("predict end time: " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
|
|
Loading…
Reference in New Issue