Merge branch 'hiyouga:main' into main

This commit is contained in:
Johann-Peter Hartmann 2024-01-31 14:05:52 +01:00 committed by GitHub
commit 4e27950acb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 15 deletions

View File

@ -115,7 +115,9 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
}, },
"tags": { "tags": {
"role_tag": "from", "role_tag": "from",
"content_tag": "value" "content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```

View File

@ -115,7 +115,9 @@
}, },
"tags": { "tags": {
"role_tag": "from", "role_tag": "from",
"content_tag": "value" "content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```

View File

@ -108,7 +108,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
tool_list = request.tools tool_list = request.tools
if len(tool_list): if len(tool_list):
try: try:
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False) tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
except Exception: except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:

View File

@ -101,6 +101,18 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
return samples return samples
def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
if model_args.flash_attn:
if is_flash_attn2_available():
config_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention2 is not installed.")
config_kwargs["attn_implementation"] = None
else:
config_kwargs["attn_implementation"] = "eager"
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not hasattr(config, "rope_scaling"): if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.") logger.warning("Current model does not support RoPE scaling.")
@ -128,15 +140,6 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
) )
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
return
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
def _configure_longlora(config: "PretrainedConfig") -> None: def _configure_longlora(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25) setattr(config, "group_size_ratio", 0.25)
@ -257,12 +260,11 @@ def patch_config(
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype) setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(model_args, config_kwargs)
if model_args.rope_scaling is not None: if model_args.rope_scaling is not None:
_configure_rope(config, model_args, is_trainable) _configure_rope(config, model_args, is_trainable)
if model_args.flash_attn:
_configure_flashattn(config_kwargs)
if is_trainable and model_args.shift_attn: if is_trainable and model_args.shift_attn:
_configure_longlora(config) _configure_longlora(config)