This commit is contained in:
hiyouga 2024-07-24 18:33:39 +08:00
parent 091010492b
commit ace1d44857
2 changed files with 15 additions and 15 deletions

View File

@ -36,9 +36,11 @@ def calculate_flops(
""" """
with get_accelerator().device(0): with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()} input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True) flops, macs, params = get_model_profile(
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
)
print("FLOPs:", flops) print("FLOPs:", flops)
print("MACs:", macs) print("MACs:", macs)
print("Params:", params) print("Params:", params)

View File

@ -104,11 +104,6 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
args = dict( args = dict(
stage=TRAINING_STAGES[get("train.training_stage")], stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True, do_train=True,
@ -116,8 +111,6 @@ class Runner:
cache_dir=user_config.get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16, preprocessing_num_workers=16,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"), template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@ -166,6 +159,11 @@ class Runner:
else: # str else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
# freeze config # freeze config
if args["finetuning_type"] == "freeze": if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
@ -242,18 +240,12 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
args = dict( args = dict(
stage="sft", stage="sft",
model_name_or_path=get("top.model_path"), model_name_or_path=get("top.model_path"),
cache_dir=user_config.get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16, preprocessing_num_workers=16,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"), quantization_method=get("top.quantization_method"),
template=get("top.template"), template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
@ -277,6 +269,7 @@ class Runner:
else: else:
args["do_eval"] = True args["do_eval"] = True
# checkpoints
if get("top.checkpoint_path"): if get("top.checkpoint_path"):
if finetuning_type in PEFT_METHODS: # list if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join( args["adapter_name_or_path"] = ",".join(
@ -285,6 +278,11 @@ class Runner:
else: # str else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
return args return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: