13 lines
444 B
Python
13 lines
444 B
Python
from transformers import HfArgumentParser
|
|
from src.config import ModelParams, DataParams, TrainingParams
|
|
from src.trainer import create_trainer
|
|
|
|
def main():
|
|
args_parser = HfArgumentParser((ModelParams, DataParams, TrainingParams))
|
|
model_args, data_args, train_args = args_parser.parse_args_into_dataclasses()
|
|
|
|
trainer = create_trainer(model_args, data_args, train_args)
|
|
trainer.train()
|
|
|
|
if __name__ == '__main__':
|
|
main() |