diff --git a/FM_9G/fm9g/dragonfly/training_tasks/pretrain_indexed.py b/FM_9G/fm9g/dragonfly/training_tasks/pretrain_indexed.py index afee738..f247081 100644 --- a/FM_9G/fm9g/dragonfly/training_tasks/pretrain_indexed.py +++ b/FM_9G/fm9g/dragonfly/training_tasks/pretrain_indexed.py @@ -602,8 +602,7 @@ class MixedIndexedDataset(torch.utils.data.IterableDataset): idx = np.random.choice(len(self.weights), p=self.weights) data = next(self.tasks[idx]) - if step % self.update_weights_frequency == 0: - self.update_weights() + if data is None: if self.tasks[idx].allow_repeat: # _runtime_ave = self.tasks[idx].ave_tokens @@ -618,7 +617,7 @@ class MixedIndexedDataset(torch.utils.data.IterableDataset): self.tasks[idx].exhaust = True self.remain -= 1 continue - + if step % self.update_weights_frequency == 0: self.update_weights() step += 1