chore: add print
This commit is contained in:
parent
a4633cc12f
commit
0071124994
|
@ -213,10 +213,6 @@ def get_dataset(
|
|||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||
print(data_args.tokenized_path)
|
||||
print(dataset_dict)
|
||||
time.sleep(100)
|
||||
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
|
||||
dataset_module: Dict[str, "Dataset"] = {}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
|
||||
|
@ -113,6 +114,9 @@ def run_sft(
|
|||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
|
||||
print(predict_results.metrics)
|
||||
time.sleep(100)
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
|
|
Loading…
Reference in New Issue