jycache-model/hffs/client/http_client.py

188 lines
6.1 KiB
Python

import asyncio
import time
import os
import aiohttp
import aiohttp.client_exceptions
import logging
from ..common.peer import Peer
from huggingface_hub import hf_hub_url, get_hf_file_metadata
from ..common.settings import load_local_service_port, HFFS_API_PING, HFFS_API_PEER_CHANGE, HFFS_API_ALIVE_PEERS
from ..common.settings import HFFS_API_STATUS, HFFS_API_STOP
logger = logging.getLogger(__name__)
LOCAL_HOST = "127.0.0.1"
def timeout_sess(timeout=60):
return aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout))
async def ping(peer, timeout=15):
alive = False
seq = os.urandom(4).hex()
url = f"http://{peer.ip}:{peer.port}" + HFFS_API_PING + f"?seq={seq}"
logger.debug(f"probing {peer.ip}:{peer.port}, seq = {seq}")
try:
async with timeout_sess(timeout) as session:
async with session.get(url) as response:
if response.status == 200:
alive = True
except TimeoutError:
pass
except Exception as e:
logger.warning(e)
peer.set_alive(alive)
peer.set_epoch(int(time.time()))
status_msg = "alive" if alive else "dead"
logger.debug(f"Peer {peer.ip}:{peer.port} (seq:{seq}) is {status_msg}")
return peer
async def alive_peers(timeout=2):
port = load_local_service_port()
url = f"http://{LOCAL_HOST}:{port}" + HFFS_API_ALIVE_PEERS
peers = []
try:
async with timeout_sess(timeout) as session:
async with session.get(url) as response:
if response.status == 200:
peer_list = await response.json()
peers = [Peer.from_dict(peer) for peer in peer_list]
else:
err = f"Failed to get alive peers, HTTP status: {response.status}"
logger.error(err)
except aiohttp.client_exceptions.ClientConnectionError:
logger.warning("Prompt: connect local service failed, may not start, "
"execute hffs daemon start to see which peers are active")
except TimeoutError:
logger.error("Prompt: connect local service timeout, may not start, "
"execute hffs daemon start to see which peers are active")
except Exception as e:
logger.warning(e)
logger.warning("Connect service error, please check it, usually caused by service not start!")
return peers
async def search_coro(peer, repo_id, revision, file_name):
"""Check if a certain file exists in a peer's model repository
Returns:
Peer or None: if the peer has the target file, return the peer, otherwise None
"""
try:
async with timeout_sess(10) as session:
async with session.head(f"http://{peer.ip}:{peer.port}/{repo_id}/resolve/{revision}/{file_name}") as response:
if response.status == 200:
return peer
except Exception:
return None
async def do_search(peers, repo_id, revision, file_name):
tasks = []
def all_finished(tasks):
return all([task.done() for task in tasks])
async with asyncio.TaskGroup() as g:
for peer in peers:
coro = search_coro(peer, repo_id, revision, file_name)
tasks.append(g.create_task(coro))
while not all_finished(tasks):
await asyncio.sleep(1)
print(".", end="")
# add new line after the dots
print("")
return [task.result() for task in tasks if task.result() is not None]
async def search_model(peers, repo_id, file_name, revision):
if not peers:
logger.info("No active peers to search")
return []
logger.info("Will check the following peers:")
logger.info(Peer.print_peers(peers))
avails = await do_search(peers, repo_id, revision, file_name)
logger.info("Peers who have the model:")
logger.info(Peer.print_peers(avails))
return avails
async def get_model_etag(endpoint, repo_id, filename, revision='main'):
url = hf_hub_url(
repo_id=repo_id,
filename=filename,
revision=revision,
endpoint=endpoint
)
metadata = get_hf_file_metadata(url)
return metadata.etag
async def notify_peer_change(timeout=2):
try:
port = load_local_service_port()
except LookupError:
return
url = f"http://{LOCAL_HOST}:{port}" + HFFS_API_PEER_CHANGE
try:
async with timeout_sess(timeout) as session:
async with session.get(url) as response:
if response.status != 200:
logger.debug(f"Peer change http status: {response.status}")
except TimeoutError:
pass # silently ignore timeout
except aiohttp.client_exceptions.ClientConnectionError:
logger.error("Connect local service failed, please check service!")
except Exception as e:
logger.error(f"Peer change error: {e}")
logger.error("Please check the error, usually caused by local service not start!")
async def get_service_status():
port = load_local_service_port()
url = f"http://{LOCAL_HOST}:{port}" + HFFS_API_STATUS
timeout = 5
try:
async with timeout_sess(timeout) as session:
async with session.get(url) as response:
if response.status != 200:
raise ValueError(f"Server response not 200 OK! status: {response.status}")
else:
return await response.json()
except (TimeoutError, ConnectionError, aiohttp.client_exceptions.ClientConnectionError):
raise ConnectionError("Connect server failed or timeout")
async def post_stop_service():
port = load_local_service_port()
url = f"http://{LOCAL_HOST}:{port}" + HFFS_API_STOP
timeout = 5
try:
async with timeout_sess(timeout) as session:
async with session.post(url) as response:
if response.status != 200:
raise ValueError(f"Server response not 200 OK! status: {response.status}")
except (TimeoutError, ConnectionError, aiohttp.client_exceptions.ClientConnectionError):
raise ConnectionError("Connect server failed or timeout")