tiny fix
This commit is contained in:
parent
15be296347
commit
c8b4c7fee5
|
@ -137,12 +137,12 @@ def get_device_count() -> int:
|
|||
r"""
|
||||
Gets the number of available GPU or NPU devices.
|
||||
"""
|
||||
if is_torch_npu_available():
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.device_count()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.device_count()
|
||||
elif is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
|
|
@ -133,9 +133,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=train_dataset,
|
||||
optimizer=optimizer,
|
||||
data_collator=data_collator,
|
||||
lr_scheduler=scheduler,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
self.args = training_args
|
||||
|
|
Loading…
Reference in New Issue