python-werkzeug/tests/conftest.py

132 lines
3.8 KiB
Python

import http.client
import json
import os
import socket
import ssl
import sys
from pathlib import Path
import ephemeral_port_reserve
import pytest
from xprocess import ProcessStarter
from werkzeug.utils import cached_property
run_path = str(Path(__file__).parent / "live_apps" / "run.py")
class UnixSocketHTTPConnection(http.client.HTTPConnection):
def connect(self):
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self.host)
class DevServerClient:
def __init__(self, kwargs):
host = kwargs.get("hostname", "127.0.0.1")
if not host.startswith("unix"):
port = kwargs.get("port")
if port is None:
kwargs["port"] = port = ephemeral_port_reserve.reserve(host)
scheme = "https" if "ssl_context" in kwargs else "http"
self.addr = f"{host}:{port}"
self.url = f"{scheme}://{self.addr}"
else:
self.addr = host[7:] # strip "unix://"
self.url = host
self.log = None
def tail_log(self, path):
self.log = open(path)
self.log.read()
def connect(self, **kwargs):
protocol = self.url.partition(":")[0]
if protocol == "https":
if "context" not in kwargs:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
kwargs["context"] = context
return http.client.HTTPSConnection(self.addr, **kwargs)
if protocol == "unix":
return UnixSocketHTTPConnection(self.addr, **kwargs)
return http.client.HTTPConnection(self.addr, **kwargs)
def request(self, path="", **kwargs):
kwargs.setdefault("method", "GET")
kwargs.setdefault("url", path)
conn = self.connect()
conn.request(**kwargs)
with conn.getresponse() as response:
response.data = response.read()
conn.close()
if response.headers.get("Content-Type", "").startswith("application/json"):
response.json = json.loads(response.data)
else:
response.json = None
return response
def wait_for_log(self, start):
while True:
for line in self.log:
if line.startswith(start):
return
def wait_for_reload(self):
self.wait_for_log(" * Restarting with ")
@pytest.fixture()
def dev_server(xprocess, request, tmp_path):
"""A function that will start a dev server in an external process
and return a client for interacting with the server.
"""
def start_dev_server(name="standard", **kwargs):
client = DevServerClient(kwargs)
class Starter(ProcessStarter):
args = [sys.executable, run_path, name, json.dumps(kwargs)]
# Extend the existing env, otherwise Windows and CI fails.
# Modules will be imported from tmp_path for the reloader.
# Unbuffered output so the logs update immediately.
env = {**os.environ, "PYTHONPATH": str(tmp_path), "PYTHONUNBUFFERED": "1"}
@cached_property
def pattern(self):
client.request("/ensure")
return "GET /ensure"
# Each test that uses the fixture will have a different log.
xp_name = f"dev_server-{request.node.name}"
_, log_path = xprocess.ensure(xp_name, Starter, restart=True)
client.tail_log(log_path)
@request.addfinalizer
def close():
xprocess.getinfo(xp_name).terminate()
client.log.close()
return client
return start_dev_server
@pytest.fixture()
def standard_app(dev_server):
"""Equivalent to ``dev_server("standard")``."""
return dev_server()