tiny fix
This commit is contained in:
parent
091010492b
commit
ace1d44857
|
@ -36,9 +36,11 @@ def calculate_flops(
|
||||||
"""
|
"""
|
||||||
with get_accelerator().device(0):
|
with get_accelerator().device(0):
|
||||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
||||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
|
||||||
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
||||||
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
|
flops, macs, params = get_model_profile(
|
||||||
|
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
|
||||||
|
)
|
||||||
print("FLOPs:", flops)
|
print("FLOPs:", flops)
|
||||||
print("MACs:", macs)
|
print("MACs:", macs)
|
||||||
print("Params:", params)
|
print("Params:", params)
|
||||||
|
|
|
@ -104,11 +104,6 @@ class Runner:
|
||||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
|
||||||
quantization_bit = int(get("top.quantization_bit"))
|
|
||||||
else:
|
|
||||||
quantization_bit = None
|
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||||
do_train=True,
|
do_train=True,
|
||||||
|
@ -116,8 +111,6 @@ class Runner:
|
||||||
cache_dir=user_config.get("cache_dir", None),
|
cache_dir=user_config.get("cache_dir", None),
|
||||||
preprocessing_num_workers=16,
|
preprocessing_num_workers=16,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=quantization_bit,
|
|
||||||
quantization_method=get("top.quantization_method"),
|
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||||
|
@ -166,6 +159,11 @@ class Runner:
|
||||||
else: # str
|
else: # str
|
||||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||||
|
|
||||||
|
# quantization
|
||||||
|
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||||
|
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||||
|
args["quantization_method"] = get("top.quantization_method")
|
||||||
|
|
||||||
# freeze config
|
# freeze config
|
||||||
if args["finetuning_type"] == "freeze":
|
if args["finetuning_type"] == "freeze":
|
||||||
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
||||||
|
@ -242,18 +240,12 @@ class Runner:
|
||||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
|
||||||
quantization_bit = int(get("top.quantization_bit"))
|
|
||||||
else:
|
|
||||||
quantization_bit = None
|
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
model_name_or_path=get("top.model_path"),
|
model_name_or_path=get("top.model_path"),
|
||||||
cache_dir=user_config.get("cache_dir", None),
|
cache_dir=user_config.get("cache_dir", None),
|
||||||
preprocessing_num_workers=16,
|
preprocessing_num_workers=16,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=quantization_bit,
|
|
||||||
quantization_method=get("top.quantization_method"),
|
quantization_method=get("top.quantization_method"),
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
|
@ -277,6 +269,7 @@ class Runner:
|
||||||
else:
|
else:
|
||||||
args["do_eval"] = True
|
args["do_eval"] = True
|
||||||
|
|
||||||
|
# checkpoints
|
||||||
if get("top.checkpoint_path"):
|
if get("top.checkpoint_path"):
|
||||||
if finetuning_type in PEFT_METHODS: # list
|
if finetuning_type in PEFT_METHODS: # list
|
||||||
args["adapter_name_or_path"] = ",".join(
|
args["adapter_name_or_path"] = ",".join(
|
||||||
|
@ -285,6 +278,11 @@ class Runner:
|
||||||
else: # str
|
else: # str
|
||||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||||
|
|
||||||
|
# quantization
|
||||||
|
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||||
|
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||||
|
args["quantization_method"] = get("top.quantization_method")
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||||
|
|
Loading…
Reference in New Issue