chore: add print
This commit is contained in:
parent
a4633cc12f
commit
0071124994
|
@ -213,10 +213,6 @@ def get_dataset(
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||||
print(data_args.tokenized_path)
|
|
||||||
print(dataset_dict)
|
|
||||||
time.sleep(100)
|
|
||||||
|
|
||||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||||
|
|
||||||
dataset_module: Dict[str, "Dataset"] = {}
|
dataset_module: Dict[str, "Dataset"] = {}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import time
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
|
||||||
|
@ -113,6 +114,9 @@ def run_sft(
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||||
|
|
||||||
|
print(predict_results.metrics)
|
||||||
|
time.sleep(100)
|
||||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||||
predict_results.metrics.pop("predict_loss", None)
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
|
|
Loading…
Reference in New Issue