vc/training/train.py

13 lines
444 B
Python

from transformers import HfArgumentParser
from src.config import ModelParams, DataParams, TrainingParams
from src.trainer import create_trainer
def main():
args_parser = HfArgumentParser((ModelParams, DataParams, TrainingParams))
model_args, data_args, train_args = args_parser.parse_args_into_dataclasses()
trainer = create_trainer(model_args, data_args, train_args)
trainer.train()
if __name__ == '__main__':
main()