fix tol
This commit is contained in:
parent
7f3c19e3ab
commit
46093b5786
|
@ -213,6 +213,7 @@ def convert_pissa_adapter(
|
||||||
safe_serialization=training_args.save_safetensors,
|
safe_serialization=training_args.save_safetensors,
|
||||||
convert_pissa_to_lora=pissa_init_dir,
|
convert_pissa_to_lora=pissa_init_dir,
|
||||||
)
|
)
|
||||||
|
# TODO: the model is applied pissa again unexpectedly
|
||||||
unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
||||||
unwrapped_model.set_adapter("default")
|
unwrapped_model.set_adapter("default")
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||||
state_dict_b = model_b.state_dict()
|
state_dict_b = model_b.state_dict()
|
||||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
for name in state_dict_a.keys():
|
for name in state_dict_a.keys():
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name])
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -67,9 +67,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
||||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
for name in state_dict_a.keys():
|
for name in state_dict_a.keys():
|
||||||
if any(key in name for key in diff_keys):
|
if any(key in name for key in diff_keys):
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is False
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -59,7 +59,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||||
state_dict_b = model_b.state_dict()
|
state_dict_b = model_b.state_dict()
|
||||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
for name in state_dict_a.keys():
|
for name in state_dict_a.keys():
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-3)
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_pissa_init():
|
def test_pissa_init():
|
||||||
|
|
Loading…
Reference in New Issue