tiny fix
This commit is contained in:
parent
091010492b
commit
ace1d44857
|
@ -36,9 +36,11 @@ def calculate_flops(
|
|||
"""
|
||||
with get_accelerator().device(0):
|
||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
|
||||
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
||||
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
|
||||
flops, macs, params = get_model_profile(
|
||||
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
|
||||
)
|
||||
print("FLOPs:", flops)
|
||||
print("MACs:", macs)
|
||||
print("Params:", params)
|
||||
|
|
|
@ -104,11 +104,6 @@ class Runner:
|
|||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
do_train=True,
|
||||
|
@ -116,8 +111,6 @@ class Runner:
|
|||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
|
@ -166,6 +159,11 @@ class Runner:
|
|||
else: # str
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||
|
||||
# quantization
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
|
||||
# freeze config
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
||||
|
@ -242,18 +240,12 @@ class Runner:
|
|||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get("top.model_path"),
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
|
@ -277,6 +269,7 @@ class Runner:
|
|||
else:
|
||||
args["do_eval"] = True
|
||||
|
||||
# checkpoints
|
||||
if get("top.checkpoint_path"):
|
||||
if finetuning_type in PEFT_METHODS: # list
|
||||
args["adapter_name_or_path"] = ",".join(
|
||||
|
@ -285,6 +278,11 @@ class Runner:
|
|||
else: # str
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||
|
||||
# quantization
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||
|
|
Loading…
Reference in New Issue