From 521ad765521bb65aff5a29a8125a2b26ef00bff4 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 31 Jan 2024 11:58:07 +0800 Subject: [PATCH 1/2] fix autoset attn impl, update data readme --- data/README.md | 4 +++- data/README_zh.md | 4 +++- src/llmtuner/model/patcher.py | 26 ++++++++++++++------------ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/data/README.md b/data/README.md index f2fd7bb1..3d950e1b 100644 --- a/data/README.md +++ b/data/README.md @@ -115,7 +115,9 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be: }, "tags": { "role_tag": "from", - "content_tag": "value" + "content_tag": "value", + "user_tag": "human", + "assistant_tag": "gpt" } } ``` diff --git a/data/README_zh.md b/data/README_zh.md index 8c46e2ae..436bc49c 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -115,7 +115,9 @@ }, "tags": { "role_tag": "from", - "content_tag": "value" + "content_tag": "value", + "user_tag": "human", + "assistant_tag": "gpt" } } ``` diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 477d267e..ac0cc08c 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -101,6 +101,18 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod return samples +def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None: + if model_args.flash_attn: + if is_flash_attn2_available(): + config_kwargs["attn_implementation"] = "flash_attention_2" + logger.info("Using FlashAttention-2 for faster training and inference.") + else: + logger.warning("FlashAttention2 is not installed.") + config_kwargs["attn_implementation"] = None + else: + config_kwargs["attn_implementation"] = "eager" + + def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: if not hasattr(config, "rope_scaling"): logger.warning("Current model does not support RoPE scaling.") @@ -128,15 +140,6 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is ) -def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None: - if not is_flash_attn2_available(): - logger.warning("FlashAttention2 is not installed.") - return - - config_kwargs["use_flash_attention_2"] = True - logger.info("Using FlashAttention-2 for faster training and inference.") - - def _configure_longlora(config: "PretrainedConfig") -> None: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: setattr(config, "group_size_ratio", 0.25) @@ -257,12 +260,11 @@ def patch_config( for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: setattr(config, dtype_name, model_args.compute_dtype == dtype) + _configure_attn_implementation(model_args, config_kwargs) + if model_args.rope_scaling is not None: _configure_rope(config, model_args, is_trainable) - if model_args.flash_attn: - _configure_flashattn(config_kwargs) - if is_trainable and model_args.shift_attn: _configure_longlora(config) From 39bd5bd52404735a2b7dab7c59fa9faa2c9017eb Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 31 Jan 2024 17:23:56 +0800 Subject: [PATCH 2/2] fix #2388 --- src/llmtuner/api/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 428d15de..26ee57ce 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -108,7 +108,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": tool_list = request.tools if len(tool_list): try: - tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False) + tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False) except Exception: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") else: