From 46093b5786611d99adf1fd3d42926a728fc629f8 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sun, 16 Jun 2024 01:38:44 +0800 Subject: [PATCH] fix tol --- src/llamafactory/train/trainer_utils.py | 1 + tests/model/test_base.py | 2 +- tests/model/test_lora.py | 4 ++-- tests/model/test_pissa.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 2d6bab24..9052c96d 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -213,6 +213,7 @@ def convert_pissa_adapter( safe_serialization=training_args.save_safetensors, convert_pissa_to_lora=pissa_init_dir, ) + # TODO: the model is applied pissa again unexpectedly unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True) unwrapped_model.set_adapter("default") diff --git a/tests/model/test_base.py b/tests/model/test_base.py index 954492ef..e1991b20 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -41,7 +41,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"): state_dict_b = model_b.state_dict() assert set(state_dict_a.keys()) == set(state_dict_b.keys()) for name in state_dict_a.keys(): - assert torch.allclose(state_dict_a[name], state_dict_b[name]) + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) @pytest.fixture diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py index fe032332..64566fe8 100644 --- a/tests/model/test_lora.py +++ b/tests/model/test_lora.py @@ -67,9 +67,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k assert set(state_dict_a.keys()) == set(state_dict_b.keys()) for name in state_dict_a.keys(): if any(key in name for key in diff_keys): - assert torch.allclose(state_dict_a[name], state_dict_b[name]) is False + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False else: - assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True @pytest.fixture diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py index 41d02752..030310d0 100644 --- a/tests/model/test_pissa.py +++ b/tests/model/test_pissa.py @@ -59,7 +59,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"): state_dict_b = model_b.state_dict() assert set(state_dict_a.keys()) == set(state_dict_b.keys()) for name in state_dict_a.keys(): - assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-3) + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) def test_pissa_init():